diff --git a/inference/__init__.py b/inference/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8dc9d99aed8cfddaa89a90611a9204ea1f41fcea --- /dev/null +++ b/inference/__init__.py @@ -0,0 +1,3 @@ +from inference.core.interfaces.stream.stream import Stream # isort:skip +from inference.core.interfaces.stream.inference_pipeline import InferencePipeline +from inference.models.utils import get_roboflow_model diff --git a/inference/__pycache__/__init__.cpython-310.pyc b/inference/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6857db7b28b1c03433cebff6ee36f88317ec95c1 Binary files /dev/null and b/inference/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/__init__.py b/inference/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..64b2f626eeba9988cbaac3abcb6cec4a67751417 --- /dev/null +++ b/inference/core/__init__.py @@ -0,0 +1,52 @@ +import threading +import time + +import requests + +from inference.core.env import DISABLE_VERSION_CHECK, VERSION_CHECK_MODE +from inference.core.logger import logger +from inference.core.version import __version__ + +latest_release = None +last_checked = 0 +cache_duration = 86400 # 24 hours +log_frequency = 300 # 5 minutes + + +def get_latest_release_version(): + global latest_release, last_checked + now = time.time() + if latest_release is None or now - last_checked > cache_duration: + try: + logger.debug("Checking for latest inference release version...") + response = requests.get( + "https://api.github.com/repos/roboflow/inference/releases/latest" + ) + response.raise_for_status() + latest_release = response.json()["tag_name"].lstrip("v") + last_checked = now + except requests.exceptions.RequestException: + pass + + +def check_latest_release_against_current(): + get_latest_release_version() + if latest_release is not None and latest_release != __version__: + logger.warning( + f"Your inference package version {__version__} is out of date! Please upgrade to version {latest_release} of inference for the latest features and bug fixes by running `pip install --upgrade inference`." + ) + + +def check_latest_release_against_current_continuous(): + while True: + check_latest_release_against_current() + time.sleep(log_frequency) + + +if not DISABLE_VERSION_CHECK: + if VERSION_CHECK_MODE == "continuous": + t = threading.Thread(target=check_latest_release_against_current_continuous) + t.daemon = True + t.start() + else: + check_latest_release_against_current() diff --git a/inference/core/__pycache__/__init__.cpython-310.pyc b/inference/core/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec0c9bc57bbc77acae72c0b776e47eb4307e5118 Binary files /dev/null and b/inference/core/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/__pycache__/constants.cpython-310.pyc b/inference/core/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0adfbd176a156186871039a8fc262aeae5599fd4 Binary files /dev/null and b/inference/core/__pycache__/constants.cpython-310.pyc differ diff --git a/inference/core/__pycache__/env.cpython-310.pyc b/inference/core/__pycache__/env.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b751fbc77b9a5d628edabc29abaed9bdb6f7a10a Binary files /dev/null and b/inference/core/__pycache__/env.cpython-310.pyc differ diff --git a/inference/core/__pycache__/exceptions.cpython-310.pyc b/inference/core/__pycache__/exceptions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9dfa9c4823bbf7e42f4e4f3a45368afd2d5e2c96 Binary files /dev/null and b/inference/core/__pycache__/exceptions.cpython-310.pyc differ diff --git a/inference/core/__pycache__/logger.cpython-310.pyc b/inference/core/__pycache__/logger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8088efc30a21f9e6c9903132d118d74ec427ffbf Binary files /dev/null and b/inference/core/__pycache__/logger.cpython-310.pyc differ diff --git a/inference/core/__pycache__/nms.cpython-310.pyc b/inference/core/__pycache__/nms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3d03dc61212468f60dca4957ad42411203c914c Binary files /dev/null and b/inference/core/__pycache__/nms.cpython-310.pyc differ diff --git a/inference/core/__pycache__/roboflow_api.cpython-310.pyc b/inference/core/__pycache__/roboflow_api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae7ade9b88cb1657976937d7100f3e55cb9e6add Binary files /dev/null and b/inference/core/__pycache__/roboflow_api.cpython-310.pyc differ diff --git a/inference/core/__pycache__/usage.cpython-310.pyc b/inference/core/__pycache__/usage.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68b1f4c151a899104b75ff39e0bbf2b1f2a0c8ed Binary files /dev/null and b/inference/core/__pycache__/usage.cpython-310.pyc differ diff --git a/inference/core/__pycache__/version.cpython-310.pyc b/inference/core/__pycache__/version.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9adc358f9a6e51e935bd6d5a82afd28572377e61 Binary files /dev/null and b/inference/core/__pycache__/version.cpython-310.pyc differ diff --git a/inference/core/active_learning/__init__.py b/inference/core/active_learning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/core/active_learning/__pycache__/__init__.cpython-310.pyc b/inference/core/active_learning/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bc335c76e9bc42a332035c24855b1f381835ba6 Binary files /dev/null and b/inference/core/active_learning/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/active_learning/__pycache__/accounting.cpython-310.pyc b/inference/core/active_learning/__pycache__/accounting.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e865ba0ed2772b052007c908c3c40c623ba860b Binary files /dev/null and b/inference/core/active_learning/__pycache__/accounting.cpython-310.pyc differ diff --git a/inference/core/active_learning/__pycache__/batching.cpython-310.pyc b/inference/core/active_learning/__pycache__/batching.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd01d9f1676c0c13234af59b1ec42897e2071d4a Binary files /dev/null and b/inference/core/active_learning/__pycache__/batching.cpython-310.pyc differ diff --git a/inference/core/active_learning/__pycache__/cache_operations.cpython-310.pyc b/inference/core/active_learning/__pycache__/cache_operations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..989996029091a434989e2091139a08e0d23c970a Binary files /dev/null and b/inference/core/active_learning/__pycache__/cache_operations.cpython-310.pyc differ diff --git a/inference/core/active_learning/__pycache__/configuration.cpython-310.pyc b/inference/core/active_learning/__pycache__/configuration.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..962ad3a880ff9767d7ccd25c86299d401f4cd94b Binary files /dev/null and b/inference/core/active_learning/__pycache__/configuration.cpython-310.pyc differ diff --git a/inference/core/active_learning/__pycache__/core.cpython-310.pyc b/inference/core/active_learning/__pycache__/core.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03c76efd2cab3bb82806db5870767abd5facc244 Binary files /dev/null and b/inference/core/active_learning/__pycache__/core.cpython-310.pyc differ diff --git a/inference/core/active_learning/__pycache__/entities.cpython-310.pyc b/inference/core/active_learning/__pycache__/entities.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41ec0eab9b7032a40c332f493ebfeed1c4bf90ad Binary files /dev/null and b/inference/core/active_learning/__pycache__/entities.cpython-310.pyc differ diff --git a/inference/core/active_learning/__pycache__/middlewares.cpython-310.pyc b/inference/core/active_learning/__pycache__/middlewares.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4c150914ca4d7865482ae182947bf15f7b6496b Binary files /dev/null and b/inference/core/active_learning/__pycache__/middlewares.cpython-310.pyc differ diff --git a/inference/core/active_learning/__pycache__/post_processing.cpython-310.pyc b/inference/core/active_learning/__pycache__/post_processing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b40080a736ee5ebb71733f2f407efe73b2fa2c63 Binary files /dev/null and b/inference/core/active_learning/__pycache__/post_processing.cpython-310.pyc differ diff --git a/inference/core/active_learning/__pycache__/utils.cpython-310.pyc b/inference/core/active_learning/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86fc2ee62ce44d912ceaaf84eaad64a47061918c Binary files /dev/null and b/inference/core/active_learning/__pycache__/utils.cpython-310.pyc differ diff --git a/inference/core/active_learning/accounting.py b/inference/core/active_learning/accounting.py new file mode 100644 index 0000000000000000000000000000000000000000..c65fa110f7a4ccd49864fc030abd8c81bd069472 --- /dev/null +++ b/inference/core/active_learning/accounting.py @@ -0,0 +1,96 @@ +from typing import List, Optional + +from inference.core.entities.types import DatasetID, WorkspaceID +from inference.core.roboflow_api import ( + get_roboflow_labeling_batches, + get_roboflow_labeling_jobs, +) + + +def image_can_be_submitted_to_batch( + batch_name: str, + workspace_id: WorkspaceID, + dataset_id: DatasetID, + max_batch_images: Optional[int], + api_key: str, +) -> bool: + """Check if an image can be submitted to a batch. + + Args: + batch_name: Name of the batch. + workspace_id: ID of the workspace. + dataset_id: ID of the dataset. + max_batch_images: Maximum number of images allowed in the batch. + api_key: API key to use for the request. + + Returns: + True if the image can be submitted to the batch, False otherwise. + """ + if max_batch_images is None: + return True + labeling_batches = get_roboflow_labeling_batches( + api_key=api_key, + workspace_id=workspace_id, + dataset_id=dataset_id, + ) + matching_labeling_batch = get_matching_labeling_batch( + all_labeling_batches=labeling_batches["batches"], + batch_name=batch_name, + ) + if matching_labeling_batch is None: + return max_batch_images > 0 + batch_images_under_labeling = 0 + if matching_labeling_batch["numJobs"] > 0: + labeling_jobs = get_roboflow_labeling_jobs( + api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id + ) + batch_images_under_labeling = get_images_in_labeling_jobs_of_specific_batch( + all_labeling_jobs=labeling_jobs["jobs"], + batch_id=matching_labeling_batch["id"], + ) + total_batch_images = matching_labeling_batch["images"] + batch_images_under_labeling + return max_batch_images > total_batch_images + + +def get_matching_labeling_batch( + all_labeling_batches: List[dict], + batch_name: str, +) -> Optional[dict]: + """Get the matching labeling batch. + + Args: + all_labeling_batches: All labeling batches. + batch_name: Name of the batch. + + Returns: + The matching labeling batch if found, None otherwise. + + """ + matching_batch = None + for labeling_batch in all_labeling_batches: + if labeling_batch["name"] == batch_name: + matching_batch = labeling_batch + break + return matching_batch + + +def get_images_in_labeling_jobs_of_specific_batch( + all_labeling_jobs: List[dict], + batch_id: str, +) -> int: + """Get the number of images in labeling jobs of a specific batch. + + Args: + all_labeling_jobs: All labeling jobs. + batch_id: ID of the batch. + + Returns: + The number of images in labeling jobs of the batch. + + """ + + matching_jobs = [] + for labeling_job in all_labeling_jobs: + if batch_id in labeling_job["sourceBatch"]: + matching_jobs.append(labeling_job) + return sum(job["numImages"] for job in matching_jobs) diff --git a/inference/core/active_learning/batching.py b/inference/core/active_learning/batching.py new file mode 100644 index 0000000000000000000000000000000000000000..3c2cd2f50d76baaf0bc70c0efcaec983f72d01d3 --- /dev/null +++ b/inference/core/active_learning/batching.py @@ -0,0 +1,26 @@ +from inference.core.active_learning.entities import ( + ActiveLearningConfiguration, + BatchReCreationInterval, +) +from inference.core.active_learning.utils import ( + generate_start_timestamp_for_this_month, + generate_start_timestamp_for_this_week, + generate_today_timestamp, +) + +RECREATION_INTERVAL2TIMESTAMP_GENERATOR = { + BatchReCreationInterval.DAILY: generate_today_timestamp, + BatchReCreationInterval.WEEKLY: generate_start_timestamp_for_this_week, + BatchReCreationInterval.MONTHLY: generate_start_timestamp_for_this_month, +} + + +def generate_batch_name(configuration: ActiveLearningConfiguration) -> str: + batch_name = configuration.batches_name_prefix + if configuration.batch_recreation_interval is BatchReCreationInterval.NEVER: + return batch_name + timestamp_generator = RECREATION_INTERVAL2TIMESTAMP_GENERATOR[ + configuration.batch_recreation_interval + ] + timestamp = timestamp_generator() + return f"{batch_name}_{timestamp}" diff --git a/inference/core/active_learning/cache_operations.py b/inference/core/active_learning/cache_operations.py new file mode 100644 index 0000000000000000000000000000000000000000..af916744689061a89d5999e5a1ce26d8f1f2f096 --- /dev/null +++ b/inference/core/active_learning/cache_operations.py @@ -0,0 +1,293 @@ +import threading +from contextlib import contextmanager +from datetime import datetime +from typing import Generator, List, Optional, OrderedDict, Union + +import redis.lock + +from inference.core import logger +from inference.core.active_learning.entities import StrategyLimit, StrategyLimitType +from inference.core.active_learning.utils import TIMESTAMP_FORMAT +from inference.core.cache.base import BaseCache + +MAX_LOCK_TIME = 5 +SECONDS_IN_HOUR = 60 * 60 +USAGE_KEY = "usage" + +LIMIT_TYPE2KEY_INFIX_GENERATOR = { + StrategyLimitType.MINUTELY: lambda: f"minute_{datetime.utcnow().minute}", + StrategyLimitType.HOURLY: lambda: f"hour_{datetime.utcnow().hour}", + StrategyLimitType.DAILY: lambda: f"day_{datetime.utcnow().strftime(TIMESTAMP_FORMAT)}", +} +LIMIT_TYPE2KEY_EXPIRATION = { + StrategyLimitType.MINUTELY: 120, + StrategyLimitType.HOURLY: 2 * SECONDS_IN_HOUR, + StrategyLimitType.DAILY: 25 * SECONDS_IN_HOUR, +} + + +def use_credit_of_matching_strategy( + cache: BaseCache, + workspace: str, + project: str, + matching_strategies_limits: OrderedDict[str, List[StrategyLimit]], +) -> Optional[str]: + # In scope of this function, cache keys updates regarding usage limits for + # specific :workspace and :project are locked - to ensure increment to be done atomically + # Limits are accounted at the moment of registration - which may introduce inaccuracy + # given that registration is postponed from prediction + # Returns: strategy with spare credit if found - else None + with lock_limits(cache=cache, workspace=workspace, project=project): + strategy_with_spare_credit = find_strategy_with_spare_usage_credit( + cache=cache, + workspace=workspace, + project=project, + matching_strategies_limits=matching_strategies_limits, + ) + if strategy_with_spare_credit is None: + return None + consume_strategy_limits_usage_credit( + cache=cache, + workspace=workspace, + project=project, + strategy_name=strategy_with_spare_credit, + ) + return strategy_with_spare_credit + + +def return_strategy_credit( + cache: BaseCache, + workspace: str, + project: str, + strategy_name: str, +) -> None: + # In scope of this function, cache keys updates regarding usage limits for + # specific :workspace and :project are locked - to ensure decrement to be done atomically + # Returning strategy is a bit naive (we may add to a pool of credits from the next period - but only + # if we have previously taken from the previous one and some credits are used in the new pool) - + # in favour of easier implementation. + with lock_limits(cache=cache, workspace=workspace, project=project): + return_strategy_limits_usage_credit( + cache=cache, + workspace=workspace, + project=project, + strategy_name=strategy_name, + ) + + +@contextmanager +def lock_limits( + cache: BaseCache, + workspace: str, + project: str, +) -> Generator[Union[threading.Lock, redis.lock.Lock], None, None]: + limits_lock_key = generate_cache_key_for_active_learning_usage_lock( + workspace=workspace, + project=project, + ) + with cache.lock(key=limits_lock_key, expire=MAX_LOCK_TIME) as lock: + yield lock + + +def find_strategy_with_spare_usage_credit( + cache: BaseCache, + workspace: str, + project: str, + matching_strategies_limits: OrderedDict[str, List[StrategyLimit]], +) -> Optional[str]: + for strategy_name, strategy_limits in matching_strategies_limits.items(): + rejected_by_strategy = ( + datapoint_should_be_rejected_based_on_strategy_usage_limits( + cache=cache, + workspace=workspace, + project=project, + strategy_name=strategy_name, + strategy_limits=strategy_limits, + ) + ) + if not rejected_by_strategy: + return strategy_name + return None + + +def datapoint_should_be_rejected_based_on_strategy_usage_limits( + cache: BaseCache, + workspace: str, + project: str, + strategy_name: str, + strategy_limits: List[StrategyLimit], +) -> bool: + for strategy_limit in strategy_limits: + limit_reached = datapoint_should_be_rejected_based_on_limit_usage( + cache=cache, + workspace=workspace, + project=project, + strategy_name=strategy_name, + strategy_limit=strategy_limit, + ) + if limit_reached: + logger.debug( + f"Violated Active Learning strategy limit: {strategy_limit.limit_type.name} " + f"with value {strategy_limit.value} for sampling strategy: {strategy_name}." + ) + return True + return False + + +def datapoint_should_be_rejected_based_on_limit_usage( + cache: BaseCache, + workspace: str, + project: str, + strategy_name: str, + strategy_limit: StrategyLimit, +) -> bool: + current_usage = get_current_strategy_limit_usage( + cache=cache, + workspace=workspace, + project=project, + strategy_name=strategy_name, + limit_type=strategy_limit.limit_type, + ) + if current_usage is None: + current_usage = 0 + return current_usage >= strategy_limit.value + + +def consume_strategy_limits_usage_credit( + cache: BaseCache, + workspace: str, + project: str, + strategy_name: str, +) -> None: + for limit_type in StrategyLimitType: + consume_strategy_limit_usage_credit( + cache=cache, + workspace=workspace, + project=project, + strategy_name=strategy_name, + limit_type=limit_type, + ) + + +def consume_strategy_limit_usage_credit( + cache: BaseCache, + workspace: str, + project: str, + strategy_name: str, + limit_type: StrategyLimitType, +) -> None: + current_value = get_current_strategy_limit_usage( + cache=cache, + limit_type=limit_type, + workspace=workspace, + project=project, + strategy_name=strategy_name, + ) + if current_value is None: + current_value = 0 + current_value += 1 + set_current_strategy_limit_usage( + current_value=current_value, + cache=cache, + limit_type=limit_type, + workspace=workspace, + project=project, + strategy_name=strategy_name, + ) + + +def return_strategy_limits_usage_credit( + cache: BaseCache, + workspace: str, + project: str, + strategy_name: str, +) -> None: + for limit_type in StrategyLimitType: + return_strategy_limit_usage_credit( + cache=cache, + workspace=workspace, + project=project, + strategy_name=strategy_name, + limit_type=limit_type, + ) + + +def return_strategy_limit_usage_credit( + cache: BaseCache, + workspace: str, + project: str, + strategy_name: str, + limit_type: StrategyLimitType, +) -> None: + current_value = get_current_strategy_limit_usage( + cache=cache, + limit_type=limit_type, + workspace=workspace, + project=project, + strategy_name=strategy_name, + ) + if current_value is None: + return None + current_value = max(current_value - 1, 0) + set_current_strategy_limit_usage( + current_value=current_value, + cache=cache, + limit_type=limit_type, + workspace=workspace, + project=project, + strategy_name=strategy_name, + ) + + +def get_current_strategy_limit_usage( + cache: BaseCache, + workspace: str, + project: str, + strategy_name: str, + limit_type: StrategyLimitType, +) -> Optional[int]: + usage_key = generate_cache_key_for_active_learning_usage( + limit_type=limit_type, + workspace=workspace, + project=project, + strategy_name=strategy_name, + ) + value = cache.get(usage_key) + if value is None: + return value + return value[USAGE_KEY] + + +def set_current_strategy_limit_usage( + current_value: int, + cache: BaseCache, + workspace: str, + project: str, + strategy_name: str, + limit_type: StrategyLimitType, +) -> None: + usage_key = generate_cache_key_for_active_learning_usage( + limit_type=limit_type, + workspace=workspace, + project=project, + strategy_name=strategy_name, + ) + expire = LIMIT_TYPE2KEY_EXPIRATION[limit_type] + cache.set(key=usage_key, value={USAGE_KEY: current_value}, expire=expire) # type: ignore + + +def generate_cache_key_for_active_learning_usage_lock( + workspace: str, + project: str, +) -> str: + return f"active_learning:usage:{workspace}:{project}:usage:lock" + + +def generate_cache_key_for_active_learning_usage( + limit_type: StrategyLimitType, + workspace: str, + project: str, + strategy_name: str, +) -> str: + time_infix = LIMIT_TYPE2KEY_INFIX_GENERATOR[limit_type]() + return f"active_learning:usage:{workspace}:{project}:{strategy_name}:{time_infix}" diff --git a/inference/core/active_learning/configuration.py b/inference/core/active_learning/configuration.py new file mode 100644 index 0000000000000000000000000000000000000000..86274cf39507e3d036893997eb366e37c739dcd7 --- /dev/null +++ b/inference/core/active_learning/configuration.py @@ -0,0 +1,203 @@ +import hashlib +from dataclasses import asdict +from typing import Any, Dict, List, Optional + +from inference.core import logger +from inference.core.active_learning.entities import ( + ActiveLearningConfiguration, + RoboflowProjectMetadata, + SamplingMethod, +) +from inference.core.active_learning.samplers.close_to_threshold import ( + initialize_close_to_threshold_sampling, +) +from inference.core.active_learning.samplers.contains_classes import ( + initialize_classes_based_sampling, +) +from inference.core.active_learning.samplers.number_of_detections import ( + initialize_detections_number_based_sampling, +) +from inference.core.active_learning.samplers.random import initialize_random_sampling +from inference.core.cache.base import BaseCache +from inference.core.exceptions import ( + ActiveLearningConfigurationDecodingError, + ActiveLearningConfigurationError, + RoboflowAPINotAuthorizedError, + RoboflowAPINotNotFoundError, +) +from inference.core.roboflow_api import ( + get_roboflow_active_learning_configuration, + get_roboflow_dataset_type, + get_roboflow_workspace, +) +from inference.core.utils.roboflow import get_model_id_chunks + +TYPE2SAMPLING_INITIALIZERS = { + "random": initialize_random_sampling, + "close_to_threshold": initialize_close_to_threshold_sampling, + "classes_based": initialize_classes_based_sampling, + "detections_number_based": initialize_detections_number_based_sampling, +} +ACTIVE_LEARNING_CONFIG_CACHE_EXPIRE = 900 # 15 min + + +def prepare_active_learning_configuration( + api_key: str, + model_id: str, + cache: BaseCache, +) -> Optional[ActiveLearningConfiguration]: + project_metadata = get_roboflow_project_metadata( + api_key=api_key, + model_id=model_id, + cache=cache, + ) + if not project_metadata.active_learning_configuration.get("enabled", False): + return None + logger.info( + f"Configuring active learning for workspace: {project_metadata.workspace_id}, " + f"project: {project_metadata.dataset_id} of type: {project_metadata.dataset_type}. " + f"AL configuration: {project_metadata.active_learning_configuration}" + ) + return initialise_active_learning_configuration( + project_metadata=project_metadata, + ) + + +def prepare_active_learning_configuration_inplace( + api_key: str, + model_id: str, + active_learning_configuration: Optional[dict], +) -> Optional[ActiveLearningConfiguration]: + if ( + active_learning_configuration is None + or active_learning_configuration.get("enabled", False) is False + ): + return None + dataset_id, version_id = get_model_id_chunks(model_id=model_id) + workspace_id = get_roboflow_workspace(api_key=api_key) + dataset_type = get_roboflow_dataset_type( + api_key=api_key, + workspace_id=workspace_id, + dataset_id=dataset_id, + ) + project_metadata = RoboflowProjectMetadata( + dataset_id=dataset_id, + version_id=version_id, + workspace_id=workspace_id, + dataset_type=dataset_type, + active_learning_configuration=active_learning_configuration, + ) + return initialise_active_learning_configuration( + project_metadata=project_metadata, + ) + + +def get_roboflow_project_metadata( + api_key: str, + model_id: str, + cache: BaseCache, +) -> RoboflowProjectMetadata: + logger.info(f"Fetching active learning configuration.") + config_cache_key = construct_cache_key_for_active_learning_config( + api_key=api_key, model_id=model_id + ) + cached_config = cache.get(config_cache_key) + if cached_config is not None: + logger.info("Found Active Learning configuration in cache.") + return parse_cached_roboflow_project_metadata(cached_config=cached_config) + dataset_id, version_id = get_model_id_chunks(model_id=model_id) + workspace_id = get_roboflow_workspace(api_key=api_key) + dataset_type = get_roboflow_dataset_type( + api_key=api_key, + workspace_id=workspace_id, + dataset_id=dataset_id, + ) + try: + roboflow_api_configuration = get_roboflow_active_learning_configuration( + api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id + ) + except (RoboflowAPINotAuthorizedError, RoboflowAPINotNotFoundError): + # currently backend returns HTTP 404 if dataset does not exist + # or workspace_id from api_key indicate that the owner is different, + # so in the situation when we query for Universe dataset. + # We want the owner of public dataset to be able to set AL configs + # and use them, but not other people. At this point it's known + # that HTTP 404 means not authorised (which will probably change + # in future iteration of backend) - so on both NotAuth and NotFound + # errors we assume that we simply cannot use AL with this model and + # this api_key. + roboflow_api_configuration = {"enabled": False} + configuration = RoboflowProjectMetadata( + dataset_id=dataset_id, + version_id=version_id, + workspace_id=workspace_id, + dataset_type=dataset_type, + active_learning_configuration=roboflow_api_configuration, + ) + cache.set( + key=config_cache_key, + value=asdict(configuration), + expire=ACTIVE_LEARNING_CONFIG_CACHE_EXPIRE, + ) + return configuration + + +def construct_cache_key_for_active_learning_config(api_key: str, model_id: str) -> str: + dataset_id = model_id.split("/")[0] + api_key_hash = hashlib.md5(api_key.encode("utf-8")).hexdigest() + return f"active_learning:configurations:{api_key_hash}:{dataset_id}" + + +def parse_cached_roboflow_project_metadata( + cached_config: dict, +) -> RoboflowProjectMetadata: + try: + return RoboflowProjectMetadata(**cached_config) + except Exception as error: + raise ActiveLearningConfigurationDecodingError( + f"Failed to initialise Active Learning configuration. Cause: {str(error)}" + ) from error + + +def initialise_active_learning_configuration( + project_metadata: RoboflowProjectMetadata, +) -> ActiveLearningConfiguration: + sampling_methods = initialize_sampling_methods( + sampling_strategies_configs=project_metadata.active_learning_configuration[ + "sampling_strategies" + ], + ) + target_workspace_id = project_metadata.active_learning_configuration.get( + "target_workspace", project_metadata.workspace_id + ) + target_dataset_id = project_metadata.active_learning_configuration.get( + "target_project", project_metadata.dataset_id + ) + return ActiveLearningConfiguration.init( + roboflow_api_configuration=project_metadata.active_learning_configuration, + sampling_methods=sampling_methods, + workspace_id=target_workspace_id, + dataset_id=target_dataset_id, + model_id=f"{project_metadata.dataset_id}/{project_metadata.version_id}", + ) + + +def initialize_sampling_methods( + sampling_strategies_configs: List[Dict[str, Any]] +) -> List[SamplingMethod]: + result = [] + for sampling_strategy_config in sampling_strategies_configs: + sampling_type = sampling_strategy_config["type"] + if sampling_type not in TYPE2SAMPLING_INITIALIZERS: + logger.warn( + f"Could not identify sampling method `{sampling_type}` - skipping initialisation." + ) + continue + initializer = TYPE2SAMPLING_INITIALIZERS[sampling_type] + result.append(initializer(sampling_strategy_config)) + names = set(m.name for m in result) + if len(names) != len(result): + raise ActiveLearningConfigurationError( + "Detected duplication of Active Learning strategies names." + ) + return result diff --git a/inference/core/active_learning/core.py b/inference/core/active_learning/core.py new file mode 100644 index 0000000000000000000000000000000000000000..55a26596d9fff334d1608b5ddbd22b1d73a27056 --- /dev/null +++ b/inference/core/active_learning/core.py @@ -0,0 +1,219 @@ +from collections import OrderedDict +from typing import List, Optional, Tuple +from uuid import uuid4 + +import numpy as np + +from inference.core import logger +from inference.core.active_learning.cache_operations import ( + return_strategy_credit, + use_credit_of_matching_strategy, +) +from inference.core.active_learning.entities import ( + ActiveLearningConfiguration, + ImageDimensions, + Prediction, + PredictionType, + SamplingMethod, +) +from inference.core.active_learning.post_processing import ( + adjust_prediction_to_client_scaling_factor, + encode_prediction, +) +from inference.core.cache.base import BaseCache +from inference.core.env import ACTIVE_LEARNING_TAGS +from inference.core.roboflow_api import ( + annotate_image_at_roboflow, + register_image_at_roboflow, +) +from inference.core.utils.image_utils import encode_image_to_jpeg_bytes +from inference.core.utils.preprocess import downscale_image_keeping_aspect_ratio + + +def execute_sampling( + image: np.ndarray, + prediction: Prediction, + prediction_type: PredictionType, + sampling_methods: List[SamplingMethod], +) -> List[str]: + matching_strategies = [] + for method in sampling_methods: + sampling_result = method.sample(image, prediction, prediction_type) + if sampling_result: + matching_strategies.append(method.name) + return matching_strategies + + +def execute_datapoint_registration( + cache: BaseCache, + matching_strategies: List[str], + image: np.ndarray, + prediction: Prediction, + prediction_type: PredictionType, + configuration: ActiveLearningConfiguration, + api_key: str, + batch_name: str, +) -> None: + local_image_id = str(uuid4()) + encoded_image, scaling_factor = prepare_image_to_registration( + image=image, + desired_size=configuration.max_image_size, + jpeg_compression_level=configuration.jpeg_compression_level, + ) + prediction = adjust_prediction_to_client_scaling_factor( + prediction=prediction, + scaling_factor=scaling_factor, + prediction_type=prediction_type, + ) + matching_strategies_limits = OrderedDict( + (strategy_name, configuration.strategies_limits[strategy_name]) + for strategy_name in matching_strategies + ) + strategy_with_spare_credit = use_credit_of_matching_strategy( + cache=cache, + workspace=configuration.workspace_id, + project=configuration.dataset_id, + matching_strategies_limits=matching_strategies_limits, + ) + if strategy_with_spare_credit is None: + logger.debug(f"Limit on Active Learning strategy reached.") + return None + register_datapoint_at_roboflow( + cache=cache, + strategy_with_spare_credit=strategy_with_spare_credit, + encoded_image=encoded_image, + local_image_id=local_image_id, + prediction=prediction, + prediction_type=prediction_type, + configuration=configuration, + api_key=api_key, + batch_name=batch_name, + ) + + +def prepare_image_to_registration( + image: np.ndarray, + desired_size: Optional[ImageDimensions], + jpeg_compression_level: int, +) -> Tuple[bytes, float]: + scaling_factor = 1.0 + if desired_size is not None: + height_before_scale = image.shape[0] + image = downscale_image_keeping_aspect_ratio( + image=image, + desired_size=desired_size.to_wh(), + ) + scaling_factor = image.shape[0] / height_before_scale + return ( + encode_image_to_jpeg_bytes(image=image, jpeg_quality=jpeg_compression_level), + scaling_factor, + ) + + +def register_datapoint_at_roboflow( + cache: BaseCache, + strategy_with_spare_credit: str, + encoded_image: bytes, + local_image_id: str, + prediction: Prediction, + prediction_type: PredictionType, + configuration: ActiveLearningConfiguration, + api_key: str, + batch_name: str, +) -> None: + tags = collect_tags( + configuration=configuration, + sampling_strategy=strategy_with_spare_credit, + ) + roboflow_image_id = safe_register_image_at_roboflow( + cache=cache, + strategy_with_spare_credit=strategy_with_spare_credit, + encoded_image=encoded_image, + local_image_id=local_image_id, + configuration=configuration, + api_key=api_key, + batch_name=batch_name, + tags=tags, + ) + if is_prediction_registration_forbidden( + prediction=prediction, + persist_predictions=configuration.persist_predictions, + roboflow_image_id=roboflow_image_id, + ): + return None + encoded_prediction, prediction_file_type = encode_prediction( + prediction=prediction, prediction_type=prediction_type + ) + _ = annotate_image_at_roboflow( + api_key=api_key, + dataset_id=configuration.dataset_id, + local_image_id=local_image_id, + roboflow_image_id=roboflow_image_id, + annotation_content=encoded_prediction, + annotation_file_type=prediction_file_type, + is_prediction=True, + ) + + +def collect_tags( + configuration: ActiveLearningConfiguration, sampling_strategy: str +) -> List[str]: + tags = ACTIVE_LEARNING_TAGS if ACTIVE_LEARNING_TAGS is not None else [] + tags.extend(configuration.tags) + tags.extend(configuration.strategies_tags[sampling_strategy]) + if configuration.persist_predictions: + # this replacement is needed due to backend input validation + tags.append(configuration.model_id.replace("/", "-")) + return tags + + +def safe_register_image_at_roboflow( + cache: BaseCache, + strategy_with_spare_credit: str, + encoded_image: bytes, + local_image_id: str, + configuration: ActiveLearningConfiguration, + api_key: str, + batch_name: str, + tags: List[str], +) -> Optional[str]: + credit_to_be_returned = False + try: + registration_response = register_image_at_roboflow( + api_key=api_key, + dataset_id=configuration.dataset_id, + local_image_id=local_image_id, + image_bytes=encoded_image, + batch_name=batch_name, + tags=tags, + ) + image_duplicated = registration_response.get("duplicate", False) + if image_duplicated: + credit_to_be_returned = True + logger.warning(f"Image duplication detected: {registration_response}.") + return None + return registration_response["id"] + except Exception as error: + credit_to_be_returned = True + raise error + finally: + if credit_to_be_returned: + return_strategy_credit( + cache=cache, + workspace=configuration.workspace_id, + project=configuration.dataset_id, + strategy_name=strategy_with_spare_credit, + ) + + +def is_prediction_registration_forbidden( + prediction: Prediction, + persist_predictions: bool, + roboflow_image_id: Optional[str], +) -> bool: + return ( + roboflow_image_id is None + or persist_predictions is False + or prediction.get("is_stub", False) is True + or (len(prediction.get("predictions", [])) == 0 and "top" not in prediction) + ) diff --git a/inference/core/active_learning/entities.py b/inference/core/active_learning/entities.py new file mode 100644 index 0000000000000000000000000000000000000000..73d905b3e7d4ea1803d15531d799e07825b10514 --- /dev/null +++ b/inference/core/active_learning/entities.py @@ -0,0 +1,141 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Tuple + +import numpy as np + +from inference.core.entities.types import DatasetID, WorkspaceID +from inference.core.exceptions import ActiveLearningConfigurationDecodingError + +LocalImageIdentifier = str +PredictionType = str +Prediction = dict +SerialisedPrediction = str +PredictionFileType = str + + +@dataclass(frozen=True) +class ImageDimensions: + height: int + width: int + + def to_hw(self) -> Tuple[int, int]: + return self.height, self.width + + def to_wh(self) -> Tuple[int, int]: + return self.width, self.height + + +@dataclass(frozen=True) +class SamplingMethod: + name: str + sample: Callable[[np.ndarray, Prediction, PredictionType], bool] + + +class BatchReCreationInterval(Enum): + NEVER = "never" + DAILY = "daily" + WEEKLY = "weekly" + MONTHLY = "monthly" + + +class StrategyLimitType(Enum): + MINUTELY = "minutely" + HOURLY = "hourly" + DAILY = "daily" + + +@dataclass(frozen=True) +class StrategyLimit: + limit_type: StrategyLimitType + value: int + + @classmethod + def from_dict(cls, specification: dict) -> "StrategyLimit": + return cls( + limit_type=StrategyLimitType(specification["type"]), + value=specification["value"], + ) + + +@dataclass(frozen=True) +class ActiveLearningConfiguration: + max_image_size: Optional[ImageDimensions] + jpeg_compression_level: int + persist_predictions: bool + sampling_methods: List[SamplingMethod] + batches_name_prefix: str + batch_recreation_interval: BatchReCreationInterval + max_batch_images: Optional[int] + workspace_id: WorkspaceID + dataset_id: DatasetID + model_id: str + strategies_limits: Dict[str, List[StrategyLimit]] + tags: List[str] + strategies_tags: Dict[str, List[str]] + + @classmethod + def init( + cls, + roboflow_api_configuration: Dict[str, Any], + sampling_methods: List[SamplingMethod], + workspace_id: WorkspaceID, + dataset_id: DatasetID, + model_id: str, + ) -> "ActiveLearningConfiguration": + try: + max_image_size = roboflow_api_configuration.get("max_image_size") + if max_image_size is not None: + max_image_size = ImageDimensions( + height=roboflow_api_configuration["max_image_size"][0], + width=roboflow_api_configuration["max_image_size"][1], + ) + strategies_limits = { + strategy["name"]: [ + StrategyLimit.from_dict(specification=specification) + for specification in strategy.get("limits", []) + ] + for strategy in roboflow_api_configuration["sampling_strategies"] + } + strategies_tags = { + strategy["name"]: strategy.get("tags", []) + for strategy in roboflow_api_configuration["sampling_strategies"] + } + return cls( + max_image_size=max_image_size, + jpeg_compression_level=roboflow_api_configuration.get( + "jpeg_compression_level", 95 + ), + persist_predictions=roboflow_api_configuration["persist_predictions"], + sampling_methods=sampling_methods, + batches_name_prefix=roboflow_api_configuration["batching_strategy"][ + "batches_name_prefix" + ], + batch_recreation_interval=BatchReCreationInterval( + roboflow_api_configuration["batching_strategy"][ + "recreation_interval" + ] + ), + max_batch_images=roboflow_api_configuration["batching_strategy"].get( + "max_batch_images" + ), + workspace_id=workspace_id, + dataset_id=dataset_id, + model_id=model_id, + strategies_limits=strategies_limits, + tags=roboflow_api_configuration.get("tags", []), + strategies_tags=strategies_tags, + ) + except (KeyError, ValueError) as e: + raise ActiveLearningConfigurationDecodingError( + f"Failed to initialise Active Learning configuration. Cause: {str(e)}" + ) from e + + +@dataclass(frozen=True) +class RoboflowProjectMetadata: + dataset_id: DatasetID + version_id: str + workspace_id: WorkspaceID + dataset_type: str + active_learning_configuration: dict diff --git a/inference/core/active_learning/middlewares.py b/inference/core/active_learning/middlewares.py new file mode 100644 index 0000000000000000000000000000000000000000..695d64c4dd95ffd24c68056861505388a55ed4ff --- /dev/null +++ b/inference/core/active_learning/middlewares.py @@ -0,0 +1,307 @@ +import queue +from queue import Queue +from threading import Thread +from typing import Any, List, Optional + +from inference.core import logger +from inference.core.active_learning.accounting import image_can_be_submitted_to_batch +from inference.core.active_learning.batching import generate_batch_name +from inference.core.active_learning.configuration import ( + prepare_active_learning_configuration, + prepare_active_learning_configuration_inplace, +) +from inference.core.active_learning.core import ( + execute_datapoint_registration, + execute_sampling, +) +from inference.core.active_learning.entities import ( + ActiveLearningConfiguration, + Prediction, + PredictionType, +) +from inference.core.cache.base import BaseCache +from inference.core.utils.image_utils import load_image + +MAX_REGISTRATION_QUEUE_SIZE = 512 + + +class NullActiveLearningMiddleware: + def register_batch( + self, + inference_inputs: List[Any], + predictions: List[Prediction], + prediction_type: PredictionType, + disable_preproc_auto_orient: bool = False, + ) -> None: + pass + + def register( + self, + inference_input: Any, + prediction: dict, + prediction_type: PredictionType, + disable_preproc_auto_orient: bool = False, + ) -> None: + pass + + def start_registration_thread(self) -> None: + pass + + def stop_registration_thread(self) -> None: + pass + + def __enter__(self) -> "NullActiveLearningMiddleware": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + pass + + +class ActiveLearningMiddleware: + @classmethod + def init( + cls, api_key: str, model_id: str, cache: BaseCache + ) -> "ActiveLearningMiddleware": + configuration = prepare_active_learning_configuration( + api_key=api_key, + model_id=model_id, + cache=cache, + ) + return cls( + api_key=api_key, + configuration=configuration, + cache=cache, + ) + + @classmethod + def init_from_config( + cls, api_key: str, model_id: str, cache: BaseCache, config: Optional[dict] + ) -> "ActiveLearningMiddleware": + configuration = prepare_active_learning_configuration_inplace( + api_key=api_key, + model_id=model_id, + active_learning_configuration=config, + ) + return cls( + api_key=api_key, + configuration=configuration, + cache=cache, + ) + + def __init__( + self, + api_key: str, + configuration: Optional[ActiveLearningConfiguration], + cache: BaseCache, + ): + self._api_key = api_key + self._configuration = configuration + self._cache = cache + + def register_batch( + self, + inference_inputs: List[Any], + predictions: List[Prediction], + prediction_type: PredictionType, + disable_preproc_auto_orient: bool = False, + ) -> None: + for inference_input, prediction in zip(inference_inputs, predictions): + self.register( + inference_input=inference_input, + prediction=prediction, + prediction_type=prediction_type, + disable_preproc_auto_orient=disable_preproc_auto_orient, + ) + + def register( + self, + inference_input: Any, + prediction: dict, + prediction_type: PredictionType, + disable_preproc_auto_orient: bool = False, + ) -> None: + self._execute_registration( + inference_input=inference_input, + prediction=prediction, + prediction_type=prediction_type, + disable_preproc_auto_orient=disable_preproc_auto_orient, + ) + + def _execute_registration( + self, + inference_input: Any, + prediction: dict, + prediction_type: PredictionType, + disable_preproc_auto_orient: bool = False, + ) -> None: + if self._configuration is None: + return None + image, is_bgr = load_image( + value=inference_input, + disable_preproc_auto_orient=disable_preproc_auto_orient, + ) + if not is_bgr: + image = image[:, :, ::-1] + matching_strategies = execute_sampling( + image=image, + prediction=prediction, + prediction_type=prediction_type, + sampling_methods=self._configuration.sampling_methods, + ) + if len(matching_strategies) == 0: + return None + batch_name = generate_batch_name(configuration=self._configuration) + if not image_can_be_submitted_to_batch( + batch_name=batch_name, + workspace_id=self._configuration.workspace_id, + dataset_id=self._configuration.dataset_id, + max_batch_images=self._configuration.max_batch_images, + api_key=self._api_key, + ): + logger.debug(f"Limit on Active Learning batch size reached.") + return None + execute_datapoint_registration( + cache=self._cache, + matching_strategies=matching_strategies, + image=image, + prediction=prediction, + prediction_type=prediction_type, + configuration=self._configuration, + api_key=self._api_key, + batch_name=batch_name, + ) + + +class ThreadingActiveLearningMiddleware(ActiveLearningMiddleware): + @classmethod + def init( + cls, + api_key: str, + model_id: str, + cache: BaseCache, + max_queue_size: int = MAX_REGISTRATION_QUEUE_SIZE, + ) -> "ThreadingActiveLearningMiddleware": + configuration = prepare_active_learning_configuration( + api_key=api_key, + model_id=model_id, + cache=cache, + ) + task_queue = Queue(max_queue_size) + return cls( + api_key=api_key, + configuration=configuration, + cache=cache, + task_queue=task_queue, + ) + + @classmethod + def init_from_config( + cls, + api_key: str, + model_id: str, + cache: BaseCache, + config: Optional[dict], + max_queue_size: int = MAX_REGISTRATION_QUEUE_SIZE, + ) -> "ThreadingActiveLearningMiddleware": + configuration = prepare_active_learning_configuration_inplace( + api_key=api_key, + model_id=model_id, + active_learning_configuration=config, + ) + task_queue = Queue(max_queue_size) + return cls( + api_key=api_key, + configuration=configuration, + cache=cache, + task_queue=task_queue, + ) + + def __init__( + self, + api_key: str, + configuration: ActiveLearningConfiguration, + cache: BaseCache, + task_queue: Queue, + ): + super().__init__(api_key=api_key, configuration=configuration, cache=cache) + self._task_queue = task_queue + self._registration_thread: Optional[Thread] = None + + def register( + self, + inference_input: Any, + prediction: dict, + prediction_type: PredictionType, + disable_preproc_auto_orient: bool = False, + ) -> None: + logger.debug(f"Putting registration task into queue") + try: + self._task_queue.put_nowait( + ( + inference_input, + prediction, + prediction_type, + disable_preproc_auto_orient, + ) + ) + except queue.Full: + logger.warning( + f"Dropping datapoint registered in Active Learning due to insufficient processing " + f"capabilities." + ) + + def start_registration_thread(self) -> None: + if self._registration_thread is not None: + logger.warning(f"Registration thread already started.") + return None + logger.debug("Staring registration thread") + self._registration_thread = Thread(target=self._consume_queue) + self._registration_thread.start() + + def stop_registration_thread(self) -> None: + if self._registration_thread is None: + logger.warning("Registration thread is already stopped.") + return None + logger.debug("Stopping registration thread") + self._task_queue.put(None) + self._registration_thread.join() + if self._registration_thread.is_alive(): + logger.warning(f"Registration thread stopping was unsuccessful.") + self._registration_thread = None + + def _consume_queue(self) -> None: + queue_closed = False + while not queue_closed: + queue_closed = self._consume_queue_task() + + def _consume_queue_task(self) -> bool: + logger.debug("Consuming registration task") + task = self._task_queue.get() + logger.debug("Received registration task") + if task is None: + logger.debug("Terminating registration thread") + self._task_queue.task_done() + return True + inference_input, prediction, prediction_type, disable_preproc_auto_orient = task + try: + self._execute_registration( + inference_input=inference_input, + prediction=prediction, + prediction_type=prediction_type, + disable_preproc_auto_orient=disable_preproc_auto_orient, + ) + except Exception as error: + # Error handling to be decided + logger.warning( + f"Error in datapoint registration for Active Learning. Details: {error}. " + f"Error is suppressed in favour of normal operations of registration thread." + ) + self._task_queue.task_done() + return False + + def __enter__(self) -> "ThreadingActiveLearningMiddleware": + self.start_registration_thread() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.stop_registration_thread() diff --git a/inference/core/active_learning/post_processing.py b/inference/core/active_learning/post_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..a8381fef4a37f7b9b7f25f33794e12e69033fe59 --- /dev/null +++ b/inference/core/active_learning/post_processing.py @@ -0,0 +1,128 @@ +import json +from typing import List, Tuple + +from inference.core.active_learning.entities import ( + Prediction, + PredictionFileType, + PredictionType, + SerialisedPrediction, +) +from inference.core.constants import ( + CLASSIFICATION_TASK, + INSTANCE_SEGMENTATION_TASK, + OBJECT_DETECTION_TASK, +) +from inference.core.exceptions import PredictionFormatNotSupported + + +def adjust_prediction_to_client_scaling_factor( + prediction: dict, scaling_factor: float, prediction_type: PredictionType +) -> dict: + if abs(scaling_factor - 1.0) < 1e-5: + return prediction + if "image" in prediction: + prediction["image"] = { + "width": round(prediction["image"]["width"] / scaling_factor), + "height": round(prediction["image"]["height"] / scaling_factor), + } + if predictions_should_not_be_post_processed( + prediction=prediction, prediction_type=prediction_type + ): + return prediction + if prediction_type == INSTANCE_SEGMENTATION_TASK: + prediction["predictions"] = ( + adjust_prediction_with_bbox_and_points_to_client_scaling_factor( + predictions=prediction["predictions"], + scaling_factor=scaling_factor, + points_key="points", + ) + ) + if prediction_type == OBJECT_DETECTION_TASK: + prediction["predictions"] = ( + adjust_object_detection_predictions_to_client_scaling_factor( + predictions=prediction["predictions"], + scaling_factor=scaling_factor, + ) + ) + return prediction + + +def predictions_should_not_be_post_processed( + prediction: dict, prediction_type: PredictionType +) -> bool: + # excluding from post-processing classification output, stub-output and empty predictions + return ( + "is_stub" in prediction + or "predictions" not in prediction + or CLASSIFICATION_TASK in prediction_type + or len(prediction["predictions"]) == 0 + ) + + +def adjust_object_detection_predictions_to_client_scaling_factor( + predictions: List[dict], + scaling_factor: float, +) -> List[dict]: + result = [] + for prediction in predictions: + prediction = adjust_bbox_coordinates_to_client_scaling_factor( + bbox=prediction, + scaling_factor=scaling_factor, + ) + result.append(prediction) + return result + + +def adjust_prediction_with_bbox_and_points_to_client_scaling_factor( + predictions: List[dict], + scaling_factor: float, + points_key: str, +) -> List[dict]: + result = [] + for prediction in predictions: + prediction = adjust_bbox_coordinates_to_client_scaling_factor( + bbox=prediction, + scaling_factor=scaling_factor, + ) + prediction[points_key] = adjust_points_coordinates_to_client_scaling_factor( + points=prediction[points_key], + scaling_factor=scaling_factor, + ) + result.append(prediction) + return result + + +def adjust_bbox_coordinates_to_client_scaling_factor( + bbox: dict, + scaling_factor: float, +) -> dict: + bbox["x"] = bbox["x"] / scaling_factor + bbox["y"] = bbox["y"] / scaling_factor + bbox["width"] = bbox["width"] / scaling_factor + bbox["height"] = bbox["height"] / scaling_factor + return bbox + + +def adjust_points_coordinates_to_client_scaling_factor( + points: List[dict], + scaling_factor: float, +) -> List[dict]: + result = [] + for point in points: + point["x"] = point["x"] / scaling_factor + point["y"] = point["y"] / scaling_factor + result.append(point) + return result + + +def encode_prediction( + prediction: Prediction, + prediction_type: PredictionType, +) -> Tuple[SerialisedPrediction, PredictionFileType]: + if CLASSIFICATION_TASK not in prediction_type: + return json.dumps(prediction), "json" + if "top" in prediction: + return prediction["top"], "txt" + raise PredictionFormatNotSupported( + f"Prediction type or prediction format not supported." + ) diff --git a/inference/core/active_learning/samplers/__init__.py b/inference/core/active_learning/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/core/active_learning/samplers/__pycache__/__init__.cpython-310.pyc b/inference/core/active_learning/samplers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb6fc332eb2fd2a08d4d858095f2f24abe032485 Binary files /dev/null and b/inference/core/active_learning/samplers/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/active_learning/samplers/__pycache__/close_to_threshold.cpython-310.pyc b/inference/core/active_learning/samplers/__pycache__/close_to_threshold.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9865ec4a3379e64ebfb2bd92a0492bdaedcefb17 Binary files /dev/null and b/inference/core/active_learning/samplers/__pycache__/close_to_threshold.cpython-310.pyc differ diff --git a/inference/core/active_learning/samplers/__pycache__/contains_classes.cpython-310.pyc b/inference/core/active_learning/samplers/__pycache__/contains_classes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f83b6511904a7d3e353cb850a8ed2a57c6536e7 Binary files /dev/null and b/inference/core/active_learning/samplers/__pycache__/contains_classes.cpython-310.pyc differ diff --git a/inference/core/active_learning/samplers/__pycache__/number_of_detections.cpython-310.pyc b/inference/core/active_learning/samplers/__pycache__/number_of_detections.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f87c9d68ffdd42b70ee442193fda30d356ab00d9 Binary files /dev/null and b/inference/core/active_learning/samplers/__pycache__/number_of_detections.cpython-310.pyc differ diff --git a/inference/core/active_learning/samplers/__pycache__/random.cpython-310.pyc b/inference/core/active_learning/samplers/__pycache__/random.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c01d59fdb14f9de4e6e95927b9dffb34b01477d Binary files /dev/null and b/inference/core/active_learning/samplers/__pycache__/random.cpython-310.pyc differ diff --git a/inference/core/active_learning/samplers/close_to_threshold.py b/inference/core/active_learning/samplers/close_to_threshold.py new file mode 100644 index 0000000000000000000000000000000000000000..4ba4c879dfc453ddbfc6c1d4cfbec4c360e29aff --- /dev/null +++ b/inference/core/active_learning/samplers/close_to_threshold.py @@ -0,0 +1,227 @@ +import random +from functools import partial +from typing import Any, Dict, Optional, Set + +import numpy as np + +from inference.core.active_learning.entities import ( + Prediction, + PredictionType, + SamplingMethod, +) +from inference.core.constants import ( + CLASSIFICATION_TASK, + INSTANCE_SEGMENTATION_TASK, + KEYPOINTS_DETECTION_TASK, + OBJECT_DETECTION_TASK, +) +from inference.core.exceptions import ActiveLearningConfigurationError + +ELIGIBLE_PREDICTION_TYPES = { + CLASSIFICATION_TASK, + INSTANCE_SEGMENTATION_TASK, + KEYPOINTS_DETECTION_TASK, + OBJECT_DETECTION_TASK, +} + + +def initialize_close_to_threshold_sampling( + strategy_config: Dict[str, Any] +) -> SamplingMethod: + try: + selected_class_names = strategy_config.get("selected_class_names") + if selected_class_names is not None: + selected_class_names = set(selected_class_names) + sample_function = partial( + sample_close_to_threshold, + selected_class_names=selected_class_names, + threshold=strategy_config["threshold"], + epsilon=strategy_config["epsilon"], + only_top_classes=strategy_config.get("only_top_classes", True), + minimum_objects_close_to_threshold=strategy_config.get( + "minimum_objects_close_to_threshold", + 1, + ), + probability=strategy_config["probability"], + ) + return SamplingMethod( + name=strategy_config["name"], + sample=sample_function, + ) + except KeyError as error: + raise ActiveLearningConfigurationError( + f"In configuration of `close_to_threshold_sampling` missing key detected: {error}." + ) from error + + +def sample_close_to_threshold( + image: np.ndarray, + prediction: Prediction, + prediction_type: PredictionType, + selected_class_names: Optional[Set[str]], + threshold: float, + epsilon: float, + only_top_classes: bool, + minimum_objects_close_to_threshold: int, + probability: float, +) -> bool: + if is_prediction_a_stub(prediction=prediction): + return False + if prediction_type not in ELIGIBLE_PREDICTION_TYPES: + return False + close_to_threshold = prediction_is_close_to_threshold( + prediction=prediction, + prediction_type=prediction_type, + selected_class_names=selected_class_names, + threshold=threshold, + epsilon=epsilon, + only_top_classes=only_top_classes, + minimum_objects_close_to_threshold=minimum_objects_close_to_threshold, + ) + if not close_to_threshold: + return False + return random.random() < probability + + +def is_prediction_a_stub(prediction: Prediction) -> bool: + return prediction.get("is_stub", False) + + +def prediction_is_close_to_threshold( + prediction: Prediction, + prediction_type: PredictionType, + selected_class_names: Optional[Set[str]], + threshold: float, + epsilon: float, + only_top_classes: bool, + minimum_objects_close_to_threshold: int, +) -> bool: + if CLASSIFICATION_TASK not in prediction_type: + return detections_are_close_to_threshold( + prediction=prediction, + selected_class_names=selected_class_names, + threshold=threshold, + epsilon=epsilon, + minimum_objects_close_to_threshold=minimum_objects_close_to_threshold, + ) + checker = multi_label_classification_prediction_is_close_to_threshold + if "top" in prediction: + checker = multi_class_classification_prediction_is_close_to_threshold + return checker( + prediction=prediction, + selected_class_names=selected_class_names, + threshold=threshold, + epsilon=epsilon, + only_top_classes=only_top_classes, + ) + + +def multi_class_classification_prediction_is_close_to_threshold( + prediction: Prediction, + selected_class_names: Optional[Set[str]], + threshold: float, + epsilon: float, + only_top_classes: bool, +) -> bool: + if only_top_classes: + return ( + multi_class_classification_prediction_is_close_to_threshold_for_top_class( + prediction=prediction, + selected_class_names=selected_class_names, + threshold=threshold, + epsilon=epsilon, + ) + ) + for prediction_details in prediction["predictions"]: + if class_to_be_excluded( + class_name=prediction_details["class"], + selected_class_names=selected_class_names, + ): + continue + if is_close_to_threshold( + value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon + ): + return True + return False + + +def multi_class_classification_prediction_is_close_to_threshold_for_top_class( + prediction: Prediction, + selected_class_names: Optional[Set[str]], + threshold: float, + epsilon: float, +) -> bool: + if ( + selected_class_names is not None + and prediction["top"] not in selected_class_names + ): + return False + return abs(prediction["confidence"] - threshold) < epsilon + + +def multi_label_classification_prediction_is_close_to_threshold( + prediction: Prediction, + selected_class_names: Optional[Set[str]], + threshold: float, + epsilon: float, + only_top_classes: bool, +) -> bool: + predicted_classes = set(prediction["predicted_classes"]) + for class_name, prediction_details in prediction["predictions"].items(): + if only_top_classes and class_name not in predicted_classes: + continue + if class_to_be_excluded( + class_name=class_name, selected_class_names=selected_class_names + ): + continue + if is_close_to_threshold( + value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon + ): + return True + return False + + +def detections_are_close_to_threshold( + prediction: Prediction, + selected_class_names: Optional[Set[str]], + threshold: float, + epsilon: float, + minimum_objects_close_to_threshold: int, +) -> bool: + detections_close_to_threshold = count_detections_close_to_threshold( + prediction=prediction, + selected_class_names=selected_class_names, + threshold=threshold, + epsilon=epsilon, + ) + return detections_close_to_threshold >= minimum_objects_close_to_threshold + + +def count_detections_close_to_threshold( + prediction: Prediction, + selected_class_names: Optional[Set[str]], + threshold: float, + epsilon: float, +) -> int: + counter = 0 + for prediction_details in prediction["predictions"]: + if class_to_be_excluded( + class_name=prediction_details["class"], + selected_class_names=selected_class_names, + ): + continue + if is_close_to_threshold( + value=prediction_details["confidence"], threshold=threshold, epsilon=epsilon + ): + counter += 1 + return counter + + +def class_to_be_excluded( + class_name: str, selected_class_names: Optional[Set[str]] +) -> bool: + return selected_class_names is not None and class_name not in selected_class_names + + +def is_close_to_threshold(value: float, threshold: float, epsilon: float) -> bool: + return abs(value - threshold) < epsilon diff --git a/inference/core/active_learning/samplers/contains_classes.py b/inference/core/active_learning/samplers/contains_classes.py new file mode 100644 index 0000000000000000000000000000000000000000..854dc3716204b6b4cdb863f1b08abce024377a33 --- /dev/null +++ b/inference/core/active_learning/samplers/contains_classes.py @@ -0,0 +1,58 @@ +from functools import partial +from typing import Any, Dict, Set + +import numpy as np + +from inference.core.active_learning.entities import ( + Prediction, + PredictionType, + SamplingMethod, +) +from inference.core.active_learning.samplers.close_to_threshold import ( + sample_close_to_threshold, +) +from inference.core.constants import CLASSIFICATION_TASK +from inference.core.exceptions import ActiveLearningConfigurationError + +ELIGIBLE_PREDICTION_TYPES = {CLASSIFICATION_TASK} + + +def initialize_classes_based_sampling( + strategy_config: Dict[str, Any] +) -> SamplingMethod: + try: + sample_function = partial( + sample_based_on_classes, + selected_class_names=set(strategy_config["selected_class_names"]), + probability=strategy_config["probability"], + ) + return SamplingMethod( + name=strategy_config["name"], + sample=sample_function, + ) + except KeyError as error: + raise ActiveLearningConfigurationError( + f"In configuration of `classes_based_sampling` missing key detected: {error}." + ) from error + + +def sample_based_on_classes( + image: np.ndarray, + prediction: Prediction, + prediction_type: PredictionType, + selected_class_names: Set[str], + probability: float, +) -> bool: + if prediction_type not in ELIGIBLE_PREDICTION_TYPES: + return False + return sample_close_to_threshold( + image=image, + prediction=prediction, + prediction_type=prediction_type, + selected_class_names=selected_class_names, + threshold=0.5, + epsilon=1.0, + only_top_classes=True, + minimum_objects_close_to_threshold=1, + probability=probability, + ) diff --git a/inference/core/active_learning/samplers/number_of_detections.py b/inference/core/active_learning/samplers/number_of_detections.py new file mode 100644 index 0000000000000000000000000000000000000000..4b80351bab8c8c379e05ae2006bfcc0863c32fa0 --- /dev/null +++ b/inference/core/active_learning/samplers/number_of_detections.py @@ -0,0 +1,107 @@ +import random +from functools import partial +from typing import Any, Dict, Optional, Set + +import numpy as np + +from inference.core.active_learning.entities import ( + Prediction, + PredictionType, + SamplingMethod, +) +from inference.core.active_learning.samplers.close_to_threshold import ( + count_detections_close_to_threshold, + is_prediction_a_stub, +) +from inference.core.constants import ( + INSTANCE_SEGMENTATION_TASK, + KEYPOINTS_DETECTION_TASK, + OBJECT_DETECTION_TASK, +) +from inference.core.exceptions import ActiveLearningConfigurationError + +ELIGIBLE_PREDICTION_TYPES = { + INSTANCE_SEGMENTATION_TASK, + KEYPOINTS_DETECTION_TASK, + OBJECT_DETECTION_TASK, +} + + +def initialize_detections_number_based_sampling( + strategy_config: Dict[str, Any] +) -> SamplingMethod: + try: + more_than = strategy_config.get("more_than") + less_than = strategy_config.get("less_than") + ensure_range_configuration_is_valid(more_than=more_than, less_than=less_than) + selected_class_names = strategy_config.get("selected_class_names") + if selected_class_names is not None: + selected_class_names = set(selected_class_names) + sample_function = partial( + sample_based_on_detections_number, + less_than=less_than, + more_than=more_than, + selected_class_names=selected_class_names, + probability=strategy_config["probability"], + ) + return SamplingMethod( + name=strategy_config["name"], + sample=sample_function, + ) + except KeyError as error: + raise ActiveLearningConfigurationError( + f"In configuration of `detections_number_based_sampling` missing key detected: {error}." + ) from error + + +def sample_based_on_detections_number( + image: np.ndarray, + prediction: Prediction, + prediction_type: PredictionType, + more_than: Optional[int], + less_than: Optional[int], + selected_class_names: Optional[Set[str]], + probability: float, +) -> bool: + if is_prediction_a_stub(prediction=prediction): + return False + if prediction_type not in ELIGIBLE_PREDICTION_TYPES: + return False + detections_close_to_threshold = count_detections_close_to_threshold( + prediction=prediction, + selected_class_names=selected_class_names, + threshold=0.5, + epsilon=1.0, + ) + if is_in_range( + value=detections_close_to_threshold, less_than=less_than, more_than=more_than + ): + return random.random() < probability + return False + + +def is_in_range( + value: int, + more_than: Optional[int], + less_than: Optional[int], +) -> bool: + # calculates value > more_than and value < less_than, with optional borders of range + less_than_satisfied, more_than_satisfied = less_than is None, more_than is None + if less_than is not None and value < less_than: + less_than_satisfied = True + if more_than is not None and value > more_than: + more_than_satisfied = True + return less_than_satisfied and more_than_satisfied + + +def ensure_range_configuration_is_valid( + more_than: Optional[int], + less_than: Optional[int], +) -> None: + if more_than is None or less_than is None: + return None + if more_than >= less_than: + raise ActiveLearningConfigurationError( + f"Misconfiguration of detections number sampling: " + f"`more_than` parameter ({more_than}) >= `less_than` ({less_than})." + ) diff --git a/inference/core/active_learning/samplers/random.py b/inference/core/active_learning/samplers/random.py new file mode 100644 index 0000000000000000000000000000000000000000..42df157e530377735bfd36c703755c21bace66a4 --- /dev/null +++ b/inference/core/active_learning/samplers/random.py @@ -0,0 +1,37 @@ +import random +from functools import partial +from typing import Any, Dict + +import numpy as np + +from inference.core.active_learning.entities import ( + Prediction, + PredictionType, + SamplingMethod, +) +from inference.core.exceptions import ActiveLearningConfigurationError + + +def initialize_random_sampling(strategy_config: Dict[str, Any]) -> SamplingMethod: + try: + sample_function = partial( + sample_randomly, + traffic_percentage=strategy_config["traffic_percentage"], + ) + return SamplingMethod( + name=strategy_config["name"], + sample=sample_function, + ) + except KeyError as error: + raise ActiveLearningConfigurationError( + f"In configuration of `random_sampling` missing key detected: {error}." + ) from error + + +def sample_randomly( + image: np.ndarray, + prediction: Prediction, + prediction_type: PredictionType, + traffic_percentage: float, +) -> bool: + return random.random() < traffic_percentage diff --git a/inference/core/active_learning/utils.py b/inference/core/active_learning/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1f1e0a4077fe3afc3fcacdaf029a65fccd2a89e6 --- /dev/null +++ b/inference/core/active_learning/utils.py @@ -0,0 +1,16 @@ +from datetime import datetime, timedelta + +TIMESTAMP_FORMAT = "%Y_%m_%d" + + +def generate_today_timestamp() -> str: + return datetime.today().strftime(TIMESTAMP_FORMAT) + + +def generate_start_timestamp_for_this_week() -> str: + today = datetime.today() + return (today - timedelta(days=today.weekday())).strftime(TIMESTAMP_FORMAT) + + +def generate_start_timestamp_for_this_month() -> str: + return datetime.today().replace(day=1).strftime(TIMESTAMP_FORMAT) diff --git a/inference/core/cache/__init__.py b/inference/core/cache/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b670d5c16625505468634dc360c700c590447495 --- /dev/null +++ b/inference/core/cache/__init__.py @@ -0,0 +1,22 @@ +from redis.exceptions import ConnectionError, TimeoutError + +from inference.core import logger +from inference.core.cache.memory import MemoryCache +from inference.core.cache.redis import RedisCache +from inference.core.env import REDIS_HOST, REDIS_PORT, REDIS_SSL, REDIS_TIMEOUT + +if REDIS_HOST is not None: + try: + cache = RedisCache( + host=REDIS_HOST, port=REDIS_PORT, ssl=REDIS_SSL, timeout=REDIS_TIMEOUT + ) + logger.info(f"Redis Cache initialised") + except (TimeoutError, ConnectionError): + logger.error( + f"Could not connect to Redis under {REDIS_HOST}:{REDIS_PORT}. MemoryCache to be used." + ) + cache = MemoryCache() + logger.info(f"Memory Cache initialised") +else: + cache = MemoryCache() + logger.info(f"Memory Cache initialised") diff --git a/inference/core/cache/__pycache__/__init__.cpython-310.pyc b/inference/core/cache/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9ba45456bd1cb09784cc828e75f9d8ef4d495c1 Binary files /dev/null and b/inference/core/cache/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/cache/__pycache__/base.cpython-310.pyc b/inference/core/cache/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec93ec5cd5e263cfe650cfdae782c0112ef37f53 Binary files /dev/null and b/inference/core/cache/__pycache__/base.cpython-310.pyc differ diff --git a/inference/core/cache/__pycache__/memory.cpython-310.pyc b/inference/core/cache/__pycache__/memory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4455cb49e3474a24abff3890a7a36de948f9b4a Binary files /dev/null and b/inference/core/cache/__pycache__/memory.cpython-310.pyc differ diff --git a/inference/core/cache/__pycache__/model_artifacts.cpython-310.pyc b/inference/core/cache/__pycache__/model_artifacts.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff5f4dc324779699a53c02669daf04764976d4fe Binary files /dev/null and b/inference/core/cache/__pycache__/model_artifacts.cpython-310.pyc differ diff --git a/inference/core/cache/__pycache__/redis.cpython-310.pyc b/inference/core/cache/__pycache__/redis.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..424da34ba94aad5ce1b87c3b3895f33bc3ca4088 Binary files /dev/null and b/inference/core/cache/__pycache__/redis.cpython-310.pyc differ diff --git a/inference/core/cache/__pycache__/serializers.cpython-310.pyc b/inference/core/cache/__pycache__/serializers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64c5ec40a2f24d35475c83f6dd8ca6e65842549d Binary files /dev/null and b/inference/core/cache/__pycache__/serializers.cpython-310.pyc differ diff --git a/inference/core/cache/base.py b/inference/core/cache/base.py new file mode 100644 index 0000000000000000000000000000000000000000..9f538032a8b6dce73bb1c32293763bd3df358d2a --- /dev/null +++ b/inference/core/cache/base.py @@ -0,0 +1,130 @@ +from contextlib import contextmanager +from typing import Any, Optional + +from inference.core import logger + + +class BaseCache: + """ + BaseCache is an abstract base class that defines the interface for a cache. + """ + + def get(self, key: str): + """ + Gets the value associated with the given key. + + Args: + key (str): The key to retrieve the value. + + Raises: + NotImplementedError: This method must be implemented by subclasses. + """ + raise NotImplementedError() + + def set(self, key: str, value: str, expire: float = None): + """ + Sets a value for a given key with an optional expire time. + + Args: + key (str): The key to store the value. + value (str): The value to store. + expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None. + + Raises: + NotImplementedError: This method must be implemented by subclasses. + """ + raise NotImplementedError() + + def zadd(self, key: str, value: str, score: float, expire: float = None): + """ + Adds a member with the specified score to the sorted set stored at key. + + Args: + key (str): The key of the sorted set. + value (str): The value to add to the sorted set. + score (float): The score associated with the value. + expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None. + + Raises: + NotImplementedError: This method must be implemented by subclasses. + """ + raise NotImplementedError() + + def zrangebyscore( + self, + key: str, + min: Optional[float] = -1, + max: Optional[float] = float("inf"), + withscores: bool = False, + ): + """ + Retrieves a range of members from a sorted set. + + Args: + key (str): The key of the sorted set. + start (int, optional): The starting index of the range. Defaults to -1. + stop (int, optional): The ending index of the range. Defaults to float("inf"). + withscores (bool, optional): Whether to return the scores along with the values. Defaults to False. + + Raises: + NotImplementedError: This method must be implemented by subclasses. + """ + raise NotImplementedError() + + def zremrangebyscore( + self, + key: str, + start: Optional[int] = -1, + stop: Optional[int] = float("inf"), + ): + """ + Removes all members in a sorted set within the given scores. + + Args: + key (str): The key of the sorted set. + start (int, optional): The minimum score of the range. Defaults to -1. + stop (int, optional): The maximum score of the range. Defaults to float("inf"). + + Raises: + NotImplementedError: This method must be implemented by subclasses. + """ + raise NotImplementedError() + + def acquire_lock(self, key: str, expire: float = None) -> Any: + raise NotImplementedError() + + @contextmanager + def lock(self, key: str, expire: float = None) -> Any: + logger.debug(f"Acquiring lock at cache key: {key}") + l = self.acquire_lock(key, expire=expire) + try: + yield l + finally: + logger.debug(f"Releasing lock at cache key: {key}") + l.release() + + def set_numpy(self, key: str, value: Any, expire: float = None): + """ + Caches a numpy array. + + Args: + key (str): The key to store the value. + value (Any): The value to store. + expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None. + + Raises: + NotImplementedError: This method must be implemented by subclasses. + """ + raise NotImplementedError() + + def get_numpy(self, key: str) -> Any: + """ + Retrieves a numpy array from the cache. + + Args: + key (str): The key of the value to retrieve. + + Raises: + NotImplementedError: This method must be implemented by subclasses. + """ + raise NotImplementedError() diff --git a/inference/core/cache/memory.py b/inference/core/cache/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..03973bb31ad7dd4cfbe9de1b1bc40e224958d8b9 --- /dev/null +++ b/inference/core/cache/memory.py @@ -0,0 +1,172 @@ +import threading +import time +from threading import Lock +from typing import Any, Optional + +from inference.core.cache.base import BaseCache +from inference.core.env import MEMORY_CACHE_EXPIRE_INTERVAL + + +class MemoryCache(BaseCache): + """ + MemoryCache is an in-memory cache that implements the BaseCache interface. + + Attributes: + cache (dict): A dictionary to store the cache values. + expires (dict): A dictionary to store the expiration times of the cache values. + zexpires (dict): A dictionary to store the expiration times of the sorted set values. + _expire_thread (threading.Thread): A thread that runs the _expire method. + """ + + def __init__(self) -> None: + """ + Initializes a new instance of the MemoryCache class. + """ + self.cache = dict() + self.expires = dict() + self.zexpires = dict() + + self._expire_thread = threading.Thread(target=self._expire) + self._expire_thread.daemon = True + self._expire_thread.start() + + def _expire(self): + """ + Removes the expired keys from the cache and zexpires dictionaries. + + This method runs in an infinite loop and sleeps for MEMORY_CACHE_EXPIRE_INTERVAL seconds between each iteration. + """ + while True: + now = time.time() + keys_to_delete = [] + for k, v in self.expires.copy().items(): + if v < now: + keys_to_delete.append(k) + for k in keys_to_delete: + del self.cache[k] + del self.expires[k] + keys_to_delete = [] + for k, v in self.zexpires.copy().items(): + if v < now: + keys_to_delete.append(k) + for k in keys_to_delete: + del self.cache[k[0]][k[1]] + del self.zexpires[k] + while time.time() - now < MEMORY_CACHE_EXPIRE_INTERVAL: + time.sleep(0.1) + + def get(self, key: str): + """ + Gets the value associated with the given key. + + Args: + key (str): The key to retrieve the value. + + Returns: + str: The value associated with the key, or None if the key does not exist or is expired. + """ + if key in self.expires: + if self.expires[key] < time.time(): + del self.cache[key] + del self.expires[key] + return None + return self.cache.get(key) + + def set(self, key: str, value: str, expire: float = None): + """ + Sets a value for a given key with an optional expire time. + + Args: + key (str): The key to store the value. + value (str): The value to store. + expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None. + """ + self.cache[key] = value + if expire: + self.expires[key] = expire + time.time() + + def zadd(self, key: str, value: Any, score: float, expire: float = None): + """ + Adds a member with the specified score to the sorted set stored at key. + + Args: + key (str): The key of the sorted set. + value (str): The value to add to the sorted set. + score (float): The score associated with the value. + expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None. + """ + if not key in self.cache: + self.cache[key] = dict() + self.cache[key][score] = value + if expire: + self.zexpires[(key, score)] = expire + time.time() + + def zrangebyscore( + self, + key: str, + min: Optional[float] = -1, + max: Optional[float] = float("inf"), + withscores: bool = False, + ): + """ + Retrieves a range of members from a sorted set. + + Args: + key (str): The key of the sorted set. + start (int, optional): The starting score of the range. Defaults to -1. + stop (int, optional): The ending score of the range. Defaults to float("inf"). + withscores (bool, optional): Whether to return the scores along with the values. Defaults to False. + + Returns: + list: A list of values (or value-score pairs if withscores is True) in the specified score range. + """ + if not key in self.cache: + return [] + keys = sorted([k for k in self.cache[key].keys() if min <= k <= max]) + if withscores: + return [(self.cache[key][k], k) for k in keys] + else: + return [self.cache[key][k] for k in keys] + + def zremrangebyscore( + self, + key: str, + min: Optional[float] = -1, + max: Optional[float] = float("inf"), + ): + """ + Removes all members in a sorted set within the given scores. + + Args: + key (str): The key of the sorted set. + start (int, optional): The minimum score of the range. Defaults to -1. + stop (int, optional): The maximum score of the range. Defaults to float("inf"). + + Returns: + int: The number of members removed from the sorted set. + """ + res = self.zrangebyscore(key, min=min, max=max, withscores=True) + keys_to_delete = [k[1] for k in res] + for k in keys_to_delete: + del self.cache[key][k] + return len(keys_to_delete) + + def acquire_lock(self, key: str, expire=None) -> Any: + lock: Optional[Lock] = self.get(key) + if lock is None: + lock = Lock() + self.set(key, lock, expire=expire) + if expire is None: + expire = -1 + acquired = lock.acquire(timeout=expire) + if not acquired: + raise TimeoutError() + # refresh the lock + self.set(key, lock, expire=expire) + return lock + + def set_numpy(self, key: str, value: Any, expire: float = None): + return self.set(key, value, expire=expire) + + def get_numpy(self, key: str): + return self.get(key) diff --git a/inference/core/cache/model_artifacts.py b/inference/core/cache/model_artifacts.py new file mode 100644 index 0000000000000000000000000000000000000000..38679dd97227fb0d946ab3c52891162c898f185d --- /dev/null +++ b/inference/core/cache/model_artifacts.py @@ -0,0 +1,99 @@ +import os.path +import shutil +from typing import List, Optional, Union + +from inference.core.env import MODEL_CACHE_DIR +from inference.core.utils.file_system import ( + dump_bytes, + dump_json, + dump_text_lines, + read_json, + read_text_file, +) + + +def initialise_cache(model_id: Optional[str] = None) -> None: + cache_dir = get_cache_dir(model_id=model_id) + os.makedirs(cache_dir, exist_ok=True) + + +def are_all_files_cached(files: List[str], model_id: Optional[str] = None) -> bool: + return all(is_file_cached(file=file, model_id=model_id) for file in files) + + +def is_file_cached(file: str, model_id: Optional[str] = None) -> bool: + cached_file_path = get_cache_file_path(file=file, model_id=model_id) + return os.path.isfile(cached_file_path) + + +def load_text_file_from_cache( + file: str, + model_id: Optional[str] = None, + split_lines: bool = False, + strip_white_chars: bool = False, +) -> Union[str, List[str]]: + cached_file_path = get_cache_file_path(file=file, model_id=model_id) + return read_text_file( + path=cached_file_path, + split_lines=split_lines, + strip_white_chars=strip_white_chars, + ) + + +def load_json_from_cache( + file: str, model_id: Optional[str] = None, **kwargs +) -> Optional[Union[dict, list]]: + cached_file_path = get_cache_file_path(file=file, model_id=model_id) + return read_json(path=cached_file_path, **kwargs) + + +def save_bytes_in_cache( + content: bytes, + file: str, + model_id: Optional[str] = None, + allow_override: bool = True, +) -> None: + cached_file_path = get_cache_file_path(file=file, model_id=model_id) + dump_bytes(path=cached_file_path, content=content, allow_override=allow_override) + + +def save_json_in_cache( + content: Union[dict, list], + file: str, + model_id: Optional[str] = None, + allow_override: bool = True, + **kwargs, +) -> None: + cached_file_path = get_cache_file_path(file=file, model_id=model_id) + dump_json( + path=cached_file_path, content=content, allow_override=allow_override, **kwargs + ) + + +def save_text_lines_in_cache( + content: List[str], + file: str, + model_id: Optional[str] = None, + allow_override: bool = True, +) -> None: + cached_file_path = get_cache_file_path(file=file, model_id=model_id) + dump_text_lines( + path=cached_file_path, content=content, allow_override=allow_override + ) + + +def get_cache_file_path(file: str, model_id: Optional[str] = None) -> str: + cache_dir = get_cache_dir(model_id=model_id) + return os.path.join(cache_dir, file) + + +def clear_cache(model_id: Optional[str] = None) -> None: + cache_dir = get_cache_dir(model_id=model_id) + if os.path.exists(cache_dir): + shutil.rmtree(cache_dir) + + +def get_cache_dir(model_id: Optional[str] = None) -> str: + if model_id is not None: + return os.path.join(MODEL_CACHE_DIR, model_id) + return MODEL_CACHE_DIR diff --git a/inference/core/cache/redis.py b/inference/core/cache/redis.py new file mode 100644 index 0000000000000000000000000000000000000000..5633e0e939ba2fade6fd8aecabe6762c84a57a37 --- /dev/null +++ b/inference/core/cache/redis.py @@ -0,0 +1,196 @@ +import asyncio +import inspect +import json +import pickle +import threading +import time +from contextlib import asynccontextmanager +from copy import copy +from typing import Any, Optional + +import redis + +from inference.core import logger +from inference.core.cache.base import BaseCache +from inference.core.entities.responses.inference import InferenceResponseImage +from inference.core.env import MEMORY_CACHE_EXPIRE_INTERVAL + + +class RedisCache(BaseCache): + """ + MemoryCache is an in-memory cache that implements the BaseCache interface. + + Attributes: + cache (dict): A dictionary to store the cache values. + expires (dict): A dictionary to store the expiration times of the cache values. + zexpires (dict): A dictionary to store the expiration times of the sorted set values. + _expire_thread (threading.Thread): A thread that runs the _expire method. + """ + + def __init__( + self, + host: str = "localhost", + port: int = 6379, + db: int = 0, + ssl: bool = False, + timeout: float = 2.0, + ) -> None: + """ + Initializes a new instance of the MemoryCache class. + """ + self.client = redis.Redis( + host=host, + port=port, + db=db, + decode_responses=True, + ssl=ssl, + socket_timeout=timeout, + socket_connect_timeout=timeout, + ) + logger.debug("Attempting to diagnose Redis connection...") + self.client.ping() + logger.debug("Redis connection established.") + self.zexpires = dict() + + self._expire_thread = threading.Thread(target=self._expire, daemon=True) + self._expire_thread.start() + + def _expire(self): + """ + Removes the expired keys from the cache and zexpires dictionaries. + + This method runs in an infinite loop and sleeps for MEMORY_CACHE_EXPIRE_INTERVAL seconds between each iteration. + """ + while True: + logger.debug("Redis cleaner thread starts cleaning...") + now = time.time() + for k, v in copy(list(self.zexpires.items())): + if v < now: + tolerance_factor = 1e-14 # floating point accuracy + self.zremrangebyscore( + k[0], k[1] - tolerance_factor, k[1] + tolerance_factor + ) + del self.zexpires[k] + logger.debug("Redis cleaner finished task.") + sleep_time = MEMORY_CACHE_EXPIRE_INTERVAL - (time.time() - now) + time.sleep(max(sleep_time, 0)) + + def get(self, key: str): + """ + Gets the value associated with the given key. + + Args: + key (str): The key to retrieve the value. + + Returns: + str: The value associated with the key, or None if the key does not exist or is expired. + """ + item = self.client.get(key) + if item is not None: + try: + return json.loads(item) + except TypeError: + return item + + def set(self, key: str, value: str, expire: float = None): + """ + Sets a value for a given key with an optional expire time. + + Args: + key (str): The key to store the value. + value (str): The value to store. + expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None. + """ + if not isinstance(value, bytes): + value = json.dumps(value) + self.client.set(key, value, ex=expire) + + def zadd(self, key: str, value: Any, score: float, expire: float = None): + """ + Adds a member with the specified score to the sorted set stored at key. + + Args: + key (str): The key of the sorted set. + value (str): The value to add to the sorted set. + score (float): The score associated with the value. + expire (float, optional): The time, in seconds, after which the key will expire. Defaults to None. + """ + # serializable_value = self.ensure_serializable(value) + value = json.dumps(value) + self.client.zadd(key, {value: score}) + if expire: + self.zexpires[(key, score)] = expire + time.time() + + def zrangebyscore( + self, + key: str, + min: Optional[float] = -1, + max: Optional[float] = float("inf"), + withscores: bool = False, + ): + """ + Retrieves a range of members from a sorted set. + + Args: + key (str): The key of the sorted set. + start (int, optional): The starting score of the range. Defaults to -1. + stop (int, optional): The ending score of the range. Defaults to float("inf"). + withscores (bool, optional): Whether to return the scores along with the values. Defaults to False. + + Returns: + list: A list of values (or value-score pairs if withscores is True) in the specified score range. + """ + res = self.client.zrangebyscore(key, min, max, withscores=withscores) + if withscores: + return [(json.loads(x), y) for x, y in res] + else: + return [json.loads(x) for x in res] + + def zremrangebyscore( + self, + key: str, + min: Optional[float] = -1, + max: Optional[float] = float("inf"), + ): + """ + Removes all members in a sorted set within the given scores. + + Args: + key (str): The key of the sorted set. + start (int, optional): The minimum score of the range. Defaults to -1. + stop (int, optional): The maximum score of the range. Defaults to float("inf"). + + Returns: + int: The number of members removed from the sorted set. + """ + return self.client.zremrangebyscore(key, min, max) + + def ensure_serializable(self, value: Any): + if isinstance(value, dict): + for k, v in value.items(): + if isinstance(v, Exception): + value[k] = str(v) + elif inspect.isclass(v) and isinstance(v, InferenceResponseImage): + value[k] = v.dict() + return value + + def acquire_lock(self, key: str, expire=None) -> Any: + l = self.client.lock(key, blocking=True, timeout=expire) + acquired = l.acquire(blocking_timeout=expire) + if not acquired: + raise TimeoutError("Couldn't get lock") + # refresh the lock + if expire is not None: + l.extend(expire) + return l + + def set_numpy(self, key: str, value: Any, expire: float = None): + serialized_value = pickle.dumps(value) + self.set(key, serialized_value, expire=expire) + + def get_numpy(self, key: str) -> Any: + serialized_value = self.get(key) + if serialized_value is not None: + return pickle.loads(serialized_value) + else: + return None diff --git a/inference/core/cache/serializers.py b/inference/core/cache/serializers.py new file mode 100644 index 0000000000000000000000000000000000000000..6fa8c06f23c17ecd0f29f83e6c0bd3524f264fd6 --- /dev/null +++ b/inference/core/cache/serializers.py @@ -0,0 +1,68 @@ +from typing import Union + +from fastapi.encoders import jsonable_encoder + +from inference.core.devices.utils import GLOBAL_INFERENCE_SERVER_ID +from inference.core.entities.requests.inference import InferenceRequest +from inference.core.entities.responses.inference import InferenceResponse +from inference.core.env import TINY_CACHE +from inference.core.logger import logger +from inference.core.version import __version__ + + +def to_cachable_inference_item( + infer_request: InferenceRequest, + infer_response: Union[InferenceResponse, list[InferenceResponse]], +) -> dict: + if not TINY_CACHE: + return { + "inference_id": infer_request.id, + "inference_server_version": __version__, + "inference_server_id": GLOBAL_INFERENCE_SERVER_ID, + "request": jsonable_encoder(infer_request), + "response": jsonable_encoder(infer_response), + } + + included_request_fields = { + "api_key", + "confidence", + "model_id", + "model_type", + "source", + "source_info", + } + request = infer_request.dict(include=included_request_fields) + response = build_condensed_response(infer_response) + + return { + "inference_id": infer_request.id, + "inference_server_version": __version__, + "inference_server_id": GLOBAL_INFERENCE_SERVER_ID, + "request": jsonable_encoder(request), + "response": jsonable_encoder(response), + } + + +def build_condensed_response(responses): + if not isinstance(responses, list): + responses = [responses] + + formatted_responses = [] + for response in responses: + if not getattr(response, "predictions", None): + continue + try: + predictions = [ + {"confidence": pred.confidence, "class": pred.class_name} + for pred in response.predictions + ] + formatted_responses.append( + { + "predictions": predictions, + "time": response.time, + } + ) + except Exception as e: + logger.warning(f"Error formatting response, skipping caching: {e}") + + return formatted_responses diff --git a/inference/core/constants.py b/inference/core/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..6958784c43c1c66f238694e74a94710b57a463bd --- /dev/null +++ b/inference/core/constants.py @@ -0,0 +1,4 @@ +CLASSIFICATION_TASK = "classification" +OBJECT_DETECTION_TASK = "object-detection" +INSTANCE_SEGMENTATION_TASK = "instance-segmentation" +KEYPOINTS_DETECTION_TASK = "keypoint-detection" diff --git a/inference/core/devices/__init__.py b/inference/core/devices/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/core/devices/__pycache__/__init__.cpython-310.pyc b/inference/core/devices/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03faf25547bfb759d27579bfdc3eac7d67966a9b Binary files /dev/null and b/inference/core/devices/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/devices/__pycache__/utils.cpython-310.pyc b/inference/core/devices/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24cc3edf06815019177dac38e5ff26cb02c6ef83 Binary files /dev/null and b/inference/core/devices/__pycache__/utils.cpython-310.pyc differ diff --git a/inference/core/devices/utils.py b/inference/core/devices/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a26d47472ec870c85439c9dba95f44877e1b4918 --- /dev/null +++ b/inference/core/devices/utils.py @@ -0,0 +1,140 @@ +import os +import platform +import random +import string +import uuid + +from inference.core.env import DEVICE_ID, INFERENCE_SERVER_ID + + +def is_running_in_docker(): + """Checks if the current process is running inside a Docker container. + + Returns: + bool: True if running inside a Docker container, False otherwise. + """ + return os.path.exists("/.dockerenv") + + +def get_gpu_id(): + """Fetches the GPU ID if a GPU is present. + + Tries to import and use the `GPUtil` module to retrieve the GPU information. + + Returns: + Optional[int]: GPU ID if available, None otherwise. + """ + try: + import GPUtil + + GPUs = GPUtil.getGPUs() + if GPUs: + return GPUs[0].id + except ImportError: + return None + except Exception as e: + return None + + +def get_cpu_id(): + """Fetches the CPU ID based on the operating system. + + Attempts to get the CPU ID for Windows, Linux, and MacOS. + In case of any error or an unsupported OS, returns None. + + Returns: + Optional[str]: CPU ID string if available, None otherwise. + """ + try: + if platform.system() == "Windows": + return os.popen("wmic cpu get ProcessorId").read().strip() + elif platform.system() == "Linux": + return ( + open("/proc/cpuinfo").read().split("processor")[0].split(":")[1].strip() + ) + elif platform.system() == "Darwin": + import subprocess + + return ( + subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"]) + .strip() + .decode() + ) + except Exception as e: + return None + + +def get_jetson_id(): + """Fetches the Jetson device's serial number. + + Attempts to read the serial number from the device tree. + In case of any error, returns None. + + Returns: + Optional[str]: Jetson device serial number if available, None otherwise. + """ + try: + # Fetch the device's serial number + if not os.path.exists("/proc/device-tree/serial-number"): + return None + serial_number = os.popen("cat /proc/device-tree/serial-number").read().strip() + if serial_number == "": + return None + return serial_number + except Exception as e: + return None + + +def get_container_id(): + if is_running_in_docker(): + return ( + os.popen( + "cat /proc/self/cgroup | grep 'docker' | sed 's/^.*\///' | tail -n1" + ) + .read() + .strip() + ) + else: + return str(uuid.uuid4()) + + +def random_string(length): + letters = string.ascii_letters + string.digits + return "".join(random.choice(letters) for i in range(length)) + + +def get_device_hostname(): + """Fetches the device's hostname. + + Returns: + str: The device's hostname. + """ + return platform.node() + + +def get_inference_server_id(): + """Fetches a unique device ID. + + Tries to get the GPU ID first, then falls back to CPU ID. + If the application is running inside Docker, the Docker container ID is appended to the hostname. + + Returns: + str: A unique string representing the device. If unable to determine, returns "UNKNOWN". + """ + try: + if INFERENCE_SERVER_ID is not None: + return INFERENCE_SERVER_ID + id = random_string(6) + gpu_id = get_gpu_id() + if gpu_id is not None: + return f"{id}-GPU-{gpu_id}" + jetson_id = get_jetson_id() + if jetson_id is not None: + return f"{id}-JETSON-{jetson_id}" + return id + except Exception as e: + return "UNKNOWN" + + +GLOBAL_INFERENCE_SERVER_ID = get_inference_server_id() +GLOBAL_DEVICE_ID = DEVICE_ID if DEVICE_ID is not None else get_device_hostname() diff --git a/inference/core/entities/__init__.py b/inference/core/entities/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/core/entities/__pycache__/__init__.cpython-310.pyc b/inference/core/entities/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b207f31b0833c903b145fe6dcf1cadc3565867e Binary files /dev/null and b/inference/core/entities/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/entities/__pycache__/common.cpython-310.pyc b/inference/core/entities/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7e6249f6663dffafc75ae8d745eeed71895ff1e Binary files /dev/null and b/inference/core/entities/__pycache__/common.cpython-310.pyc differ diff --git a/inference/core/entities/__pycache__/types.cpython-310.pyc b/inference/core/entities/__pycache__/types.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79da4b263e8b03a044d36c96ff3317fe121b09ed Binary files /dev/null and b/inference/core/entities/__pycache__/types.cpython-310.pyc differ diff --git a/inference/core/entities/common.py b/inference/core/entities/common.py new file mode 100644 index 0000000000000000000000000000000000000000..768c972093594c8c96b2e71c19b97fd022f6cfed --- /dev/null +++ b/inference/core/entities/common.py @@ -0,0 +1,12 @@ +from pydantic import Field + +ModelID = Field(example="raccoon-detector-1", description="A unique model identifier") +ModelType = Field( + default=None, + example="object-detection", + description="The type of the model, usually referring to what task the model performs", +) +ApiKey = Field( + default=None, + description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval", +) diff --git a/inference/core/entities/requests/__init__.py b/inference/core/entities/requests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/core/entities/requests/__pycache__/__init__.cpython-310.pyc b/inference/core/entities/requests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2458b49b0788f7d1cee98a9890449dd785c1135 Binary files /dev/null and b/inference/core/entities/requests/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/entities/requests/__pycache__/clip.cpython-310.pyc b/inference/core/entities/requests/__pycache__/clip.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2709615065606cdbd79f5f943728fb587e10ce8d Binary files /dev/null and b/inference/core/entities/requests/__pycache__/clip.cpython-310.pyc differ diff --git a/inference/core/entities/requests/__pycache__/cogvlm.cpython-310.pyc b/inference/core/entities/requests/__pycache__/cogvlm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4083754227c3c610b06157423f7ae63f50c252ce Binary files /dev/null and b/inference/core/entities/requests/__pycache__/cogvlm.cpython-310.pyc differ diff --git a/inference/core/entities/requests/__pycache__/doctr.cpython-310.pyc b/inference/core/entities/requests/__pycache__/doctr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..317f90defef83174dcf7403d35316065cce35ae2 Binary files /dev/null and b/inference/core/entities/requests/__pycache__/doctr.cpython-310.pyc differ diff --git a/inference/core/entities/requests/__pycache__/dynamic_class_base.cpython-310.pyc b/inference/core/entities/requests/__pycache__/dynamic_class_base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34e729a9fdd453320723fcf5ba1e88cb01cc0fa8 Binary files /dev/null and b/inference/core/entities/requests/__pycache__/dynamic_class_base.cpython-310.pyc differ diff --git a/inference/core/entities/requests/__pycache__/gaze.cpython-310.pyc b/inference/core/entities/requests/__pycache__/gaze.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..001dcbe6319e698415fd4501aa16659ab0bd5b20 Binary files /dev/null and b/inference/core/entities/requests/__pycache__/gaze.cpython-310.pyc differ diff --git a/inference/core/entities/requests/__pycache__/groundingdino.cpython-310.pyc b/inference/core/entities/requests/__pycache__/groundingdino.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9153c202afcc6fa60b688bc30088b7eee2c14503 Binary files /dev/null and b/inference/core/entities/requests/__pycache__/groundingdino.cpython-310.pyc differ diff --git a/inference/core/entities/requests/__pycache__/inference.cpython-310.pyc b/inference/core/entities/requests/__pycache__/inference.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..603e5cb33f316584d2d7882cef4e1a3b74c66831 Binary files /dev/null and b/inference/core/entities/requests/__pycache__/inference.cpython-310.pyc differ diff --git a/inference/core/entities/requests/__pycache__/sam.cpython-310.pyc b/inference/core/entities/requests/__pycache__/sam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd7912b761ecb48379eb9e34e606abdb5a0f7692 Binary files /dev/null and b/inference/core/entities/requests/__pycache__/sam.cpython-310.pyc differ diff --git a/inference/core/entities/requests/__pycache__/server_state.cpython-310.pyc b/inference/core/entities/requests/__pycache__/server_state.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30d4a8ae478536fd6142a7ddc9709c8214fc7309 Binary files /dev/null and b/inference/core/entities/requests/__pycache__/server_state.cpython-310.pyc differ diff --git a/inference/core/entities/requests/__pycache__/workflows.cpython-310.pyc b/inference/core/entities/requests/__pycache__/workflows.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a4655bafb55beb5311fa8e2200b5f6f9688aa7d Binary files /dev/null and b/inference/core/entities/requests/__pycache__/workflows.cpython-310.pyc differ diff --git a/inference/core/entities/requests/__pycache__/yolo_world.cpython-310.pyc b/inference/core/entities/requests/__pycache__/yolo_world.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0311906fa988f7151f43d96a943122cf4670a69b Binary files /dev/null and b/inference/core/entities/requests/__pycache__/yolo_world.cpython-310.pyc differ diff --git a/inference/core/entities/requests/clip.py b/inference/core/entities/requests/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..7e35fd422eaed4c0daba2bd93dc071a8c62926ac --- /dev/null +++ b/inference/core/entities/requests/clip.py @@ -0,0 +1,91 @@ +from typing import Dict, List, Optional, Union + +from pydantic import Field, validator + +from inference.core.entities.requests.inference import ( + BaseRequest, + InferenceRequestImage, +) +from inference.core.env import CLIP_VERSION_ID + + +class ClipInferenceRequest(BaseRequest): + """Request for CLIP inference. + + Attributes: + api_key (Optional[str]): Roboflow API Key. + clip_version_id (Optional[str]): The version ID of CLIP to be used for this request. + """ + + clip_version_id: Optional[str] = Field( + default=CLIP_VERSION_ID, + examples=["ViT-B-16"], + description="The version ID of CLIP to be used for this request. Must be one of RN101, RN50, RN50x16, RN50x4, RN50x64, ViT-B-16, ViT-B-32, ViT-L-14-336px, and ViT-L-14.", + ) + model_id: Optional[str] = Field(None) + + # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information. + @validator("model_id", always=True) + def validate_model_id(cls, value, values): + if value is not None: + return value + if values.get("clip_version_id") is None: + return None + return f"clip/{values['clip_version_id']}" + + +class ClipImageEmbeddingRequest(ClipInferenceRequest): + """Request for CLIP image embedding. + + Attributes: + image (Union[List[InferenceRequestImage], InferenceRequestImage]): Image(s) to be embedded. + """ + + image: Union[List[InferenceRequestImage], InferenceRequestImage] + + +class ClipTextEmbeddingRequest(ClipInferenceRequest): + """Request for CLIP text embedding. + + Attributes: + text (Union[List[str], str]): A string or list of strings. + """ + + text: Union[List[str], str] = Field( + examples=["The quick brown fox jumps over the lazy dog"], + description="A string or list of strings", + ) + + +class ClipCompareRequest(ClipInferenceRequest): + """Request for CLIP comparison. + + Attributes: + subject (Union[InferenceRequestImage, str]): The type of image data provided, one of 'url' or 'base64'. + subject_type (str): The type of subject, one of 'image' or 'text'. + prompt (Union[List[InferenceRequestImage], InferenceRequestImage, str, List[str], Dict[str, Union[InferenceRequestImage, str]]]): The prompt for comparison. + prompt_type (str): The type of prompt, one of 'image' or 'text'. + """ + + subject: Union[InferenceRequestImage, str] = Field( + examples=["url"], + description="The type of image data provided, one of 'url' or 'base64'", + ) + subject_type: str = Field( + default="image", + examples=["image"], + description="The type of subject, one of 'image' or 'text'", + ) + prompt: Union[ + List[InferenceRequestImage], + InferenceRequestImage, + str, + List[str], + Dict[str, Union[InferenceRequestImage, str]], + ] + prompt_type: str = Field( + default="text", + examples=["text"], + description="The type of prompt, one of 'image' or 'text'", + ) diff --git a/inference/core/entities/requests/cogvlm.py b/inference/core/entities/requests/cogvlm.py new file mode 100644 index 0000000000000000000000000000000000000000..fbd326fa0b8703c6f66649f325d4982ca639df2b --- /dev/null +++ b/inference/core/entities/requests/cogvlm.py @@ -0,0 +1,47 @@ +from typing import Dict, List, Optional, Tuple, Union + +from pydantic import Field, validator + +from inference.core.entities.requests.inference import ( + BaseRequest, + InferenceRequestImage, +) +from inference.core.env import COGVLM_VERSION_ID + + +class CogVLMInferenceRequest(BaseRequest): + """Request for CogVLM inference. + + Attributes: + api_key (Optional[str]): Roboflow API Key. + cog_version_id (Optional[str]): The version ID of CLIP to be used for this request. + """ + + cogvlm_version_id: Optional[str] = Field( + default=COGVLM_VERSION_ID, + examples=["cogvlm-chat-hf"], + description="The version ID of CogVLM to be used for this request. See the huggingface model repo at THUDM.", + ) + model_id: Optional[str] = Field(None) + image: InferenceRequestImage = Field( + description="Image for CogVLM to look at. Use prompt to specify what you want it to do with the image." + ) + prompt: str = Field( + description="Text to be passed to CogVLM. Use to prompt it to describe an image or provide only text to chat with the model.", + examples=["Describe this image."], + ) + history: Optional[List[Tuple[str, str]]] = Field( + None, + description="Optional chat history, formatted as a list of 2-tuples where the first entry is the user prompt" + " and the second entry is the generated model response", + ) + + # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information. + @validator("model_id", always=True) + def validate_model_id(cls, value, values): + if value is not None: + return value + if values.get("cogvlm_version_id") is None: + return None + return f"cogvlm/{values['cogvlm_version_id']}" diff --git a/inference/core/entities/requests/doctr.py b/inference/core/entities/requests/doctr.py new file mode 100644 index 0000000000000000000000000000000000000000..b71c04f13c776836d281cc6ce72ae09830f1ff8f --- /dev/null +++ b/inference/core/entities/requests/doctr.py @@ -0,0 +1,31 @@ +from typing import List, Optional, Union + +from pydantic import Field, validator + +from inference.core.entities.requests.inference import ( + BaseRequest, + InferenceRequestImage, +) + + +class DoctrOCRInferenceRequest(BaseRequest): + """ + DocTR inference request. + + Attributes: + api_key (Optional[str]): Roboflow API Key. + """ + + image: Union[List[InferenceRequestImage], InferenceRequestImage] + doctr_version_id: Optional[str] = "default" + model_id: Optional[str] = Field(None) + + # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information. + @validator("model_id", always=True, allow_reuse=True) + def validate_model_id(cls, value, values): + if value is not None: + return value + if values.get("doctr_version_id") is None: + return None + return f"doctr/{values['doctr_version_id']}" diff --git a/inference/core/entities/requests/dynamic_class_base.py b/inference/core/entities/requests/dynamic_class_base.py new file mode 100644 index 0000000000000000000000000000000000000000..a8515b283619b29bc96fa9c3166ca315aa97c341 --- /dev/null +++ b/inference/core/entities/requests/dynamic_class_base.py @@ -0,0 +1,19 @@ +from typing import List, Optional + +from pydantic import Field + +from inference.core.entities.requests.inference import CVInferenceRequest + + +class DynamicClassBaseInferenceRequest(CVInferenceRequest): + """Request for zero-shot object detection models (with dynamic class lists). + + Attributes: + text (List[str]): A list of strings. + """ + + model_id: Optional[str] = Field(None) + text: List[str] = Field( + examples=[["person", "dog", "cat"]], + description="A list of strings", + ) diff --git a/inference/core/entities/requests/gaze.py b/inference/core/entities/requests/gaze.py new file mode 100644 index 0000000000000000000000000000000000000000..009738f8ccf2e113c2572a23cb884b485389166a --- /dev/null +++ b/inference/core/entities/requests/gaze.py @@ -0,0 +1,46 @@ +from typing import List, Optional, Union + +from pydantic import Field, validator + +from inference.core.entities.common import ApiKey +from inference.core.entities.requests.inference import ( + BaseRequest, + InferenceRequestImage, +) +from inference.core.env import GAZE_VERSION_ID + + +class GazeDetectionInferenceRequest(BaseRequest): + """Request for gaze detection inference. + + Attributes: + api_key (Optional[str]): Roboflow API Key. + gaze_version_id (Optional[str]): The version ID of Gaze to be used for this request. + do_run_face_detection (Optional[bool]): If true, face detection will be applied; if false, face detection will be ignored and the whole input image will be used for gaze detection. + image (Union[List[InferenceRequestImage], InferenceRequestImage]): Image(s) for inference. + """ + + gaze_version_id: Optional[str] = Field( + default=GAZE_VERSION_ID, + examples=["l2cs"], + description="The version ID of Gaze to be used for this request. Must be one of l2cs.", + ) + + do_run_face_detection: Optional[bool] = Field( + default=True, + examples=[False], + description="If true, face detection will be applied; if false, face detection will be ignored and the whole input image will be used for gaze detection", + ) + + image: Union[List[InferenceRequestImage], InferenceRequestImage] + model_id: Optional[str] = Field(None) + + # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information. + @validator("model_id", always=True, allow_reuse=True) + def validate_model_id(cls, value, values): + if value is not None: + return value + if values.get("gaze_version_id") is None: + return None + return f"gaze/{values['gaze_version_id']}" diff --git a/inference/core/entities/requests/groundingdino.py b/inference/core/entities/requests/groundingdino.py new file mode 100644 index 0000000000000000000000000000000000000000..29206a994037b7923b915c84a7ac9d9bcaff5bac --- /dev/null +++ b/inference/core/entities/requests/groundingdino.py @@ -0,0 +1,15 @@ +from typing import List, Optional + +from inference.core.entities.requests.dynamic_class_base import ( + DynamicClassBaseInferenceRequest, +) + + +class GroundingDINOInferenceRequest(DynamicClassBaseInferenceRequest): + """Request for Grounding DINO zero-shot predictions. + + Attributes: + text (List[str]): A list of strings. + """ + + grounding_dino_version_id: Optional[str] = "default" diff --git a/inference/core/entities/requests/inference.py b/inference/core/entities/requests/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..4bb26742c282bae10d50de3505be3324d5b8bb08 --- /dev/null +++ b/inference/core/entities/requests/inference.py @@ -0,0 +1,233 @@ +from typing import Any, List, Optional, Union +from uuid import uuid4 + +from pydantic import BaseModel, ConfigDict, Field + +from inference.core.entities.common import ApiKey, ModelID, ModelType + + +class BaseRequest(BaseModel): + """Base request for inference. + + Attributes: + id (str_): A unique request identifier. + api_key (Optional[str]): Roboflow API Key that will be passed to the model during initialization for artifact retrieval. + start (Optional[float]): start time of request + """ + + def __init__(self, **kwargs): + kwargs["id"] = str(uuid4()) + super().__init__(**kwargs) + + model_config = ConfigDict(protected_namespaces=()) + id: str + api_key: Optional[str] = ApiKey + start: Optional[float] = None + source: Optional[str] = None + source_info: Optional[str] = None + + +class InferenceRequest(BaseRequest): + """Base request for inference. + + Attributes: + model_id (str): A unique model identifier. + model_type (Optional[str]): The type of the model, usually referring to what task the model performs. + """ + + model_id: Optional[str] = ModelID + model_type: Optional[str] = ModelType + + +class InferenceRequestImage(BaseModel): + """Image data for inference request. + + Attributes: + type (str): The type of image data provided, one of 'url', 'base64', or 'numpy'. + value (Optional[Any]): Image data corresponding to the image type. + """ + + type: str = Field( + examples=["url"], + description="The type of image data provided, one of 'url', 'base64', or 'numpy'", + ) + value: Optional[Any] = Field( + None, + examples=["http://www.example-image-url.com"], + description="Image data corresponding to the image type, if type = 'url' then value is a string containing the url of an image, else if type = 'base64' then value is a string containing base64 encoded image data, else if type = 'numpy' then value is binary numpy data serialized using pickle.dumps(); array should 3 dimensions, channels last, with values in the range [0,255].", + ) + + +class CVInferenceRequest(InferenceRequest): + """Computer Vision inference request. + + Attributes: + image (Union[List[InferenceRequestImage], InferenceRequestImage]): Image(s) for inference. + disable_preproc_auto_orient (Optional[bool]): If true, the auto orient preprocessing step is disabled for this call. Default is False. + disable_preproc_contrast (Optional[bool]): If true, the auto contrast preprocessing step is disabled for this call. Default is False. + disable_preproc_grayscale (Optional[bool]): If true, the grayscale preprocessing step is disabled for this call. Default is False. + disable_preproc_static_crop (Optional[bool]): If true, the static crop preprocessing step is disabled for this call. Default is False. + """ + + image: Union[List[InferenceRequestImage], InferenceRequestImage] + disable_preproc_auto_orient: Optional[bool] = Field( + default=False, + description="If true, the auto orient preprocessing step is disabled for this call.", + ) + disable_preproc_contrast: Optional[bool] = Field( + default=False, + description="If true, the auto contrast preprocessing step is disabled for this call.", + ) + disable_preproc_grayscale: Optional[bool] = Field( + default=False, + description="If true, the grayscale preprocessing step is disabled for this call.", + ) + disable_preproc_static_crop: Optional[bool] = Field( + default=False, + description="If true, the static crop preprocessing step is disabled for this call.", + ) + + +class ObjectDetectionInferenceRequest(CVInferenceRequest): + """Object Detection inference request. + + Attributes: + class_agnostic_nms (Optional[bool]): If true, NMS is applied to all detections at once, if false, NMS is applied per class. + class_filter (Optional[List[str]]): If provided, only predictions for the listed classes will be returned. + confidence (Optional[float]): The confidence threshold used to filter out predictions. + fix_batch_size (Optional[bool]): If true, the batch size will be fixed to the maximum batch size configured for this server. + iou_threshold (Optional[float]): The IoU threshold that must be met for a box pair to be considered duplicate during NMS. + max_detections (Optional[int]): The maximum number of detections that will be returned. + max_candidates (Optional[int]): The maximum number of candidate detections passed to NMS. + visualization_labels (Optional[bool]): If true, labels will be rendered on prediction visualizations. + visualization_stroke_width (Optional[int]): The stroke width used when visualizing predictions. + visualize_predictions (Optional[bool]): If true, the predictions will be drawn on the original image and returned as a base64 string. + """ + + class_agnostic_nms: Optional[bool] = Field( + default=False, + examples=[False], + description="If true, NMS is applied to all detections at once, if false, NMS is applied per class", + ) + class_filter: Optional[List[str]] = Field( + default=None, + examples=[["class-1", "class-2", "class-n"]], + description="If provided, only predictions for the listed classes will be returned", + ) + confidence: Optional[float] = Field( + default=0.4, + examples=[0.5], + description="The confidence threshold used to filter out predictions", + ) + fix_batch_size: Optional[bool] = Field( + default=False, + examples=[False], + description="If true, the batch size will be fixed to the maximum batch size configured for this server", + ) + iou_threshold: Optional[float] = Field( + default=0.3, + examples=[0.5], + description="The IoU threhsold that must be met for a box pair to be considered duplicate during NMS", + ) + max_detections: Optional[int] = Field( + default=300, + examples=[300], + description="The maximum number of detections that will be returned", + ) + max_candidates: Optional[int] = Field( + default=3000, + description="The maximum number of candidate detections passed to NMS", + ) + visualization_labels: Optional[bool] = Field( + default=False, + examples=[False], + description="If true, labels will be rendered on prediction visualizations", + ) + visualization_stroke_width: Optional[int] = Field( + default=1, + examples=[1], + description="The stroke width used when visualizing predictions", + ) + visualize_predictions: Optional[bool] = Field( + default=False, + examples=[False], + description="If true, the predictions will be drawn on the original image and returned as a base64 string", + ) + disable_active_learning: Optional[bool] = Field( + default=False, + examples=[False], + description="If true, the predictions will be prevented from registration by Active Learning (if the functionality is enabled)", + ) + + +class KeypointsDetectionInferenceRequest(ObjectDetectionInferenceRequest): + keypoint_confidence: Optional[float] = Field( + default=0.0, + examples=[0.5], + description="The confidence threshold used to filter out non visible keypoints", + ) + + +class InstanceSegmentationInferenceRequest(ObjectDetectionInferenceRequest): + """Instance Segmentation inference request. + + Attributes: + mask_decode_mode (Optional[str]): The mode used to decode instance segmentation masks, one of 'accurate', 'fast', 'tradeoff'. + tradeoff_factor (Optional[float]): The amount to tradeoff between 0='fast' and 1='accurate'. + """ + + mask_decode_mode: Optional[str] = Field( + default="accurate", + examples=["accurate"], + description="The mode used to decode instance segmentation masks, one of 'accurate', 'fast', 'tradeoff'", + ) + tradeoff_factor: Optional[float] = Field( + default=0.0, + examples=[0.5], + description="The amount to tradeoff between 0='fast' and 1='accurate'", + ) + + +class ClassificationInferenceRequest(CVInferenceRequest): + """Classification inference request. + + Attributes: + confidence (Optional[float]): The confidence threshold used to filter out predictions. + visualization_stroke_width (Optional[int]): The stroke width used when visualizing predictions. + visualize_predictions (Optional[bool]): If true, the predictions will be drawn on the original image and returned as a base64 string. + """ + + confidence: Optional[float] = Field( + default=0.4, + examples=[0.5], + description="The confidence threshold used to filter out predictions", + ) + visualization_stroke_width: Optional[int] = Field( + default=1, + examples=[1], + description="The stroke width used when visualizing predictions", + ) + visualize_predictions: Optional[bool] = Field( + default=False, + examples=[False], + description="If true, the predictions will be drawn on the original image and returned as a base64 string", + ) + disable_active_learning: Optional[bool] = Field( + default=False, + examples=[False], + description="If true, the predictions will be prevented from registration by Active Learning (if the functionality is enabled)", + ) + + +def request_from_type(model_type, request_dict): + """Uses original request id""" + if model_type == "classification": + request = ClassificationInferenceRequest(**request_dict) + elif model_type == "instance-segmentation": + request = InstanceSegmentationInferenceRequest(**request_dict) + elif model_type == "object-detection": + request = ObjectDetectionInferenceRequest(**request_dict) + else: + raise ValueError(f"Uknown task type {model_type}") + request.id = request_dict.get("id") + return request diff --git a/inference/core/entities/requests/sam.py b/inference/core/entities/requests/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..6c971dae869c875fa57e652dfde8c6407601388b --- /dev/null +++ b/inference/core/entities/requests/sam.py @@ -0,0 +1,139 @@ +from typing import Any, List, Optional, Union + +from pydantic import Field, root_validator, validator + +from inference.core.entities.requests.inference import ( + BaseRequest, + InferenceRequestImage, +) +from inference.core.env import SAM_VERSION_ID + + +class SamInferenceRequest(BaseRequest): + """SAM inference request. + + Attributes: + api_key (Optional[str]): Roboflow API Key. + sam_version_id (Optional[str]): The version ID of SAM to be used for this request. + """ + + sam_version_id: Optional[str] = Field( + default=SAM_VERSION_ID, + examples=["vit_h"], + description="The version ID of SAM to be used for this request. Must be one of vit_h, vit_l, or vit_b.", + ) + + model_id: Optional[str] = Field(None) + + # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually. + # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information. + @validator("model_id", always=True) + def validate_model_id(cls, value, values): + if value is not None: + return value + if values.get("sam_version_id") is None: + return None + return f"sam/{values['sam_version_id']}" + + +class SamEmbeddingRequest(SamInferenceRequest): + """SAM embedding request. + + Attributes: + image (Optional[inference.core.entities.requests.inference.InferenceRequestImage]): The image to be embedded. + image_id (Optional[str]): The ID of the image to be embedded used to cache the embedding. + format (Optional[str]): The format of the response. Must be one of json or binary. + """ + + image: Optional[InferenceRequestImage] = Field( + default=None, + description="The image to be embedded", + ) + image_id: Optional[str] = Field( + default=None, + examples=["image_id"], + description="The ID of the image to be embedded used to cache the embedding.", + ) + format: Optional[str] = Field( + default="json", + examples=["json"], + description="The format of the response. Must be one of json or binary. If binary, embedding is returned as a binary numpy array.", + ) + + +class SamSegmentationRequest(SamInferenceRequest): + """SAM segmentation request. + + Attributes: + embeddings (Optional[Union[List[List[List[List[float]]]], Any]]): The embeddings to be decoded. + embeddings_format (Optional[str]): The format of the embeddings. + format (Optional[str]): The format of the response. + image (Optional[InferenceRequestImage]): The image to be segmented. + image_id (Optional[str]): The ID of the image to be segmented used to retrieve cached embeddings. + has_mask_input (Optional[bool]): Whether or not the request includes a mask input. + mask_input (Optional[Union[List[List[List[float]]], Any]]): The set of output masks. + mask_input_format (Optional[str]): The format of the mask input. + orig_im_size (Optional[List[int]]): The original size of the image used to generate the embeddings. + point_coords (Optional[List[List[float]]]): The coordinates of the interactive points used during decoding. + point_labels (Optional[List[float]]): The labels of the interactive points used during decoding. + use_mask_input_cache (Optional[bool]): Whether or not to use the mask input cache. + """ + + embeddings: Optional[Union[List[List[List[List[float]]]], Any]] = Field( + None, + examples=["[[[[0.1, 0.2, 0.3, ...] ...] ...]]"], + description="The embeddings to be decoded. The dimensions of the embeddings are 1 x 256 x 64 x 64. If embeddings is not provided, image must be provided.", + ) + embeddings_format: Optional[str] = Field( + default="json", + examples=["json"], + description="The format of the embeddings. Must be one of json or binary. If binary, embeddings are expected to be a binary numpy array.", + ) + format: Optional[str] = Field( + default="json", + examples=["json"], + description="The format of the response. Must be one of json or binary. If binary, masks are returned as binary numpy arrays. If json, masks are converted to polygons, then returned as json.", + ) + image: Optional[InferenceRequestImage] = Field( + default=None, + description="The image to be segmented. Only required if embeddings are not provided.", + ) + image_id: Optional[str] = Field( + default=None, + examples=["image_id"], + description="The ID of the image to be segmented used to retrieve cached embeddings. If an embedding is cached, it will be used instead of generating a new embedding. If no embedding is cached, a new embedding will be generated and cached.", + ) + has_mask_input: Optional[bool] = Field( + default=False, + examples=[True], + description="Whether or not the request includes a mask input. If true, the mask input must be provided.", + ) + mask_input: Optional[Union[List[List[List[float]]], Any]] = Field( + default=None, + description="The set of output masks. If request format is json, masks is a list of polygons, where each polygon is a list of points, where each point is a tuple containing the x,y pixel coordinates of the point. If request format is binary, masks is a list of binary numpy arrays. The dimensions of each mask are 256 x 256. This is the same as the output, low resolution mask from the previous inference.", + ) + mask_input_format: Optional[str] = Field( + default="json", + examples=["json"], + description="The format of the mask input. Must be one of json or binary. If binary, mask input is expected to be a binary numpy array.", + ) + orig_im_size: Optional[List[int]] = Field( + default=None, + examples=[[640, 320]], + description="The original size of the image used to generate the embeddings. This is only required if the image is not provided.", + ) + point_coords: Optional[List[List[float]]] = Field( + default=[[0.0, 0.0]], + examples=[[[10.0, 10.0]]], + description="The coordinates of the interactive points used during decoding. Each point (x,y pair) corresponds to a label in point_labels.", + ) + point_labels: Optional[List[float]] = Field( + default=[-1], + examples=[[1]], + description="The labels of the interactive points used during decoding. A 1 represents a positive point (part of the object to be segmented). A -1 represents a negative point (not part of the object to be segmented). Each label corresponds to a point in point_coords.", + ) + use_mask_input_cache: Optional[bool] = Field( + default=True, + examples=[True], + description="Whether or not to use the mask input cache. If true, the mask input cache will be used if it exists. If false, the mask input cache will not be used.", + ) diff --git a/inference/core/entities/requests/server_state.py b/inference/core/entities/requests/server_state.py new file mode 100644 index 0000000000000000000000000000000000000000..917c665d66a406da41a71f7cce464a15fdd0544a --- /dev/null +++ b/inference/core/entities/requests/server_state.py @@ -0,0 +1,31 @@ +from typing import Optional + +from pydantic import BaseModel, ConfigDict + +from inference.core.entities.common import ApiKey, ModelID, ModelType + + +class AddModelRequest(BaseModel): + """Request to add a model to the inference server. + + Attributes: + model_id (str): A unique model identifier. + model_type (Optional[str]): The type of the model, usually referring to what task the model performs. + api_key (Optional[str]): Roboflow API Key that will be passed to the model during initialization for artifact retrieval. + """ + + model_config = ConfigDict(protected_namespaces=()) + model_id: str = ModelID + model_type: Optional[str] = ModelType + api_key: Optional[str] = ApiKey + + +class ClearModelRequest(BaseModel): + """Request to clear a model from the inference server. + + Attributes: + model_id (str): A unique model identifier. + """ + + model_config = ConfigDict(protected_namespaces=()) + model_id: str = ModelID diff --git a/inference/core/entities/requests/workflows.py b/inference/core/entities/requests/workflows.py new file mode 100644 index 0000000000000000000000000000000000000000..9f0bf30e37a68a77f840e9e20b9bf10861e46f5e --- /dev/null +++ b/inference/core/entities/requests/workflows.py @@ -0,0 +1,24 @@ +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + +from inference.enterprise.workflows.entities.workflows_specification import ( + WorkflowSpecificationV1, +) + + +class WorkflowInferenceRequest(BaseModel): + api_key: str = Field( + description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval", + ) + inputs: Dict[str, Any] = Field( + description="Dictionary that contains each parameter defined as an input for chosen workflow" + ) + excluded_fields: Optional[List[str]] = Field( + default=None, + description="List of field that shall be excluded from the response (among those defined in workflow specification)", + ) + + +class WorkflowSpecificationInferenceRequest(WorkflowInferenceRequest): + specification: WorkflowSpecificationV1 diff --git a/inference/core/entities/requests/yolo_world.py b/inference/core/entities/requests/yolo_world.py new file mode 100644 index 0000000000000000000000000000000000000000..a2b3167cf11e6b57b4f91e628663a9155ff82ff3 --- /dev/null +++ b/inference/core/entities/requests/yolo_world.py @@ -0,0 +1,17 @@ +from typing import List, Optional + +from inference.core.entities.requests.dynamic_class_base import ( + DynamicClassBaseInferenceRequest, +) +from inference.core.models.defaults import DEFAULT_CONFIDENCE + + +class YOLOWorldInferenceRequest(DynamicClassBaseInferenceRequest): + """Request for Grounding DINO zero-shot predictions. + + Attributes: + text (List[str]): A list of strings. + """ + + yolo_world_version_id: Optional[str] = "l" + confidence: Optional[float] = DEFAULT_CONFIDENCE diff --git a/inference/core/entities/responses/__init__.py b/inference/core/entities/responses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/core/entities/responses/__pycache__/__init__.cpython-310.pyc b/inference/core/entities/responses/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e4966921204859c731350a7db859208a970b8e2 Binary files /dev/null and b/inference/core/entities/responses/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/entities/responses/__pycache__/clip.cpython-310.pyc b/inference/core/entities/responses/__pycache__/clip.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27216bd2114ac50e9bbaa9997738e99ca8d70c50 Binary files /dev/null and b/inference/core/entities/responses/__pycache__/clip.cpython-310.pyc differ diff --git a/inference/core/entities/responses/__pycache__/cogvlm.cpython-310.pyc b/inference/core/entities/responses/__pycache__/cogvlm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb9e2ff181f4844c0a53ec5a566232eb04efd588 Binary files /dev/null and b/inference/core/entities/responses/__pycache__/cogvlm.cpython-310.pyc differ diff --git a/inference/core/entities/responses/__pycache__/doctr.cpython-310.pyc b/inference/core/entities/responses/__pycache__/doctr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6400d8b4ed6ffdd007d5057e16be96340b5423ac Binary files /dev/null and b/inference/core/entities/responses/__pycache__/doctr.cpython-310.pyc differ diff --git a/inference/core/entities/responses/__pycache__/gaze.cpython-310.pyc b/inference/core/entities/responses/__pycache__/gaze.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7acac181e9a19f0d0459d967b22658502b92238a Binary files /dev/null and b/inference/core/entities/responses/__pycache__/gaze.cpython-310.pyc differ diff --git a/inference/core/entities/responses/__pycache__/groundingdino.cpython-310.pyc b/inference/core/entities/responses/__pycache__/groundingdino.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1653a671b5f8b3c5b7d044be760cb231bcf87613 Binary files /dev/null and b/inference/core/entities/responses/__pycache__/groundingdino.cpython-310.pyc differ diff --git a/inference/core/entities/responses/__pycache__/inference.cpython-310.pyc b/inference/core/entities/responses/__pycache__/inference.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c595588ca8794bba65f263f54ff7521544d1a5e6 Binary files /dev/null and b/inference/core/entities/responses/__pycache__/inference.cpython-310.pyc differ diff --git a/inference/core/entities/responses/__pycache__/notebooks.cpython-310.pyc b/inference/core/entities/responses/__pycache__/notebooks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b676485e1cbda4baecb2bc79a2d3438ff61b862 Binary files /dev/null and b/inference/core/entities/responses/__pycache__/notebooks.cpython-310.pyc differ diff --git a/inference/core/entities/responses/__pycache__/sam.cpython-310.pyc b/inference/core/entities/responses/__pycache__/sam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20716285c11e04298a65d1d47d6d148a13603ae8 Binary files /dev/null and b/inference/core/entities/responses/__pycache__/sam.cpython-310.pyc differ diff --git a/inference/core/entities/responses/__pycache__/server_state.cpython-310.pyc b/inference/core/entities/responses/__pycache__/server_state.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fff9ee26aa6001f640ee166069ddcc1b8d70d94 Binary files /dev/null and b/inference/core/entities/responses/__pycache__/server_state.cpython-310.pyc differ diff --git a/inference/core/entities/responses/__pycache__/workflows.cpython-310.pyc b/inference/core/entities/responses/__pycache__/workflows.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4beb896926a2bf1a14432cf4965031536f7d37e0 Binary files /dev/null and b/inference/core/entities/responses/__pycache__/workflows.cpython-310.pyc differ diff --git a/inference/core/entities/responses/clip.py b/inference/core/entities/responses/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..1d6922929351226f7c76f3e336f7964d034868d9 --- /dev/null +++ b/inference/core/entities/responses/clip.py @@ -0,0 +1,42 @@ +from typing import Dict, List, Optional, Union + +from pydantic import Field + +from inference.core.entities.responses.inference import InferenceResponse + + +class ClipEmbeddingResponse(InferenceResponse): + """Response for CLIP embedding. + + Attributes: + embeddings (List[List[float]]): A list of embeddings, each embedding is a list of floats. + time (float): The time in seconds it took to produce the embeddings including preprocessing. + """ + + embeddings: List[List[float]] = Field( + examples=["[[0.12, 0.23, 0.34, ..., 0.43]]"], + description="A list of embeddings, each embedding is a list of floats", + ) + time: Optional[float] = Field( + None, + description="The time in seconds it took to produce the embeddings including preprocessing", + ) + + +class ClipCompareResponse(InferenceResponse): + """Response for CLIP comparison. + + Attributes: + similarity (Union[List[float], Dict[str, float]]): Similarity scores. + time (float): The time in seconds it took to produce the similarity scores including preprocessing. + """ + + similarity: Union[List[float], Dict[str, float]] + time: Optional[float] = Field( + None, + description="The time in seconds it took to produce the similarity scores including preprocessing", + ) + parent_id: Optional[str] = Field( + description="Identifier of parent image region. Useful when stack of detection-models is in use to refer the RoI being the input to inference", + default=None, + ) diff --git a/inference/core/entities/responses/cogvlm.py b/inference/core/entities/responses/cogvlm.py new file mode 100644 index 0000000000000000000000000000000000000000..d3e5e9589017f28ab70f46528901d42633f10ecf --- /dev/null +++ b/inference/core/entities/responses/cogvlm.py @@ -0,0 +1,11 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class CogVLMResponse(BaseModel): + response: str = Field(description="Text generated by CogVLM") + time: Optional[float] = Field( + None, + description="The time in seconds it took to produce the response including preprocessing", + ) diff --git a/inference/core/entities/responses/doctr.py b/inference/core/entities/responses/doctr.py new file mode 100644 index 0000000000000000000000000000000000000000..6b53daa60d5c3869800c30d134ca8488323b1826 --- /dev/null +++ b/inference/core/entities/responses/doctr.py @@ -0,0 +1,22 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class DoctrOCRInferenceResponse(BaseModel): + """ + DocTR Inference response. + + Attributes: + result (str): The result from OCR. + time: The time in seconds it took to produce the segmentation including preprocessing. + """ + + result: str = Field(description="The result from OCR.") + time: float = Field( + description="The time in seconds it took to produce the segmentation including preprocessing." + ) + parent_id: Optional[str] = Field( + description="Identifier of parent image region. Useful when stack of detection-models is in use to refer the RoI being the input to inference", + default=None, + ) diff --git a/inference/core/entities/responses/gaze.py b/inference/core/entities/responses/gaze.py new file mode 100644 index 0000000000000000000000000000000000000000..06e5b005c77e0ba5260dfada05e7b02fb35dfbdb --- /dev/null +++ b/inference/core/entities/responses/gaze.py @@ -0,0 +1,39 @@ +from typing import List, Optional + +from pydantic import BaseModel, Field + +from inference.core.entities.responses.inference import FaceDetectionPrediction + + +class GazeDetectionPrediction(BaseModel): + """Gaze Detection prediction. + + Attributes: + face (inference.core.entities.responses.inference.FaceDetectionPrediction): The face prediction. + yaw (float): Yaw (radian) of the detected face. + pitch (float): Pitch (radian) of the detected face. + """ + + face: FaceDetectionPrediction + + yaw: float = Field(description="Yaw (radian) of the detected face") + pitch: float = Field(description="Pitch (radian) of the detected face") + + +class GazeDetectionInferenceResponse(BaseModel): + """Response for gaze detection inference. + + Attributes: + predictions (List[inference.core.entities.responses.gaze.GazeDetectionPrediction]): List of gaze detection predictions. + time (float): The processing time (second). + """ + + predictions: List[GazeDetectionPrediction] + + time: float = Field(description="The processing time (second)") + time_face_det: Optional[float] = Field( + None, description="The face detection time (second)" + ) + time_gaze_det: Optional[float] = Field( + None, description="The gaze detection time (second)" + ) diff --git a/inference/core/entities/responses/groundingdino.py b/inference/core/entities/responses/groundingdino.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/core/entities/responses/inference.py b/inference/core/entities/responses/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..06474a69752c32b8a7f863d692821ba26f6a0669 --- /dev/null +++ b/inference/core/entities/responses/inference.py @@ -0,0 +1,322 @@ +import base64 +from typing import Any, Dict, List, Optional, Union +from uuid import uuid4 + +from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_serializer + + +class ObjectDetectionPrediction(BaseModel): + """Object Detection prediction. + + Attributes: + x (float): The center x-axis pixel coordinate of the prediction. + y (float): The center y-axis pixel coordinate of the prediction. + width (float): The width of the prediction bounding box in number of pixels. + height (float): The height of the prediction bounding box in number of pixels. + confidence (float): The detection confidence as a fraction between 0 and 1. + class_name (str): The predicted class label. + class_confidence (Union[float, None]): The class label confidence as a fraction between 0 and 1. + class_id (int): The class id of the prediction + """ + + x: float = Field(description="The center x-axis pixel coordinate of the prediction") + y: float = Field(description="The center y-axis pixel coordinate of the prediction") + width: float = Field( + description="The width of the prediction bounding box in number of pixels" + ) + height: float = Field( + description="The height of the prediction bounding box in number of pixels" + ) + confidence: float = Field( + description="The detection confidence as a fraction between 0 and 1" + ) + class_name: str = Field(alias="class", description="The predicted class label") + + class_confidence: Union[float, None] = Field( + None, description="The class label confidence as a fraction between 0 and 1" + ) + class_id: int = Field(description="The class id of the prediction") + tracker_id: Optional[int] = Field( + description="The tracker id of the prediction if tracking is enabled", + default=None, + ) + detection_id: str = Field( + description="Unique identifier of detection", + default_factory=lambda: str(uuid4()), + ) + parent_id: Optional[str] = Field( + description="Identifier of parent image region. Useful when stack of detection-models is in use to refer the RoI being the input to inference", + default=None, + ) + + +class Point(BaseModel): + """Point coordinates. + + Attributes: + x (float): The x-axis pixel coordinate of the point. + y (float): The y-axis pixel coordinate of the point. + """ + + x: float = Field(description="The x-axis pixel coordinate of the point") + y: float = Field(description="The y-axis pixel coordinate of the point") + + +class Point3D(Point): + """3D Point coordinates. + + Attributes: + z (float): The z-axis pixel coordinate of the point. + """ + + z: float = Field(description="The z-axis pixel coordinate of the point") + + +class InstanceSegmentationPrediction(BaseModel): + """Instance Segmentation prediction. + + Attributes: + x (float): The center x-axis pixel coordinate of the prediction. + y (float): The center y-axis pixel coordinate of the prediction. + width (float): The width of the prediction bounding box in number of pixels. + height (float): The height of the prediction bounding box in number of pixels. + confidence (float): The detection confidence as a fraction between 0 and 1. + class_name (str): The predicted class label. + class_confidence (Union[float, None]): The class label confidence as a fraction between 0 and 1. + points (List[Point]): The list of points that make up the instance polygon. + class_id: int = Field(description="The class id of the prediction") + """ + + x: float = Field(description="The center x-axis pixel coordinate of the prediction") + y: float = Field(description="The center y-axis pixel coordinate of the prediction") + width: float = Field( + description="The width of the prediction bounding box in number of pixels" + ) + height: float = Field( + description="The height of the prediction bounding box in number of pixels" + ) + confidence: float = Field( + description="The detection confidence as a fraction between 0 and 1" + ) + class_name: str = Field(alias="class", description="The predicted class label") + + class_confidence: Union[float, None] = Field( + None, description="The class label confidence as a fraction between 0 and 1" + ) + points: List[Point] = Field( + description="The list of points that make up the instance polygon" + ) + class_id: int = Field(description="The class id of the prediction") + detection_id: str = Field( + description="Unique identifier of detection", + default_factory=lambda: str(uuid4()), + ) + parent_id: Optional[str] = Field( + description="Identifier of parent image region. Useful when stack of detection-models is in use to refer the RoI being the input to inference", + default=None, + ) + + +class ClassificationPrediction(BaseModel): + """Classification prediction. + + Attributes: + class_name (str): The predicted class label. + class_id (int): Numeric ID associated with the class label. + confidence (float): The class label confidence as a fraction between 0 and 1. + """ + + class_name: str = Field(alias="class", description="The predicted class label") + class_id: int = Field(description="Numeric ID associated with the class label") + confidence: float = Field( + description="The class label confidence as a fraction between 0 and 1" + ) + + +class MultiLabelClassificationPrediction(BaseModel): + """Multi-label Classification prediction. + + Attributes: + confidence (float): The class label confidence as a fraction between 0 and 1. + """ + + confidence: float = Field( + description="The class label confidence as a fraction between 0 and 1" + ) + + +class InferenceResponseImage(BaseModel): + """Inference response image information. + + Attributes: + width (int): The original width of the image used in inference. + height (int): The original height of the image used in inference. + """ + + width: int = Field(description="The original width of the image used in inference") + height: int = Field( + description="The original height of the image used in inference" + ) + + +class InferenceResponse(BaseModel): + """Base inference response. + + Attributes: + frame_id (Optional[int]): The frame id of the image used in inference if the input was a video. + time (Optional[float]): The time in seconds it took to produce the predictions including image preprocessing. + """ + + model_config = ConfigDict(protected_namespaces=()) + frame_id: Optional[int] = Field( + default=None, + description="The frame id of the image used in inference if the input was a video", + ) + time: Optional[float] = Field( + default=None, + description="The time in seconds it took to produce the predictions including image preprocessing", + ) + + +class CvInferenceResponse(InferenceResponse): + """Computer Vision inference response. + + Attributes: + image (Union[List[inference.core.entities.responses.inference.InferenceResponseImage], inference.core.entities.responses.inference.InferenceResponseImage]): Image(s) used in inference. + """ + + image: Union[List[InferenceResponseImage], InferenceResponseImage] + + +class WithVisualizationResponse(BaseModel): + """Response with visualization. + + Attributes: + visualization (Optional[Any]): Base64 encoded string containing prediction visualization image data. + """ + + visualization: Optional[Any] = Field( + default=None, + description="Base64 encoded string containing prediction visualization image data", + ) + + @field_serializer("visualization", when_used="json") + def serialize_visualisation(self, visualization: Optional[Any]) -> Optional[str]: + if visualization is None: + return None + return base64.b64encode(visualization).decode("utf-8") + + +class ObjectDetectionInferenceResponse(CvInferenceResponse, WithVisualizationResponse): + """Object Detection inference response. + + Attributes: + predictions (List[inference.core.entities.responses.inference.ObjectDetectionPrediction]): List of object detection predictions. + """ + + predictions: List[ObjectDetectionPrediction] + + +class Keypoint(Point): + confidence: float = Field( + description="Model confidence regarding keypoint visibility." + ) + class_id: int = Field(description="Identifier of keypoint.") + class_name: str = Field(field="class", description="Type of keypoint.") + + +class KeypointsPrediction(ObjectDetectionPrediction): + keypoints: List[Keypoint] + + +class KeypointsDetectionInferenceResponse( + CvInferenceResponse, WithVisualizationResponse +): + predictions: List[KeypointsPrediction] + + +class InstanceSegmentationInferenceResponse( + CvInferenceResponse, WithVisualizationResponse +): + """Instance Segmentation inference response. + + Attributes: + predictions (List[inference.core.entities.responses.inference.InstanceSegmentationPrediction]): List of instance segmentation predictions. + """ + + predictions: List[InstanceSegmentationPrediction] + + +class ClassificationInferenceResponse(CvInferenceResponse, WithVisualizationResponse): + """Classification inference response. + + Attributes: + predictions (List[inference.core.entities.responses.inference.ClassificationPrediction]): List of classification predictions. + top (str): The top predicted class label. + confidence (float): The confidence of the top predicted class label. + """ + + predictions: List[ClassificationPrediction] + top: str = Field(description="The top predicted class label") + confidence: float = Field( + description="The confidence of the top predicted class label" + ) + parent_id: Optional[str] = Field( + description="Identifier of parent image region. Useful when stack of detection-models is in use to refer the RoI being the input to inference", + default=None, + ) + + +class MultiLabelClassificationInferenceResponse( + CvInferenceResponse, WithVisualizationResponse +): + """Multi-label Classification inference response. + + Attributes: + predictions (Dict[str, inference.core.entities.responses.inference.MultiLabelClassificationPrediction]): Dictionary of multi-label classification predictions. + predicted_classes (List[str]): The list of predicted classes. + """ + + predictions: Dict[str, MultiLabelClassificationPrediction] + predicted_classes: List[str] = Field(description="The list of predicted classes") + parent_id: Optional[str] = Field( + description="Identifier of parent image region. Useful when stack of detection-models is in use to refer the RoI being the input to inference", + default=None, + ) + + +class FaceDetectionPrediction(ObjectDetectionPrediction): + """Face Detection prediction. + + Attributes: + class_name (str): fixed value "face". + landmarks (Union[List[inference.core.entities.responses.inference.Point], List[inference.core.entities.responses.inference.Point3D]]): The detected face landmarks. + """ + + class_id: Optional[int] = Field( + description="The class id of the prediction", default=0 + ) + class_name: str = Field( + alias="class", default="face", description="The predicted class label" + ) + landmarks: Union[List[Point], List[Point3D]] + + +def response_from_type(model_type, response_dict): + if model_type == "classification": + try: + return ClassificationInferenceResponse(**response_dict) + except ValidationError: + return MultiLabelClassificationInferenceResponse(**response_dict) + elif model_type == "instance-segmentation": + return InstanceSegmentationInferenceResponse(**response_dict) + elif model_type == "object-detection": + return ObjectDetectionInferenceResponse(**response_dict) + else: + raise ValueError(f"Uknown task type {model_type}") + + +class StubResponse(InferenceResponse, WithVisualizationResponse): + is_stub: bool = Field(description="Field to mark prediction type as stub") + model_id: str = Field(description="Identifier of a model stub that was called") + task_type: str = Field(description="Task type of the project") diff --git a/inference/core/entities/responses/notebooks.py b/inference/core/entities/responses/notebooks.py new file mode 100644 index 0000000000000000000000000000000000000000..bd2ca292db9dd7f390259b99d20f7a03b2a81765 --- /dev/null +++ b/inference/core/entities/responses/notebooks.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel, Field, ValidationError + + +class NotebookStartResponse(BaseModel): + """Response model for notebook start request""" + + success: str = Field(..., description="Status of the request") + message: str = Field(..., description="Message of the request", optional=True) diff --git a/inference/core/entities/responses/sam.py b/inference/core/entities/responses/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..f455219ca5d5e5497ea2405dd3f76911050e490e --- /dev/null +++ b/inference/core/entities/responses/sam.py @@ -0,0 +1,40 @@ +from typing import Any, List, Union + +from pydantic import BaseModel, Field + + +class SamEmbeddingResponse(BaseModel): + """SAM embedding response. + + Attributes: + embeddings (Union[List[List[List[List[float]]]], Any]): The SAM embedding. + time (float): The time in seconds it took to produce the embeddings including preprocessing. + """ + + embeddings: Union[List[List[List[List[float]]]], Any] = Field( + examples=["[[[[0.1, 0.2, 0.3, ...] ...] ...]]"], + description="If request format is json, embeddings is a series of nested lists representing the SAM embedding. If request format is binary, embeddings is a binary numpy array. The dimensions of the embedding are 1 x 256 x 64 x 64.", + ) + time: float = Field( + description="The time in seconds it took to produce the embeddings including preprocessing" + ) + + +class SamSegmentationResponse(BaseModel): + """SAM segmentation response. + + Attributes: + masks (Union[List[List[List[int]]], Any]): The set of output masks. + low_res_masks (Union[List[List[List[int]]], Any]): The set of output low-resolution masks. + time (float): The time in seconds it took to produce the segmentation including preprocessing. + """ + + masks: Union[List[List[List[int]]], Any] = Field( + description="The set of output masks. If request format is json, masks is a list of polygons, where each polygon is a list of points, where each point is a tuple containing the x,y pixel coordinates of the point. If request format is binary, masks is a list of binary numpy arrays. The dimensions of each mask are the same as the dimensions of the input image.", + ) + low_res_masks: Union[List[List[List[int]]], Any] = Field( + description="The set of output masks. If request format is json, masks is a list of polygons, where each polygon is a list of points, where each point is a tuple containing the x,y pixel coordinates of the point. If request format is binary, masks is a list of binary numpy arrays. The dimensions of each mask are 256 x 256", + ) + time: float = Field( + description="The time in seconds it took to produce the segmentation including preprocessing" + ) diff --git a/inference/core/entities/responses/server_state.py b/inference/core/entities/responses/server_state.py new file mode 100644 index 0000000000000000000000000000000000000000..6dcdd5d727735a45872bab662f22620254bff0f8 --- /dev/null +++ b/inference/core/entities/responses/server_state.py @@ -0,0 +1,73 @@ +from typing import List, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field + +from inference.core.managers.entities import ModelDescription + + +class ServerVersionInfo(BaseModel): + """Server version information. + + Attributes: + name (str): Server name. + version (str): Server version. + uuid (str): Server UUID. + """ + + name: str = Field(examples=["Roboflow Inference Server"]) + version: str = Field(examples=["0.0.1"]) + uuid: str = Field(examples=["9c18c6f4-2266-41fb-8a0f-c12ae28f6fbe"]) + + +class ModelDescriptionEntity(BaseModel): + model_config = ConfigDict(protected_namespaces=()) + model_id: str = Field( + description="Identifier of the model", examples=["some-project/3"] + ) + task_type: str = Field( + description="Type of the task that the model performs", + examples=["classification"], + ) + batch_size: Optional[Union[int, str]] = Field( + None, + description="Batch size accepted by the model (if registered).", + ) + input_height: Optional[int] = Field( + None, + description="Image input height accepted by the model (if registered).", + ) + input_width: Optional[int] = Field( + None, + description="Image input width accepted by the model (if registered).", + ) + + @classmethod + def from_model_description( + cls, model_description: ModelDescription + ) -> "ModelDescriptionEntity": + return cls( + model_id=model_description.model_id, + task_type=model_description.task_type, + batch_size=model_description.batch_size, + input_height=model_description.input_height, + input_width=model_description.input_width, + ) + + +class ModelsDescriptions(BaseModel): + models: List[ModelDescriptionEntity] = Field( + description="List of models that are loaded by model manager.", + ) + + @classmethod + def from_models_descriptions( + cls, models_descriptions: List[ModelDescription] + ) -> "ModelsDescriptions": + return cls( + models=[ + ModelDescriptionEntity.from_model_description( + model_description=model_description + ) + for model_description in models_descriptions + ] + ) diff --git a/inference/core/entities/responses/workflows.py b/inference/core/entities/responses/workflows.py new file mode 100644 index 0000000000000000000000000000000000000000..f8d72ee017cdaaddc0713f3350bcc999d58b780f --- /dev/null +++ b/inference/core/entities/responses/workflows.py @@ -0,0 +1,9 @@ +from typing import Any, Dict + +from pydantic import BaseModel, Field + + +class WorkflowInferenceResponse(BaseModel): + outputs: Dict[str, Any] = Field( + description="Dictionary with keys defined in workflow output and serialised values" + ) diff --git a/inference/core/entities/types.py b/inference/core/entities/types.py new file mode 100644 index 0000000000000000000000000000000000000000..16611627d741c1c7e9ecb2787f67192821fa93f5 --- /dev/null +++ b/inference/core/entities/types.py @@ -0,0 +1,5 @@ +DatasetID = str +VersionID = str +TaskType = str +ModelType = str +WorkspaceID = str diff --git a/inference/core/env.py b/inference/core/env.py new file mode 100644 index 0000000000000000000000000000000000000000..c72e3f0dc72858207fb00d8b8bd8a275cb513746 --- /dev/null +++ b/inference/core/env.py @@ -0,0 +1,370 @@ +import os +import uuid + +from dotenv import load_dotenv + +from inference.core.utils.environment import safe_split_value, str2bool + +load_dotenv(os.getcwd() + "/.env") + +# The project name, default is "roboflow-platform" +PROJECT = os.getenv("PROJECT", "roboflow-platform") + +# Allow numpy input, default is True +ALLOW_NUMPY_INPUT = str2bool(os.getenv("ALLOW_NUMPY_INPUT", True)) + +# List of allowed origins +ALLOW_ORIGINS = os.getenv("ALLOW_ORIGINS", "") +ALLOW_ORIGINS = ALLOW_ORIGINS.split(",") + +# Base URL for the API +API_BASE_URL = os.getenv( + "API_BASE_URL", + ( + "https://api.roboflow.com" + if PROJECT == "roboflow-platform" + else "https://api.roboflow.one" + ), +) + +# Debug flag for the API, default is False +API_DEBUG = os.getenv("API_DEBUG", False) + +# API key, default is None +API_KEY_ENV_NAMES = ["ROBOFLOW_API_KEY", "API_KEY"] +API_KEY = os.getenv(API_KEY_ENV_NAMES[0], None) or os.getenv(API_KEY_ENV_NAMES[1], None) + +# AWS access key ID, default is None +AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID", None) + +# AWS secret access key, default is None +AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY", None) + +COGVLM_LOAD_4BIT = str2bool(os.getenv("COGVLM_LOAD_4BIT", True)) +COGVLM_LOAD_8BIT = str2bool(os.getenv("COGVLM_LOAD_8BIT", False)) +COGVLM_VERSION_ID = os.getenv("COGVLM_VERSION_ID", "cogvlm-chat-hf") +# CLIP version ID, default is "ViT-B-16" +CLIP_VERSION_ID = os.getenv("CLIP_VERSION_ID", "ViT-B-16") + +# CLIP model ID +CLIP_MODEL_ID = f"clip/{CLIP_VERSION_ID}" + +# Gaze version ID, default is "L2CS" +GAZE_VERSION_ID = os.getenv("GAZE_VERSION_ID", "L2CS") + +# Gaze model ID +GAZE_MODEL_ID = f"gaze/{CLIP_VERSION_ID}" + +# Maximum batch size for GAZE, default is 8 +GAZE_MAX_BATCH_SIZE = int(os.getenv("GAZE_MAX_BATCH_SIZE", 8)) + +# If true, this will store a non-verbose version of the inference request and repsonse in the cache +TINY_CACHE = str2bool(os.getenv("TINY_CACHE", True)) + +# Maximum batch size for CLIP, default is 8 +CLIP_MAX_BATCH_SIZE = int(os.getenv("CLIP_MAX_BATCH_SIZE", 8)) + +# Class agnostic NMS flag, default is False +CLASS_AGNOSTIC_NMS_ENV = "CLASS_AGNOSTIC_NMS" +DEFAULT_CLASS_AGNOSTIC_NMS = False +CLASS_AGNOSTIC_NMS = str2bool( + os.getenv(CLASS_AGNOSTIC_NMS_ENV, DEFAULT_CLASS_AGNOSTIC_NMS) +) + +# Confidence threshold, default is 50% +CONFIDENCE_ENV = "CONFIDENCE" +DEFAULT_CONFIDENCE = 0.4 +CONFIDENCE = float(os.getenv(CONFIDENCE_ENV, DEFAULT_CONFIDENCE)) + +# Flag to enable core models, default is True +CORE_MODELS_ENABLED = str2bool(os.getenv("CORE_MODELS_ENABLED", True)) + +# Flag to enable CLIP core model, default is True +CORE_MODEL_CLIP_ENABLED = str2bool(os.getenv("CORE_MODEL_CLIP_ENABLED", True)) + +# Flag to enable SAM core model, default is True +CORE_MODEL_SAM_ENABLED = str2bool(os.getenv("CORE_MODEL_SAM_ENABLED", True)) + +# Flag to enable GAZE core model, default is True +CORE_MODEL_GAZE_ENABLED = str2bool(os.getenv("CORE_MODEL_GAZE_ENABLED", True)) + +# Flag to enable DocTR core model, default is True +CORE_MODEL_DOCTR_ENABLED = str2bool(os.getenv("CORE_MODEL_DOCTR_ENABLED", True)) + +# Flag to enable GROUNDINGDINO core model, default is True +CORE_MODEL_GROUNDINGDINO_ENABLED = str2bool( + os.getenv("CORE_MODEL_GROUNDINGDINO_ENABLED", True) +) + +# Flag to enable CogVLM core model, default is True +CORE_MODEL_COGVLM_ENABLED = str2bool(os.getenv("CORE_MODEL_COGVLM_ENABLED", True)) + +# Flag to enable YOLO-World core model, default is True +CORE_MODEL_YOLO_WORLD_ENABLED = str2bool( + os.getenv("CORE_MODEL_YOLO_WORLD_ENABLED", True) +) + +# ID of host device, default is None +DEVICE_ID = os.getenv("DEVICE_ID", None) + +# Flag to disable inference cache, default is False +DISABLE_INFERENCE_CACHE = str2bool(os.getenv("DISABLE_INFERENCE_CACHE", False)) + +# Flag to disable auto-orientation preprocessing, default is False +DISABLE_PREPROC_AUTO_ORIENT = str2bool(os.getenv("DISABLE_PREPROC_AUTO_ORIENT", False)) + +# Flag to disable contrast preprocessing, default is False +DISABLE_PREPROC_CONTRAST = str2bool(os.getenv("DISABLE_PREPROC_CONTRAST", False)) + +# Flag to disable grayscale preprocessing, default is False +DISABLE_PREPROC_GRAYSCALE = str2bool(os.getenv("DISABLE_PREPROC_GRAYSCALE", False)) + +# Flag to disable static crop preprocessing, default is False +DISABLE_PREPROC_STATIC_CROP = str2bool(os.getenv("DISABLE_PREPROC_STATIC_CROP", False)) + +# Flag to disable version check, default is False +DISABLE_VERSION_CHECK = str2bool(os.getenv("DISABLE_VERSION_CHECK", False)) + +# ElastiCache endpoint +ELASTICACHE_ENDPOINT = os.environ.get( + "ELASTICACHE_ENDPOINT", + ( + "roboflow-infer-prod.ljzegl.cfg.use2.cache.amazonaws.com:11211" + if PROJECT == "roboflow-platform" + else "roboflow-infer.ljzegl.cfg.use2.cache.amazonaws.com:11211" + ), +) + +# Flag to enable byte track, default is False +ENABLE_BYTE_TRACK = str2bool(os.getenv("ENABLE_BYTE_TRACK", False)) + +# Flag to enforce FPS, default is False +ENFORCE_FPS = str2bool(os.getenv("ENFORCE_FPS", False)) +MAX_FPS = os.getenv("MAX_FPS") +if MAX_FPS is not None: + MAX_FPS = int(MAX_FPS) + +# Flag to fix batch size, default is False +FIX_BATCH_SIZE = str2bool(os.getenv("FIX_BATCH_SIZE", False)) + +# Host, default is "0.0.0.0" +HOST = os.getenv("HOST", "0.0.0.0") + +# IoU threshold, default is 0.3 +IOU_THRESHOLD_ENV = "IOU_THRESHOLD" +DEFAULT_IOU_THRESHOLD = 0.3 +IOU_THRESHOLD = float(os.getenv(IOU_THRESHOLD_ENV, DEFAULT_IOU_THRESHOLD)) + +# IP broadcast address, default is "127.0.0.1" +IP_BROADCAST_ADDR = os.getenv("IP_BROADCAST_ADDR", "127.0.0.1") + +# IP broadcast port, default is 37020 +IP_BROADCAST_PORT = int(os.getenv("IP_BROADCAST_PORT", 37020)) + +# Flag to enable JSON response, default is True +JSON_RESPONSE = str2bool(os.getenv("JSON_RESPONSE", True)) + +# Lambda flag, default is False +LAMBDA = str2bool(os.getenv("LAMBDA", False)) + +# Flag to enable legacy route, default is True +LEGACY_ROUTE_ENABLED = str2bool(os.getenv("LEGACY_ROUTE_ENABLED", True)) + +# License server, default is None +LICENSE_SERVER = os.getenv("LICENSE_SERVER", None) + +# Log level, default is "INFO" +LOG_LEVEL = os.getenv("LOG_LEVEL", "WARNING") + +# Maximum number of active models, default is 8 +MAX_ACTIVE_MODELS = int(os.getenv("MAX_ACTIVE_MODELS", 8)) + +# Maximum batch size, default is infinite +MAX_BATCH_SIZE = os.getenv("MAX_BATCH_SIZE", None) +if MAX_BATCH_SIZE is not None: + MAX_BATCH_SIZE = int(MAX_BATCH_SIZE) +else: + MAX_BATCH_SIZE = float("inf") + +# Maximum number of candidates, default is 3000 +MAX_CANDIDATES_ENV = "MAX_CANDIDATES" +DEFAULT_MAX_CANDIDATES = 3000 +MAX_CANDIDATES = int(os.getenv(MAX_CANDIDATES_ENV, DEFAULT_MAX_CANDIDATES)) + +# Maximum number of detections, default is 300 +MAX_DETECTIONS_ENV = "MAX_DETECTIONS" +DEFAULT_MAX_DETECTIONS = 300 +MAX_DETECTIONS = int(os.getenv(MAX_DETECTIONS_ENV, DEFAULT_MAX_DETECTIONS)) + +# Loop interval for expiration of memory cache, default is 5 +MEMORY_CACHE_EXPIRE_INTERVAL = int(os.getenv("MEMORY_CACHE_EXPIRE_INTERVAL", 5)) + +# Metrics enabled flag, default is True +METRICS_ENABLED = str2bool(os.getenv("METRICS_ENABLED", True)) +if LAMBDA: + METRICS_ENABLED = False + +# Interval for metrics aggregation, default is 60 +METRICS_INTERVAL = int(os.getenv("METRICS_INTERVAL", 60)) + +# URL for posting metrics to Roboflow API, default is "{API_BASE_URL}/inference-stats" +METRICS_URL = os.getenv("METRICS_URL", f"{API_BASE_URL}/inference-stats") + +# Model cache directory, default is "/tmp/cache" +MODEL_CACHE_DIR = os.getenv("MODEL_CACHE_DIR", "/tmp/cache") + +# Model ID, default is None +MODEL_ID = os.getenv("MODEL_ID") + +# Enable jupyter notebook server route, default is False +NOTEBOOK_ENABLED = str2bool(os.getenv("NOTEBOOK_ENABLED", False)) + +# Jupyter notebook password, default is "roboflow" +NOTEBOOK_PASSWORD = os.getenv("NOTEBOOK_PASSWORD", "roboflow") + +# Jupyter notebook port, default is 9002 +NOTEBOOK_PORT = int(os.getenv("NOTEBOOK_PORT", 9002)) + +# Number of workers, default is 1 +NUM_WORKERS = int(os.getenv("NUM_WORKERS", 1)) + +ONNXRUNTIME_EXECUTION_PROVIDERS = os.getenv( + "ONNXRUNTIME_EXECUTION_PROVIDERS", "[CUDAExecutionProvider,CPUExecutionProvider]" +) + +# Port, default is 9001 +PORT = int(os.getenv("PORT", 9001)) + +# Profile flag, default is False +PROFILE = str2bool(os.getenv("PROFILE", False)) + +# Redis host, default is None +REDIS_HOST = os.getenv("REDIS_HOST", None) + +# Redis port, default is 6379 +REDIS_PORT = int(os.getenv("REDIS_PORT", 6379)) +REDIS_SSL = str2bool(os.getenv("REDIS_SSL", False)) +REDIS_TIMEOUT = float(os.getenv("REDIS_TIMEOUT", 2.0)) + +# Required ONNX providers, default is None +REQUIRED_ONNX_PROVIDERS = safe_split_value(os.getenv("REQUIRED_ONNX_PROVIDERS", None)) + +# Roboflow server UUID +ROBOFLOW_SERVER_UUID = os.getenv("ROBOFLOW_SERVER_UUID", str(uuid.uuid4())) + +# Roboflow service secret, default is None +ROBOFLOW_SERVICE_SECRET = os.getenv("ROBOFLOW_SERVICE_SECRET", None) + +# Maximum embedding cache size for SAM, default is 10 +SAM_MAX_EMBEDDING_CACHE_SIZE = int(os.getenv("SAM_MAX_EMBEDDING_CACHE_SIZE", 10)) + +# SAM version ID, default is "vit_h" +SAM_VERSION_ID = os.getenv("SAM_VERSION_ID", "vit_h") + + +# Device ID, default is "sample-device-id" +INFERENCE_SERVER_ID = os.getenv("INFERENCE_SERVER_ID", None) + +# Stream ID, default is None +STREAM_ID = os.getenv("STREAM_ID") +try: + STREAM_ID = int(STREAM_ID) +except (TypeError, ValueError): + pass + +# Tags used for device management +TAGS = safe_split_value(os.getenv("TAGS", "")) + +# TensorRT cache path, default is MODEL_CACHE_DIR +TENSORRT_CACHE_PATH = os.getenv("TENSORRT_CACHE_PATH", MODEL_CACHE_DIR) + +# Set TensorRT cache path +os.environ["ORT_TENSORRT_CACHE_PATH"] = TENSORRT_CACHE_PATH + +# Version check mode, one of "once" or "continuous", default is "once" +VERSION_CHECK_MODE = os.getenv("VERSION_CHECK_MODE", "once") + +# Metlo key, default is None +METLO_KEY = os.getenv("METLO_KEY", None) + +# Core model bucket +CORE_MODEL_BUCKET = os.getenv( + "CORE_MODEL_BUCKET", + ( + "roboflow-core-model-prod" + if PROJECT == "roboflow-platform" + else "roboflow-core-model-staging" + ), +) + +# Inference bucket +INFER_BUCKET = os.getenv( + "INFER_BUCKET", + ( + "roboflow-infer-prod" + if PROJECT == "roboflow-platform" + else "roboflow-infer-staging" + ), +) + +ACTIVE_LEARNING_ENABLED = str2bool(os.getenv("ACTIVE_LEARNING_ENABLED", True)) +ACTIVE_LEARNING_TAGS = safe_split_value(os.getenv("ACTIVE_LEARNING_TAGS", None)) + +# Number inflight async tasks for async model manager +NUM_PARALLEL_TASKS = int(os.getenv("NUM_PARALLEL_TASKS", 512)) +STUB_CACHE_SIZE = int(os.getenv("STUB_CACHE_SIZE", 256)) +# New stream interface variables +PREDICTIONS_QUEUE_SIZE = int( + os.getenv("INFERENCE_PIPELINE_PREDICTIONS_QUEUE_SIZE", 512) +) +RESTART_ATTEMPT_DELAY = int(os.getenv("INFERENCE_PIPELINE_RESTART_ATTEMPT_DELAY", 1)) +DEFAULT_BUFFER_SIZE = int(os.getenv("VIDEO_SOURCE_BUFFER_SIZE", "64")) +DEFAULT_ADAPTIVE_MODE_STREAM_PACE_TOLERANCE = float( + os.getenv("VIDEO_SOURCE_ADAPTIVE_MODE_STREAM_PACE_TOLERANCE", "0.1") +) +DEFAULT_ADAPTIVE_MODE_READER_PACE_TOLERANCE = float( + os.getenv("VIDEO_SOURCE_ADAPTIVE_MODE_READER_PACE_TOLERANCE", "5.0") +) +DEFAULT_MINIMUM_ADAPTIVE_MODE_SAMPLES = int( + os.getenv("VIDEO_SOURCE_MINIMUM_ADAPTIVE_MODE_SAMPLES", "10") +) +DEFAULT_MAXIMUM_ADAPTIVE_FRAMES_DROPPED_IN_ROW = int( + os.getenv("VIDEO_SOURCE_MAXIMUM_ADAPTIVE_FRAMES_DROPPED_IN_ROW", "16") +) + +NUM_CELERY_WORKERS = os.getenv("NUM_CELERY_WORKERS", 4) +CELERY_LOG_LEVEL = os.getenv("CELERY_LOG_LEVEL", "WARNING") + + +LOCAL_INFERENCE_API_URL = os.getenv("LOCAL_INFERENCE_API_URL", "http://127.0.0.1:9001") +HOSTED_DETECT_URL = ( + "https://detect.roboflow.com" + if PROJECT == "roboflow-platform" + else "https://lambda-object-detection.staging.roboflow.com" +) +HOSTED_INSTANCE_SEGMENTATION_URL = ( + "https://outline.roboflow.com" + if PROJECT == "roboflow-platform" + else "https://lambda-instance-segmentation.staging.roboflow.com" +) +HOSTED_CLASSIFICATION_URL = ( + "https://classify.roboflow.com" + if PROJECT == "roboflow-platform" + else "https://lambda-classification.staging.roboflow.com" +) +HOSTED_CORE_MODEL_URL = ( + "https://infer.roboflow.com" + if PROJECT == "roboflow-platform" + else "https://3hkaykeh3j.execute-api.us-east-1.amazonaws.com" +) + +DISABLE_WORKFLOW_ENDPOINTS = str2bool(os.getenv("DISABLE_WORKFLOW_ENDPOINTS", False)) +WORKFLOWS_STEP_EXECUTION_MODE = os.getenv("WORKFLOWS_STEP_EXECUTION_MODE", "remote") +WORKFLOWS_REMOTE_API_TARGET = os.getenv("WORKFLOWS_REMOTE_API_TARGET", "hosted") +WORKFLOWS_MAX_CONCURRENT_STEPS = int(os.getenv("WORKFLOWS_MAX_CONCURRENT_STEPS", "8")) +WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE = int( + os.getenv("WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE", "1") +) +WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS = int( + os.getenv("WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS", "8") +) diff --git a/inference/core/exceptions.py b/inference/core/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..2d71ffe9cd7125c6c9bbb14e8bb4ed8ef9130b32 --- /dev/null +++ b/inference/core/exceptions.py @@ -0,0 +1,182 @@ +class ContentTypeInvalid(Exception): + """Raised when the content type is invalid. + + Attributes: + message (str): Optional message describing the error. + """ + + +class ContentTypeMissing(Exception): + """Raised when the content type is missing. + + Attributes: + message (str): Optional message describing the error. + """ + + +class EngineIgnitionFailure(Exception): + """Raised when the engine fails to ignite. + + Attributes: + message (str): Optional message describing the error. + """ + + +class InferenceModelNotFound(Exception): + """Raised when the inference model is not found. + + Attributes: + message (str): Optional message describing the error. + """ + + +class InvalidEnvironmentVariableError(Exception): + """Raised when an environment variable is invalid. + + Attributes: + message (str): Optional message describing the error. + """ + + +class InvalidMaskDecodeArgument(Exception): + """Raised when an invalid argument is provided for mask decoding. + + Attributes: + message (str): Optional message describing the error. + """ + + +class MissingApiKeyError(Exception): + """Raised when the API key is missing. + + Attributes: + message (str): Optional message describing the error. + """ + + +class MissingServiceSecretError(Exception): + """Raised when the service secret is missing. + + Attributes: + message (str): Optional message describing the error. + """ + + +class OnnxProviderNotAvailable(Exception): + """Raised when the ONNX provider is not available. + + Attributes: + message (str): Optional message describing the error. + """ + + +class WorkspaceLoadError(Exception): + """Raised when there is an error loading the workspace. + + Attributes: + message (str): Optional message describing the error. + """ + + +class InputImageLoadError(Exception): + pass + + +class InvalidNumpyInput(InputImageLoadError): + """Raised when the input is an invalid NumPy array. + + Attributes: + message (str): Optional message describing the error. + """ + + +class InvalidImageTypeDeclared(InputImageLoadError): + pass + + +class InputFormatInferenceFailed(InputImageLoadError): + pass + + +class PreProcessingError(Exception): + pass + + +class PostProcessingError(Exception): + pass + + +class InvalidModelIDError(Exception): + pass + + +class MalformedRoboflowAPIResponseError(Exception): + pass + + +class ServiceConfigurationError(Exception): + pass + + +class MissingDefaultModelError(ServiceConfigurationError): + pass + + +class ModelNotRecognisedError(ServiceConfigurationError): + pass + + +class RoboflowAPIRequestError(Exception): + pass + + +class RoboflowAPIUnsuccessfulRequestError(RoboflowAPIRequestError): + pass + + +class RoboflowAPINotAuthorizedError(RoboflowAPIUnsuccessfulRequestError): + pass + + +class RoboflowAPINotNotFoundError(RoboflowAPIUnsuccessfulRequestError): + pass + + +class RoboflowAPIConnectionError(RoboflowAPIRequestError): + pass + + +class RoboflowAPIImageUploadRejectionError(RoboflowAPIRequestError): + pass + + +class RoboflowAPIIAnnotationRejectionError(RoboflowAPIRequestError): + pass + + +class MalformedWorkflowResponseError(RoboflowAPIRequestError): + pass + + +class RoboflowAPIIAlreadyAnnotatedError(RoboflowAPIIAnnotationRejectionError): + pass + + +class ModelArtefactError(Exception): + pass + + +class ActiveLearningError(Exception): + pass + + +class PredictionFormatNotSupported(ActiveLearningError): + pass + + +class ActiveLearningConfigurationDecodingError(ActiveLearningError): + pass + + +class ActiveLearningConfigurationError(ActiveLearningError): + pass diff --git a/inference/core/interfaces/__init__.py b/inference/core/interfaces/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/core/interfaces/__pycache__/__init__.cpython-310.pyc b/inference/core/interfaces/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a5303ca13991774c9d6f35f8811805ba545ab76 Binary files /dev/null and b/inference/core/interfaces/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/interfaces/__pycache__/base.cpython-310.pyc b/inference/core/interfaces/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08759002b41054bde751a38dce0af6152d5cf733 Binary files /dev/null and b/inference/core/interfaces/__pycache__/base.cpython-310.pyc differ diff --git a/inference/core/interfaces/base.py b/inference/core/interfaces/base.py new file mode 100644 index 0000000000000000000000000000000000000000..cc298c1947a28ac0bccb083a883198f9ffffb8bb --- /dev/null +++ b/inference/core/interfaces/base.py @@ -0,0 +1,8 @@ +from inference.core.managers.base import ModelManager + + +class BaseInterface: + """Base interface class which accepts a model manager on initialization""" + + def __init__(self, model_manager: ModelManager) -> None: + self.model_manager = model_manager diff --git a/inference/core/interfaces/camera/__init__.py b/inference/core/interfaces/camera/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/core/interfaces/camera/__pycache__/__init__.cpython-310.pyc b/inference/core/interfaces/camera/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21c807940698ed0f46e0569909c55b354f00818a Binary files /dev/null and b/inference/core/interfaces/camera/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/interfaces/camera/__pycache__/camera.cpython-310.pyc b/inference/core/interfaces/camera/__pycache__/camera.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f224c0cf3aefd28b35167f3959c2978c20853de Binary files /dev/null and b/inference/core/interfaces/camera/__pycache__/camera.cpython-310.pyc differ diff --git a/inference/core/interfaces/camera/__pycache__/entities.cpython-310.pyc b/inference/core/interfaces/camera/__pycache__/entities.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6bc0e61f47acea3d84dcdcd4c5cd3b1803f1765 Binary files /dev/null and b/inference/core/interfaces/camera/__pycache__/entities.cpython-310.pyc differ diff --git a/inference/core/interfaces/camera/__pycache__/exceptions.cpython-310.pyc b/inference/core/interfaces/camera/__pycache__/exceptions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27f36d45c6352183dd779d9aec90e649d5be5758 Binary files /dev/null and b/inference/core/interfaces/camera/__pycache__/exceptions.cpython-310.pyc differ diff --git a/inference/core/interfaces/camera/__pycache__/utils.cpython-310.pyc b/inference/core/interfaces/camera/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c13f2004903471afad9f91184fb14466f9e91360 Binary files /dev/null and b/inference/core/interfaces/camera/__pycache__/utils.cpython-310.pyc differ diff --git a/inference/core/interfaces/camera/__pycache__/video_source.cpython-310.pyc b/inference/core/interfaces/camera/__pycache__/video_source.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d6d2d5cea13fbc8dd0071a9fa858280dba193d5 Binary files /dev/null and b/inference/core/interfaces/camera/__pycache__/video_source.cpython-310.pyc differ diff --git a/inference/core/interfaces/camera/camera.py b/inference/core/interfaces/camera/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..b5fecbf162eda73a2427b05879843668c3902ecf --- /dev/null +++ b/inference/core/interfaces/camera/camera.py @@ -0,0 +1,137 @@ +import os +import time +from threading import Thread + +import cv2 +from PIL import Image + +from inference.core.logger import logger + + +class WebcamStream: + """Class to handle webcam streaming using a separate thread. + + Attributes: + stream_id (int): The ID of the webcam stream. + frame_id (int): A counter for the current frame. + vcap (VideoCapture): OpenCV video capture object. + width (int): The width of the video frame. + height (int): The height of the video frame. + fps_input_stream (int): Frames per second of the input stream. + grabbed (bool): A flag indicating if a frame was successfully grabbed. + frame (array): The current frame as a NumPy array. + pil_image (Image): The current frame as a PIL image. + stopped (bool): A flag indicating if the stream is stopped. + t (Thread): The thread used to update the stream. + """ + + def __init__(self, stream_id=0, enforce_fps=False): + """Initialize the webcam stream. + + Args: + stream_id (int, optional): The ID of the webcam stream. Defaults to 0. + """ + self.stream_id = stream_id + self.enforce_fps = enforce_fps + self.frame_id = 0 + self.vcap = cv2.VideoCapture(self.stream_id) + + for key in os.environ: + if key.startswith("CV2_CAP_PROP"): + opencv_prop = key[4:] + opencv_constant = getattr(cv2, opencv_prop, None) + if opencv_constant is not None: + value = int(os.getenv(key)) + self.vcap.set(opencv_constant, value) + logger.info(f"set {opencv_prop} to {value}") + else: + logger.warn(f"Property {opencv_prop} not found in cv2") + + self.width = int(self.vcap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.height = int(self.vcap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + self.file_mode = self.vcap.get(cv2.CAP_PROP_FRAME_COUNT) > 0 + if self.enforce_fps and not self.file_mode: + logger.warn( + "Ignoring enforce_fps flag for this stream. It is not compatible with streams and will cause the process to crash" + ) + self.enforce_fps = False + self.max_fps = None + if self.vcap.isOpened() is False: + logger.debug("[Exiting]: Error accessing webcam stream.") + exit(0) + self.fps_input_stream = int(self.vcap.get(cv2.CAP_PROP_FPS)) + logger.debug( + "FPS of webcam hardware/input stream: {}".format(self.fps_input_stream) + ) + self.grabbed, self.frame = self.vcap.read() + self.pil_image = Image.fromarray(cv2.cvtColor(self.frame, cv2.COLOR_BGR2RGB)) + if self.grabbed is False: + logger.debug("[Exiting] No more frames to read") + exit(0) + self.stopped = True + self.t = Thread(target=self.update, args=()) + self.t.daemon = True + + def start(self): + """Start the thread for reading frames.""" + self.stopped = False + self.t.start() + + def update(self): + """Update the frame by reading from the webcam.""" + frame_id = 0 + next_frame_time = 0 + t0 = time.perf_counter() + while True: + t1 = time.perf_counter() + if self.stopped is True: + break + + self.grabbed = self.vcap.grab() + if self.grabbed is False: + logger.debug("[Exiting] No more frames to read") + self.stopped = True + break + frame_id += 1 + # We can't retrieve each frame on nano and other lower powered devices quickly enough to keep up with the stream. + # By default, we will only retrieve frames when we'll be ready process them (determined by self.max_fps). + if t1 > next_frame_time: + ret, frame = self.vcap.retrieve() + if frame is None: + logger.debug("[Exiting] Frame not available for read") + self.stopped = True + break + logger.debug( + f"retrieved frame {frame_id}, effective FPS: {frame_id / (t1 - t0):.2f}" + ) + self.frame_id = frame_id + self.frame = frame + while self.file_mode and self.enforce_fps and self.max_fps is None: + # sleep until we have processed the first frame and we know what our FPS should be + time.sleep(0.01) + if self.max_fps is None: + self.max_fps = 30 + next_frame_time = t1 + (1 / self.max_fps) + 0.02 + if self.file_mode: + t2 = time.perf_counter() + if self.enforce_fps: + # when enforce_fps is true, grab video frames 1:1 with inference speed + time_to_sleep = next_frame_time - t2 + else: + # otherwise, grab at native FPS of the video file + time_to_sleep = (1 / self.fps_input_stream) - (t2 - t1) + if time_to_sleep > 0: + time.sleep(time_to_sleep) + self.vcap.release() + + def read_opencv(self): + """Read the current frame using OpenCV. + + Returns: + array, int: The current frame as a NumPy array, and the frame ID. + """ + return self.frame, self.frame_id + + def stop(self): + """Stop the webcam stream.""" + self.stopped = True diff --git a/inference/core/interfaces/camera/entities.py b/inference/core/interfaces/camera/entities.py new file mode 100644 index 0000000000000000000000000000000000000000..55844c6dae1acad9cc10a62038600e70a8728756 --- /dev/null +++ b/inference/core/interfaces/camera/entities.py @@ -0,0 +1,59 @@ +import logging +from dataclasses import dataclass +from datetime import datetime +from enum import Enum + +import numpy as np + +FrameTimestamp = datetime +FrameID = int + + +class UpdateSeverity(Enum): + """Enumeration for defining different levels of update severity. + + Attributes: + DEBUG (int): A debugging severity level. + INFO (int): An informational severity level. + WARNING (int): A warning severity level. + ERROR (int): An error severity level. + """ + + DEBUG = logging.DEBUG + INFO = logging.INFO + WARNING = logging.WARNING + ERROR = logging.ERROR + + +@dataclass(frozen=True) +class StatusUpdate: + """Represents a status update event in the system. + + Attributes: + timestamp (datetime): The timestamp when the status update was created. + severity (UpdateSeverity): The severity level of the update. + event_type (str): A string representing the type of the event. + payload (dict): A dictionary containing data relevant to the update. + context (str): A string providing additional context about the update. + """ + + timestamp: datetime + severity: UpdateSeverity + event_type: str + payload: dict + context: str + + +@dataclass(frozen=True) +class VideoFrame: + """Represents a single frame of video data. + + Attributes: + image (np.ndarray): The image data of the frame as a NumPy array. + frame_id (FrameID): A unique identifier for the frame. + frame_timestamp (FrameTimestamp): The timestamp when the frame was captured. + """ + + image: np.ndarray + frame_id: FrameID + frame_timestamp: FrameTimestamp diff --git a/inference/core/interfaces/camera/exceptions.py b/inference/core/interfaces/camera/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..f5d356e9a7f613a3a27679f67d5b5c0dd06970bb --- /dev/null +++ b/inference/core/interfaces/camera/exceptions.py @@ -0,0 +1,14 @@ +class StreamError(Exception): + pass + + +class StreamOperationNotAllowedError(StreamError): + pass + + +class EndOfStreamError(StreamError): + pass + + +class SourceConnectionError(StreamError): + pass diff --git a/inference/core/interfaces/camera/utils.py b/inference/core/interfaces/camera/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9882ac1f345c331d99d5d35cab97f3a83a7906d6 --- /dev/null +++ b/inference/core/interfaces/camera/utils.py @@ -0,0 +1,116 @@ +import time +from enum import Enum +from typing import Generator, Iterable, Optional, Tuple, Union + +import numpy as np + +from inference.core.interfaces.camera.entities import ( + FrameID, + FrameTimestamp, + VideoFrame, +) +from inference.core.interfaces.camera.video_source import SourceProperties, VideoSource + +MINIMAL_FPS = 0.01 + + +class FPSLimiterStrategy(Enum): + DROP = "drop" + WAIT = "wait" + + +def get_video_frames_generator( + video: Union[VideoSource, str, int], + max_fps: Optional[Union[float, int]] = None, + limiter_strategy: Optional[FPSLimiterStrategy] = None, +) -> Generator[VideoFrame, None, None]: + """ + Util function to create a frames generator from `VideoSource` with possibility to + limit FPS of consumed frames and dictate what to do if frames are produced to fast. + + Args: + video (Union[VideoSource, str, int]): Either instance of VideoSource or video reference accepted + by VideoSource.init(...) + max_fps (Optional[Union[float, int]]): value of maximum FPS rate of generated frames - can be used to limit + generation frequency + limiter_strategy (Optional[FPSLimiterStrategy]): strategy used to deal with frames decoding exceeding + limit of `max_fps`. By default - for files, in the interest of processing all frames - + generation will be awaited, for streams - frames will be dropped on the floor. + Returns: generator of `VideoFrame` + + Example: + ```python + for frame in get_video_frames_generator( + video="./some.mp4", + max_fps=50, + ): + pass + ``` + """ + if issubclass(type(video), str) or issubclass(type(video), int): + video = VideoSource.init( + video_reference=video, + ) + video.start() + if max_fps is None: + yield from video + return None + limiter_strategy = resolve_limiter_strategy( + explicitly_defined_strategy=limiter_strategy, + source_properties=video.describe_source().source_properties, + ) + yield from limit_frame_rate( + frames_generator=video, max_fps=max_fps, strategy=limiter_strategy + ) + + +def resolve_limiter_strategy( + explicitly_defined_strategy: Optional[FPSLimiterStrategy], + source_properties: Optional[SourceProperties], +) -> FPSLimiterStrategy: + if explicitly_defined_strategy is not None: + return explicitly_defined_strategy + limiter_strategy = FPSLimiterStrategy.DROP + if source_properties is not None and source_properties.is_file: + limiter_strategy = FPSLimiterStrategy.WAIT + return limiter_strategy + + +def limit_frame_rate( + frames_generator: Iterable[Tuple[FrameTimestamp, FrameID, np.ndarray]], + max_fps: Union[float, int], + strategy: FPSLimiterStrategy, +) -> Generator[Tuple[FrameTimestamp, FrameID, np.ndarray], None, None]: + rate_limiter = RateLimiter(desired_fps=max_fps) + for frame_data in frames_generator: + delay = rate_limiter.estimate_next_action_delay() + if delay <= 0.0: + rate_limiter.tick() + yield frame_data + continue + if strategy is FPSLimiterStrategy.WAIT: + time.sleep(delay) + rate_limiter.tick() + yield frame_data + + +class RateLimiter: + """ + Implements rate upper-bound rate limiting by ensuring estimate_next_tick_delay() + to be at min 1 / desired_fps, not letting the client obeying outcomes to exceed + assumed rate. + """ + + def __init__(self, desired_fps: Union[float, int]): + self._desired_fps = max(desired_fps, MINIMAL_FPS) + self._last_tick: Optional[float] = None + + def tick(self) -> None: + self._last_tick = time.monotonic() + + def estimate_next_action_delay(self) -> float: + if self._last_tick is None: + return 0.0 + desired_delay = 1 / self._desired_fps + time_since_last_tick = time.monotonic() - self._last_tick + return max(desired_delay - time_since_last_tick, 0.0) diff --git a/inference/core/interfaces/camera/video_source.py b/inference/core/interfaces/camera/video_source.py new file mode 100644 index 0000000000000000000000000000000000000000..f251e5ee89f9e2983c1b217fb1a27fcdbdfb2a1e --- /dev/null +++ b/inference/core/interfaces/camera/video_source.py @@ -0,0 +1,1006 @@ +import time +from dataclasses import dataclass +from datetime import datetime +from enum import Enum +from queue import Empty, Queue +from threading import Event, Lock, Thread +from typing import Any, Callable, List, Optional, Protocol, Union + +import cv2 +import supervision as sv + +from inference.core import logger +from inference.core.env import ( + DEFAULT_ADAPTIVE_MODE_READER_PACE_TOLERANCE, + DEFAULT_ADAPTIVE_MODE_STREAM_PACE_TOLERANCE, + DEFAULT_BUFFER_SIZE, + DEFAULT_MAXIMUM_ADAPTIVE_FRAMES_DROPPED_IN_ROW, + DEFAULT_MINIMUM_ADAPTIVE_MODE_SAMPLES, +) +from inference.core.interfaces.camera.entities import ( + StatusUpdate, + UpdateSeverity, + VideoFrame, +) +from inference.core.interfaces.camera.exceptions import ( + EndOfStreamError, + SourceConnectionError, + StreamOperationNotAllowedError, +) + +VIDEO_SOURCE_CONTEXT = "video_source" +VIDEO_CONSUMER_CONTEXT = "video_consumer" +SOURCE_STATE_UPDATE_EVENT = "SOURCE_STATE_UPDATE" +SOURCE_ERROR_EVENT = "SOURCE_ERROR" +FRAME_CAPTURED_EVENT = "FRAME_CAPTURED" +FRAME_DROPPED_EVENT = "FRAME_DROPPED" +FRAME_CONSUMED_EVENT = "FRAME_CONSUMED" +VIDEO_CONSUMPTION_STARTED_EVENT = "VIDEO_CONSUMPTION_STARTED" +VIDEO_CONSUMPTION_FINISHED_EVENT = "VIDEO_CONSUMPTION_FINISHED" + + +class StreamState(Enum): + NOT_STARTED = "NOT_STARTED" + INITIALISING = "INITIALISING" + RESTARTING = "RESTARTING" + RUNNING = "RUNNING" + PAUSED = "PAUSED" + MUTED = "MUTED" + TERMINATING = "TERMINATING" + ENDED = "ENDED" + ERROR = "ERROR" + + +START_ELIGIBLE_STATES = { + StreamState.NOT_STARTED, + StreamState.RESTARTING, + StreamState.ENDED, +} +PAUSE_ELIGIBLE_STATES = {StreamState.RUNNING} +MUTE_ELIGIBLE_STATES = {StreamState.RUNNING} +RESUME_ELIGIBLE_STATES = {StreamState.PAUSED, StreamState.MUTED} +TERMINATE_ELIGIBLE_STATES = { + StreamState.MUTED, + StreamState.RUNNING, + StreamState.PAUSED, + StreamState.RESTARTING, + StreamState.ENDED, + StreamState.ERROR, +} +RESTART_ELIGIBLE_STATES = { + StreamState.MUTED, + StreamState.RUNNING, + StreamState.PAUSED, + StreamState.ENDED, + StreamState.ERROR, +} + + +class BufferFillingStrategy(Enum): + WAIT = "WAIT" + DROP_OLDEST = "DROP_OLDEST" + ADAPTIVE_DROP_OLDEST = "ADAPTIVE_DROP_OLDEST" + DROP_LATEST = "DROP_LATEST" + ADAPTIVE_DROP_LATEST = "ADAPTIVE_DROP_LATEST" + + +ADAPTIVE_STRATEGIES = { + BufferFillingStrategy.ADAPTIVE_DROP_LATEST, + BufferFillingStrategy.ADAPTIVE_DROP_OLDEST, +} +DROP_OLDEST_STRATEGIES = { + BufferFillingStrategy.DROP_OLDEST, + BufferFillingStrategy.ADAPTIVE_DROP_OLDEST, +} + + +class BufferConsumptionStrategy(Enum): + LAZY = "LAZY" + EAGER = "EAGER" + + +@dataclass(frozen=True) +class SourceProperties: + width: int + height: int + total_frames: int + is_file: bool + fps: float + + +@dataclass(frozen=True) +class SourceMetadata: + source_properties: Optional[SourceProperties] + source_reference: str + buffer_size: int + state: StreamState + buffer_filling_strategy: Optional[BufferFillingStrategy] + buffer_consumption_strategy: Optional[BufferConsumptionStrategy] + + +class VideoSourceMethod(Protocol): + def __call__(self, video_source: "VideoSource", *args, **kwargs) -> None: ... + + +def lock_state_transition( + method: VideoSourceMethod, +) -> Callable[["VideoSource"], None]: + def locked_executor(video_source: "VideoSource", *args, **kwargs) -> None: + with video_source._state_change_lock: + return method(video_source, *args, **kwargs) + + return locked_executor + + +class VideoSource: + @classmethod + def init( + cls, + video_reference: Union[str, int], + buffer_size: int = DEFAULT_BUFFER_SIZE, + status_update_handlers: Optional[List[Callable[[StatusUpdate], None]]] = None, + buffer_filling_strategy: Optional[BufferFillingStrategy] = None, + buffer_consumption_strategy: Optional[BufferConsumptionStrategy] = None, + adaptive_mode_stream_pace_tolerance: float = DEFAULT_ADAPTIVE_MODE_STREAM_PACE_TOLERANCE, + adaptive_mode_reader_pace_tolerance: float = DEFAULT_ADAPTIVE_MODE_READER_PACE_TOLERANCE, + minimum_adaptive_mode_samples: int = DEFAULT_MINIMUM_ADAPTIVE_MODE_SAMPLES, + maximum_adaptive_frames_dropped_in_row: int = DEFAULT_MAXIMUM_ADAPTIVE_FRAMES_DROPPED_IN_ROW, + ): + """ + This class is meant to represent abstraction over video sources - both video files and + on-line streams that are possible to be consumed and used by other components of `inference` + library. + + Before digging into details of the class behaviour, it is advised to familiarise with the following + concepts and implementation assumptions: + + 1. Video file can be accessed from local (or remote) storage by the consumer in a pace dictated by + its processing capabilities. If processing is faster than the frame rate of video, operations + may be executed in a time shorter than the time of video playback. In the opposite case - consumer + may freely decode and process frames in its own pace, without risk for failures due to temporal + dependencies of processing - this is classical offline processing example. + 2. Video streams, on the other hand, usually need to be consumed in a pace near to their frame-rate - + in other words - this is on-line processing example. Consumer being faster than incoming stream + frames cannot utilise its resources to the full extent as not-yet-delivered data would be needed. + Slow consumer, however, may not be able to process everything on time and to keep up with the pace + of stream - some frames would need to be dropped. Otherwise - over time, consumer could go out of + sync with the stream causing decoding failures or unpredictable behavior. + + To fit those two types of video sources, `VideoSource` introduces the concept of buffered decoding of + video stream (like at the YouTube - player buffers some frames that are soon to be displayed). + The way on how buffer is filled and consumed dictates the behavior of `VideoSource`. + + Starting from `BufferFillingStrategy` - we have 3 basic options: + * WAIT: in case of slow video consumption, when buffer is full - `VideoSource` will wait for + the empty spot in buffer before next frame will be processed - this is suitable in cases when + we want to ensure EACH FRAME of the video to be processed + * DROP_OLDEST: when buffer is full, the frame that sits there for the longest time will be dropped - + this is suitable for cases when we want to process the most recent frames possible + * DROP_LATEST: when buffer is full, the newly decoded frame is dropped - useful in cases when + it is expected to have processing performance drops, but we would like to consume portions of + video that are locally smooth - but this is probably the least common use-case. + + On top of that - there are two ADAPTIVE strategies: ADAPTIVE_DROP_OLDEST and ADAPTIVE_DROP_LATEST, + which are equivalent to DROP_OLDEST and DROP_LATEST with adaptive decoding feature enabled. The notion + of that mode will be described later. + + Naturally, decoded frames must also be consumed. `VideoSource` provides a handy interface for reading + a video source frames by a SINGLE consumer. Consumption strategy can also be dictated via + `BufferConsumptionStrategy`: + * LAZY - consume all the frames from decoding buffer one-by-one + * EAGER - at each readout - take all frames already buffered, drop all of them apart from the most recent + + In consequence - there are various combinations of `BufferFillingStrategy` and `BufferConsumptionStrategy`. + The most popular would be: + * `BufferFillingStrategy.WAIT` and `BufferConsumptionStrategy.LAZY` - to always decode and process each and + every frame of the source (useful while processing video files - and default behaviour enforced by + `inference` if there is no explicit configuration) + * `BufferFillingStrategy.DROP_OLDEST` and `BufferConsumptionStrategy.EAGER` - to always process the most + recent frames of source (useful while processing video streams when low latency [real-time experience] + is required - ADAPTIVE version of this is default for streams) + + ADAPTIVE strategies were introduced to handle corner-cases, when consumer hardware is not capable to consume + video stream and process frames at the same time (for instance - Nvidia Jetson devices running processing + against hi-res streams with high FPS ratio). It acts with buffer in nearly the same way as `DROP_OLDEST` + and `DROP_LATEST` strategies, but there are two more conditions that may influence frame drop: + * announced rate of source - which in fact dictate the pace of frames grabbing from incoming stream that + MUST be met by consumer to avoid strange decoding issues causing decoder to fail - if the pace of frame grabbing + deviates too much - decoding will be postponed, and frames dropped to grab next ones sooner + * consumption rate - in resource constraints environment, not only decoding is problematic from the performance + perspective - but also heavy processing. If consumer is not quick enough - allocating more useful resources + for decoding frames that may never be processed is a waste. That's why - if decoding happens more frequently + than consumption of frame - ADAPTIVE mode causes decoding to be done in a slower pace and more frames are just + grabbed and dropped on the floor. + ADAPTIVE mode increases latency slightly, but may be the only way to operate in some cases. + Behaviour of adaptive mode, including the maximum acceptable deviations of frames grabbing pace from source, + reader pace and maximum number of consecutive frames dropped in ADAPTIVE mode are configurable by clients, + with reasonable defaults being set. + + `VideoSource` emits events regarding its activity - which can be intercepted by custom handlers. Take + into account that they are always executed in context of thread invoking them (and should be fast to complete, + otherwise may block the flow of stream consumption). All errors raised will be emitted as logger warnings only. + + `VideoSource` implementation is naturally multithreading, with different thread decoding video and different + one consuming it and manipulating source state. Implementation of user interface is thread-safe, although + stream it is meant to be consumed by a single thread only. + + ENV variables involved: + * VIDEO_SOURCE_BUFFER_SIZE - default: 64 + * VIDEO_SOURCE_ADAPTIVE_MODE_STREAM_PACE_TOLERANCE - default: 0.1 + * VIDEO_SOURCE_ADAPTIVE_MODE_READER_PACE_TOLERANCE - default: 5.0 + * VIDEO_SOURCE_MINIMUM_ADAPTIVE_MODE_SAMPLES - default: 10 + * VIDEO_SOURCE_MAXIMUM_ADAPTIVE_FRAMES_DROPPED_IN_ROW - default: 16 + + As an `inference` user, please use .init() method instead of constructor to instantiate objects. + + Args: + video_reference (Union[str, int]): Either str with file or stream reference, or int representing device ID + buffer_size (int): size of decoding buffer + status_update_handlers (Optional[List[Callable[[StatusUpdate], None]]]): List of handlers for status updates + buffer_filling_strategy (Optional[BufferFillingStrategy]): Settings for buffer filling strategy - if not + given - automatic choice regarding source type will be applied + buffer_consumption_strategy (Optional[BufferConsumptionStrategy]): Settings for buffer consumption strategy, + if not given - automatic choice regarding source type will be applied + adaptive_mode_stream_pace_tolerance (float): Maximum deviation between frames grabbing pace and stream pace + that will not trigger adaptive mode frame drop + adaptive_mode_reader_pace_tolerance (float): Maximum deviation between decoding pace and stream consumption + pace that will not trigger adaptive mode frame drop + minimum_adaptive_mode_samples (int): Minimal number of frames to be used to establish actual pace of + processing, before adaptive mode can drop any frame + maximum_adaptive_frames_dropped_in_row (int): Maximum number of frames dropped in row due to application of + adaptive strategy + + Returns: Instance of `VideoSource` class + """ + frames_buffer = Queue(maxsize=buffer_size) + if status_update_handlers is None: + status_update_handlers = [] + video_consumer = VideoConsumer.init( + buffer_filling_strategy=buffer_filling_strategy, + adaptive_mode_stream_pace_tolerance=adaptive_mode_stream_pace_tolerance, + adaptive_mode_reader_pace_tolerance=adaptive_mode_reader_pace_tolerance, + minimum_adaptive_mode_samples=minimum_adaptive_mode_samples, + maximum_adaptive_frames_dropped_in_row=maximum_adaptive_frames_dropped_in_row, + status_update_handlers=status_update_handlers, + ) + return cls( + stream_reference=video_reference, + frames_buffer=frames_buffer, + status_update_handlers=status_update_handlers, + buffer_consumption_strategy=buffer_consumption_strategy, + video_consumer=video_consumer, + ) + + def __init__( + self, + stream_reference: Union[str, int], + frames_buffer: Queue, + status_update_handlers: List[Callable[[StatusUpdate], None]], + buffer_consumption_strategy: Optional[BufferConsumptionStrategy], + video_consumer: "VideoConsumer", + ): + self._stream_reference = stream_reference + self._video: Optional[cv2.VideoCapture] = None + self._source_properties: Optional[SourceProperties] = None + self._frames_buffer = frames_buffer + self._status_update_handlers = status_update_handlers + self._buffer_consumption_strategy = buffer_consumption_strategy + self._video_consumer = video_consumer + self._state = StreamState.NOT_STARTED + self._playback_allowed = Event() + self._frames_buffering_allowed = True + self._stream_consumption_thread: Optional[Thread] = None + self._state_change_lock = Lock() + + @lock_state_transition + def restart(self, wait_on_frames_consumption: bool = True) -> None: + """ + Method to restart source consumption. Eligible to be used in states: + [MUTED, RUNNING, PAUSED, ENDED, ERROR]. + End state: + * INITIALISING - that should change into RUNNING once first frame is ready to be grabbed + * ERROR - if it was not possible to connect with source + + Thread safe - only one transition of states possible at the time. + + Args: + wait_on_frames_consumption (bool): Flag telling if all frames from buffer must be consumed before + completion of this operation. + + Returns: None + Throws: + * StreamOperationNotAllowedError: if executed in context of incorrect state of the source + * SourceConnectionError: if source cannot be connected + """ + if self._state not in RESTART_ELIGIBLE_STATES: + raise StreamOperationNotAllowedError( + f"Could not RESTART stream in state: {self._state}" + ) + self._restart(wait_on_frames_consumption=wait_on_frames_consumption) + + @lock_state_transition + def start(self) -> None: + """ + Method to be used to start source consumption. Eligible to be used in states: + [NOT_STARTED, ENDED, (RESTARTING - which is internal state only)] + End state: + * INITIALISING - that should change into RUNNING once first frame is ready to be grabbed + * ERROR - if it was not possible to connect with source + + Thread safe - only one transition of states possible at the time. + + Returns: None + Throws: + * StreamOperationNotAllowedError: if executed in context of incorrect state of the source + * SourceConnectionError: if source cannot be connected + """ + if self._state not in START_ELIGIBLE_STATES: + raise StreamOperationNotAllowedError( + f"Could not START stream in state: {self._state}" + ) + self._start() + + @lock_state_transition + def terminate(self, wait_on_frames_consumption: bool = True) -> None: + """ + Method to be used to terminate source consumption. Eligible to be used in states: + [MUTED, RUNNING, PAUSED, ENDED, ERROR, (RESTARTING - which is internal state only)] + End state: + * ENDED - indicating success of the process + * ERROR - if error with processing occurred + + Must be used to properly dispose resources at the end. + + Thread safe - only one transition of states possible at the time. + + Args: + wait_on_frames_consumption (bool): Flag telling if all frames from buffer must be consumed before + completion of this operation. + + Returns: None + Throws: + * StreamOperationNotAllowedError: if executed in context of incorrect state of the source + """ + if self._state not in TERMINATE_ELIGIBLE_STATES: + raise StreamOperationNotAllowedError( + f"Could not TERMINATE stream in state: {self._state}" + ) + self._terminate(wait_on_frames_consumption=wait_on_frames_consumption) + + @lock_state_transition + def pause(self) -> None: + """ + Method to be used to pause source consumption. During pause - no new frames are consumed. + Used on on-line streams for too long may cause stream disconnection. + Eligible to be used in states: + [RUNNING] + End state: + * PAUSED + + Thread safe - only one transition of states possible at the time. + + Returns: None + Throws: + * StreamOperationNotAllowedError: if executed in context of incorrect state of the source + """ + if self._state not in PAUSE_ELIGIBLE_STATES: + raise StreamOperationNotAllowedError( + f"Could not PAUSE stream in state: {self._state}" + ) + self._pause() + + @lock_state_transition + def mute(self) -> None: + """ + Method to be used to mute source consumption. Muting is an equivalent of pause for stream - where + frames grabbing is not put on hold, just new frames decoding and buffering is not allowed - causing + intermediate frames to be dropped. May be also used against files, although arguably less useful. + Eligible to be used in states: + [RUNNING] + End state: + * MUTED + + Thread safe - only one transition of states possible at the time. + + Returns: None + Throws: + * StreamOperationNotAllowedError: if executed in context of incorrect state of the source + """ + if self._state not in MUTE_ELIGIBLE_STATES: + raise StreamOperationNotAllowedError( + f"Could not MUTE stream in state: {self._state}" + ) + self._mute() + + @lock_state_transition + def resume(self) -> None: + """ + Method to recover from pause or mute into running state. + [PAUSED, MUTED] + End state: + * RUNNING + + Thread safe - only one transition of states possible at the time. + + Returns: None + Throws: + * StreamOperationNotAllowedError: if executed in context of incorrect state of the source + """ + if self._state not in RESUME_ELIGIBLE_STATES: + raise StreamOperationNotAllowedError( + f"Could not RESUME stream in state: {self._state}" + ) + self._resume() + + def get_state(self) -> StreamState: + """ + Method to get current state of the `VideoSource` + + Returns: StreamState + """ + return self._state + + def frame_ready(self) -> bool: + """ + Method to check if decoded frame is ready for consumer + + Returns: boolean flag indicating frame readiness + """ + return not self._frames_buffer.empty() + + def read_frame(self) -> VideoFrame: + """ + Method to be used by the consumer to get decoded source frame. + + Returns: VideoFrame object with decoded frame and its metadata. + Throws: + * EndOfStreamError: when trying to get the frame from closed source. + """ + if self._buffer_consumption_strategy is BufferConsumptionStrategy.EAGER: + video_frame: Optional[VideoFrame] = purge_queue( + queue=self._frames_buffer, + on_successful_read=self._video_consumer.notify_frame_consumed, + ) + else: + video_frame: Optional[VideoFrame] = self._frames_buffer.get() + self._frames_buffer.task_done() + self._video_consumer.notify_frame_consumed() + if video_frame is None: + raise EndOfStreamError( + "Attempted to retrieve frame from stream that already ended." + ) + send_video_source_status_update( + severity=UpdateSeverity.DEBUG, + event_type=FRAME_CONSUMED_EVENT, + payload={ + "frame_timestamp": video_frame.frame_timestamp, + "frame_id": video_frame.frame_id, + }, + status_update_handlers=self._status_update_handlers, + ) + return video_frame + + def describe_source(self) -> SourceMetadata: + return SourceMetadata( + source_properties=self._source_properties, + source_reference=self._stream_reference, + buffer_size=self._frames_buffer.maxsize, + state=self._state, + buffer_filling_strategy=self._video_consumer.buffer_filling_strategy, + buffer_consumption_strategy=self._buffer_consumption_strategy, + ) + + def _restart(self, wait_on_frames_consumption: bool = True) -> None: + self._terminate(wait_on_frames_consumption=wait_on_frames_consumption) + self._change_state(target_state=StreamState.RESTARTING) + self._playback_allowed = Event() + self._frames_buffering_allowed = True + self._video: Optional[cv2.VideoCapture] = None + self._source_properties: Optional[SourceProperties] = None + self._start() + + def _start(self) -> None: + self._change_state(target_state=StreamState.INITIALISING) + self._video = cv2.VideoCapture(self._stream_reference) + if not self._video.isOpened(): + self._change_state(target_state=StreamState.ERROR) + raise SourceConnectionError( + f"Cannot connect to video source under reference: {self._stream_reference}" + ) + self._source_properties = discover_source_properties(stream=self._video) + self._video_consumer.reset(source_properties=self._source_properties) + if self._source_properties.is_file: + self._set_file_mode_consumption_strategies() + else: + self._set_stream_mode_consumption_strategies() + self._playback_allowed.set() + self._stream_consumption_thread = Thread(target=self._consume_video) + self._stream_consumption_thread.start() + + def _terminate(self, wait_on_frames_consumption: bool) -> None: + if self._state in RESUME_ELIGIBLE_STATES: + self._resume() + previous_state = self._state + self._change_state(target_state=StreamState.TERMINATING) + if self._stream_consumption_thread is not None: + self._stream_consumption_thread.join() + if wait_on_frames_consumption: + self._frames_buffer.join() + if previous_state is not StreamState.ERROR: + self._change_state(target_state=StreamState.ENDED) + + def _pause(self) -> None: + self._playback_allowed.clear() + self._change_state(target_state=StreamState.PAUSED) + + def _mute(self) -> None: + self._frames_buffering_allowed = False + self._change_state(target_state=StreamState.MUTED) + + def _resume(self) -> None: + previous_state = self._state + self._change_state(target_state=StreamState.RUNNING) + if previous_state is StreamState.PAUSED: + self._video_consumer.reset_stream_consumption_pace() + self._playback_allowed.set() + if previous_state is StreamState.MUTED: + self._frames_buffering_allowed = True + + def _set_file_mode_consumption_strategies(self) -> None: + if self._buffer_consumption_strategy is None: + self._buffer_consumption_strategy = BufferConsumptionStrategy.LAZY + + def _set_stream_mode_consumption_strategies(self) -> None: + if self._buffer_consumption_strategy is None: + self._buffer_consumption_strategy = BufferConsumptionStrategy.EAGER + + def _consume_video(self) -> None: + send_video_source_status_update( + severity=UpdateSeverity.INFO, + event_type=VIDEO_CONSUMPTION_STARTED_EVENT, + status_update_handlers=self._status_update_handlers, + ) + logger.info(f"Video consumption started") + try: + self._change_state(target_state=StreamState.RUNNING) + declared_source_fps = None + if self._source_properties is not None: + declared_source_fps = self._source_properties.fps + while self._video.isOpened(): + if self._state is StreamState.TERMINATING: + break + self._playback_allowed.wait() + success = self._video_consumer.consume_frame( + video=self._video, + declared_source_fps=declared_source_fps, + buffer=self._frames_buffer, + frames_buffering_allowed=self._frames_buffering_allowed, + ) + if not success: + break + self._frames_buffer.put(None) + self._video.release() + self._change_state(target_state=StreamState.ENDED) + send_video_source_status_update( + severity=UpdateSeverity.INFO, + event_type=VIDEO_CONSUMPTION_FINISHED_EVENT, + status_update_handlers=self._status_update_handlers, + ) + logger.info(f"Video consumption finished") + except Exception as error: + self._change_state(target_state=StreamState.ERROR) + payload = { + "error_type": error.__class__.__name__, + "error_message": str(error), + "error_context": "stream_consumer_thread", + } + send_video_source_status_update( + severity=UpdateSeverity.ERROR, + event_type=SOURCE_ERROR_EVENT, + payload=payload, + status_update_handlers=self._status_update_handlers, + ) + logger.exception("Encountered error in video consumption thread") + + def _change_state(self, target_state: StreamState) -> None: + payload = { + "previous_state": self._state, + "new_state": target_state, + } + self._state = target_state + send_video_source_status_update( + severity=UpdateSeverity.INFO, + event_type=SOURCE_STATE_UPDATE_EVENT, + payload=payload, + status_update_handlers=self._status_update_handlers, + ) + + def __iter__(self) -> "VideoSource": + return self + + def __next__(self) -> VideoFrame: + """ + Method allowing to use `VideoSource` convenient to read frames + + Returns: VideoFrame + + Example: + ```python + source = VideoSource.init(video_reference="./some.mp4") + source.start() + + for frame in source: + pass + ``` + """ + try: + return self.read_frame() + except EndOfStreamError: + raise StopIteration() + + +class VideoConsumer: + """ + This class should be consumed as part of internal implementation. + It provides abstraction around stream consumption strategies. + """ + + @classmethod + def init( + cls, + buffer_filling_strategy: Optional[BufferFillingStrategy], + adaptive_mode_stream_pace_tolerance: float, + adaptive_mode_reader_pace_tolerance: float, + minimum_adaptive_mode_samples: int, + maximum_adaptive_frames_dropped_in_row: int, + status_update_handlers: List[Callable[[StatusUpdate], None]], + ) -> "VideoConsumer": + minimum_adaptive_mode_samples = max(minimum_adaptive_mode_samples, 2) + reader_pace_monitor = sv.FPSMonitor( + sample_size=10 * minimum_adaptive_mode_samples + ) + stream_consumption_pace_monitor = sv.FPSMonitor( + sample_size=10 * minimum_adaptive_mode_samples + ) + decoding_pace_monitor = sv.FPSMonitor( + sample_size=10 * minimum_adaptive_mode_samples + ) + return cls( + buffer_filling_strategy=buffer_filling_strategy, + adaptive_mode_stream_pace_tolerance=adaptive_mode_stream_pace_tolerance, + adaptive_mode_reader_pace_tolerance=adaptive_mode_reader_pace_tolerance, + minimum_adaptive_mode_samples=minimum_adaptive_mode_samples, + maximum_adaptive_frames_dropped_in_row=maximum_adaptive_frames_dropped_in_row, + status_update_handlers=status_update_handlers, + reader_pace_monitor=reader_pace_monitor, + stream_consumption_pace_monitor=stream_consumption_pace_monitor, + decoding_pace_monitor=decoding_pace_monitor, + ) + + def __init__( + self, + buffer_filling_strategy: Optional[BufferFillingStrategy], + adaptive_mode_stream_pace_tolerance: float, + adaptive_mode_reader_pace_tolerance: float, + minimum_adaptive_mode_samples: int, + maximum_adaptive_frames_dropped_in_row: int, + status_update_handlers: List[Callable[[StatusUpdate], None]], + reader_pace_monitor: sv.FPSMonitor, + stream_consumption_pace_monitor: sv.FPSMonitor, + decoding_pace_monitor: sv.FPSMonitor, + ): + self._buffer_filling_strategy = buffer_filling_strategy + self._frame_counter = 0 + self._adaptive_mode_stream_pace_tolerance = adaptive_mode_stream_pace_tolerance + self._adaptive_mode_reader_pace_tolerance = adaptive_mode_reader_pace_tolerance + self._minimum_adaptive_mode_samples = minimum_adaptive_mode_samples + self._maximum_adaptive_frames_dropped_in_row = ( + maximum_adaptive_frames_dropped_in_row + ) + self._adaptive_frames_dropped_in_row = 0 + self._reader_pace_monitor = reader_pace_monitor + self._stream_consumption_pace_monitor = stream_consumption_pace_monitor + self._decoding_pace_monitor = decoding_pace_monitor + self._status_update_handlers = status_update_handlers + + @property + def buffer_filling_strategy(self) -> Optional[BufferFillingStrategy]: + return self._buffer_filling_strategy + + def reset(self, source_properties: SourceProperties) -> None: + if source_properties.is_file: + self._set_file_mode_buffering_strategies() + else: + self._set_stream_mode_buffering_strategies() + self._reader_pace_monitor.reset() + self.reset_stream_consumption_pace() + self._decoding_pace_monitor.reset() + self._adaptive_frames_dropped_in_row = 0 + + def reset_stream_consumption_pace(self) -> None: + self._stream_consumption_pace_monitor.reset() + + def notify_frame_consumed(self) -> None: + self._reader_pace_monitor.tick() + + def consume_frame( + self, + video: cv2.VideoCapture, + declared_source_fps: Optional[float], + buffer: Queue, + frames_buffering_allowed: bool, + ) -> bool: + frame_timestamp = datetime.now() + success = video.grab() + self._stream_consumption_pace_monitor.tick() + if not success: + return False + self._frame_counter += 1 + send_video_source_status_update( + severity=UpdateSeverity.DEBUG, + event_type=FRAME_CAPTURED_EVENT, + payload={ + "frame_timestamp": frame_timestamp, + "frame_id": self._frame_counter, + }, + status_update_handlers=self._status_update_handlers, + ) + return self._consume_stream_frame( + video=video, + declared_source_fps=declared_source_fps, + frame_timestamp=frame_timestamp, + buffer=buffer, + frames_buffering_allowed=frames_buffering_allowed, + ) + + def _set_file_mode_buffering_strategies(self) -> None: + if self._buffer_filling_strategy is None: + self._buffer_filling_strategy = BufferFillingStrategy.WAIT + + def _set_stream_mode_buffering_strategies(self) -> None: + if self._buffer_filling_strategy is None: + self._buffer_filling_strategy = BufferFillingStrategy.ADAPTIVE_DROP_OLDEST + + def _consume_stream_frame( + self, + video: cv2.VideoCapture, + declared_source_fps: Optional[float], + frame_timestamp: datetime, + buffer: Queue, + frames_buffering_allowed: bool, + ) -> bool: + """ + Returns: boolean flag with success status + """ + if not frames_buffering_allowed: + send_frame_drop_update( + frame_timestamp=frame_timestamp, + frame_id=self._frame_counter, + cause="Buffering not allowed at the moment", + status_update_handlers=self._status_update_handlers, + ) + return True + if self._frame_should_be_adaptively_dropped( + declared_source_fps=declared_source_fps + ): + self._adaptive_frames_dropped_in_row += 1 + send_frame_drop_update( + frame_timestamp=frame_timestamp, + frame_id=self._frame_counter, + cause="ADAPTIVE strategy", + status_update_handlers=self._status_update_handlers, + ) + return True + self._adaptive_frames_dropped_in_row = 0 + if ( + not buffer.full() + or self._buffer_filling_strategy is BufferFillingStrategy.WAIT + ): + return decode_video_frame_to_buffer( + frame_timestamp=frame_timestamp, + frame_id=self._frame_counter, + video=video, + buffer=buffer, + decoding_pace_monitor=self._decoding_pace_monitor, + ) + if self._buffer_filling_strategy in DROP_OLDEST_STRATEGIES: + return self._process_stream_frame_dropping_oldest( + frame_timestamp=frame_timestamp, + video=video, + buffer=buffer, + ) + send_frame_drop_update( + frame_timestamp=frame_timestamp, + frame_id=self._frame_counter, + cause="DROP_LATEST strategy", + status_update_handlers=self._status_update_handlers, + ) + return True + + def _frame_should_be_adaptively_dropped( + self, declared_source_fps: Optional[float] + ) -> bool: + if self._buffer_filling_strategy not in ADAPTIVE_STRATEGIES: + return False + if ( + self._adaptive_frames_dropped_in_row + >= self._maximum_adaptive_frames_dropped_in_row + ): + return False + if ( + len(self._stream_consumption_pace_monitor.all_timestamps) + <= self._minimum_adaptive_mode_samples + ): + # not enough observations + return False + stream_consumption_pace = self._stream_consumption_pace_monitor() + announced_stream_fps = stream_consumption_pace + if declared_source_fps is not None and declared_source_fps > 0: + announced_stream_fps = declared_source_fps + if ( + announced_stream_fps - stream_consumption_pace + > self._adaptive_mode_stream_pace_tolerance + ): + # cannot keep up with stream emission + return True + if ( + len(self._reader_pace_monitor.all_timestamps) + <= self._minimum_adaptive_mode_samples + ) or ( + len(self._decoding_pace_monitor.all_timestamps) + <= self._minimum_adaptive_mode_samples + ): + # not enough observations + return False + actual_reader_pace = get_fps_if_tick_happens_now( + fps_monitor=self._reader_pace_monitor + ) + decoding_pace = self._decoding_pace_monitor() + if ( + decoding_pace - actual_reader_pace + > self._adaptive_mode_reader_pace_tolerance + ): + # we are too fast for the reader - time to save compute on decoding + return True + return False + + def _process_stream_frame_dropping_oldest( + self, + frame_timestamp: datetime, + video: cv2.VideoCapture, + buffer: Queue, + ) -> bool: + drop_single_frame_from_buffer( + buffer=buffer, + cause="DROP_OLDEST strategy", + status_update_handlers=self._status_update_handlers, + ) + return decode_video_frame_to_buffer( + frame_timestamp=frame_timestamp, + frame_id=self._frame_counter, + video=video, + buffer=buffer, + decoding_pace_monitor=self._decoding_pace_monitor, + ) + + +def discover_source_properties(stream: cv2.VideoCapture) -> SourceProperties: + width = int(stream.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(stream.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = stream.get(cv2.CAP_PROP_FPS) + total_frames = int(stream.get(cv2.CAP_PROP_FRAME_COUNT)) + return SourceProperties( + width=width, + height=height, + total_frames=total_frames, + is_file=total_frames > 0, + fps=fps, + ) + + +def purge_queue( + queue: Queue, + wait_on_empty: bool = True, + on_successful_read: Callable[[], None] = lambda: None, +) -> Optional[Any]: + result = None + if queue.empty() and wait_on_empty: + result = queue.get() + queue.task_done() + on_successful_read() + while not queue.empty(): + result = queue.get() + queue.task_done() + on_successful_read() + return result + + +def drop_single_frame_from_buffer( + buffer: Queue, + cause: str, + status_update_handlers: List[Callable[[StatusUpdate], None]], +) -> None: + try: + video_frame = buffer.get_nowait() + buffer.task_done() + send_frame_drop_update( + frame_timestamp=video_frame.frame_timestamp, + frame_id=video_frame.frame_id, + cause=cause, + status_update_handlers=status_update_handlers, + ) + except Empty: + # buffer may be emptied in the meantime, hence we ignore Empty + pass + + +def send_frame_drop_update( + frame_timestamp: datetime, + frame_id: int, + cause: str, + status_update_handlers: List[Callable[[StatusUpdate], None]], +) -> None: + send_video_source_status_update( + severity=UpdateSeverity.DEBUG, + event_type=FRAME_DROPPED_EVENT, + payload={ + "frame_timestamp": frame_timestamp, + "frame_id": frame_id, + "cause": cause, + }, + status_update_handlers=status_update_handlers, + sub_context=VIDEO_CONSUMER_CONTEXT, + ) + + +def send_video_source_status_update( + severity: UpdateSeverity, + event_type: str, + status_update_handlers: List[Callable[[StatusUpdate], None]], + sub_context: Optional[str] = None, + payload: Optional[dict] = None, +) -> None: + if payload is None: + payload = {} + context = VIDEO_SOURCE_CONTEXT + if sub_context is not None: + context = f"{context}.{sub_context}" + status_update = StatusUpdate( + timestamp=datetime.now(), + severity=severity, + event_type=event_type, + payload=payload, + context=context, + ) + for handler in status_update_handlers: + try: + handler(status_update) + except Exception as error: + logger.warning(f"Could not execute handler update. Cause: {error}") + + +def decode_video_frame_to_buffer( + frame_timestamp: datetime, + frame_id: int, + video: cv2.VideoCapture, + buffer: Queue, + decoding_pace_monitor: sv.FPSMonitor, +) -> bool: + success, image = video.retrieve() + if not success: + return False + decoding_pace_monitor.tick() + video_frame = VideoFrame( + image=image, frame_id=frame_id, frame_timestamp=frame_timestamp + ) + buffer.put(video_frame) + return True + + +def get_fps_if_tick_happens_now(fps_monitor: sv.FPSMonitor) -> float: + if len(fps_monitor.all_timestamps) == 0: + return 0.0 + min_reader_timestamp = fps_monitor.all_timestamps[0] + now = time.monotonic() + reader_taken_time = now - min_reader_timestamp + return (len(fps_monitor.all_timestamps) + 1) / reader_taken_time diff --git a/inference/core/interfaces/http/__init__.py b/inference/core/interfaces/http/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/core/interfaces/http/__pycache__/__init__.cpython-310.pyc b/inference/core/interfaces/http/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1661e667aa12ae3954536e0f802d1c166f293e2 Binary files /dev/null and b/inference/core/interfaces/http/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/interfaces/http/__pycache__/http_api.cpython-310.pyc b/inference/core/interfaces/http/__pycache__/http_api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbe6ceac34288e6cd4650454e62d7b8140e81a53 Binary files /dev/null and b/inference/core/interfaces/http/__pycache__/http_api.cpython-310.pyc differ diff --git a/inference/core/interfaces/http/__pycache__/orjson_utils.cpython-310.pyc b/inference/core/interfaces/http/__pycache__/orjson_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ccd5fd7e43a5bb7712534f88f97fd83ef119c8f Binary files /dev/null and b/inference/core/interfaces/http/__pycache__/orjson_utils.cpython-310.pyc differ diff --git a/inference/core/interfaces/http/http_api.py b/inference/core/interfaces/http/http_api.py new file mode 100644 index 0000000000000000000000000000000000000000..0c44c1b65c622edfd7add0fb5e28adca3f5e2642 --- /dev/null +++ b/inference/core/interfaces/http/http_api.py @@ -0,0 +1,1456 @@ +import base64 +import traceback +from functools import partial, wraps +from time import sleep +from typing import Any, List, Optional, Union + +import uvicorn +from fastapi import BackgroundTasks, FastAPI, Path, Query, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, RedirectResponse, Response +from fastapi.staticfiles import StaticFiles +from fastapi_cprofile.profiler import CProfileMiddleware + +from inference.core import logger +from inference.core.cache import cache +from inference.core.devices.utils import GLOBAL_INFERENCE_SERVER_ID +from inference.core.entities.requests.clip import ( + ClipCompareRequest, + ClipImageEmbeddingRequest, + ClipTextEmbeddingRequest, +) +from inference.core.entities.requests.cogvlm import CogVLMInferenceRequest +from inference.core.entities.requests.doctr import DoctrOCRInferenceRequest +from inference.core.entities.requests.gaze import GazeDetectionInferenceRequest +from inference.core.entities.requests.groundingdino import GroundingDINOInferenceRequest +from inference.core.entities.requests.inference import ( + ClassificationInferenceRequest, + InferenceRequest, + InferenceRequestImage, + InstanceSegmentationInferenceRequest, + KeypointsDetectionInferenceRequest, + ObjectDetectionInferenceRequest, +) +from inference.core.entities.requests.sam import ( + SamEmbeddingRequest, + SamSegmentationRequest, +) +from inference.core.entities.requests.server_state import ( + AddModelRequest, + ClearModelRequest, +) +from inference.core.entities.requests.workflows import ( + WorkflowInferenceRequest, + WorkflowSpecificationInferenceRequest, +) +from inference.core.entities.requests.yolo_world import YOLOWorldInferenceRequest +from inference.core.entities.responses.clip import ( + ClipCompareResponse, + ClipEmbeddingResponse, +) +from inference.core.entities.responses.cogvlm import CogVLMResponse +from inference.core.entities.responses.doctr import DoctrOCRInferenceResponse +from inference.core.entities.responses.gaze import GazeDetectionInferenceResponse +from inference.core.entities.responses.inference import ( + ClassificationInferenceResponse, + InferenceResponse, + InstanceSegmentationInferenceResponse, + KeypointsDetectionInferenceResponse, + MultiLabelClassificationInferenceResponse, + ObjectDetectionInferenceResponse, + StubResponse, +) +from inference.core.entities.responses.notebooks import NotebookStartResponse +from inference.core.entities.responses.sam import ( + SamEmbeddingResponse, + SamSegmentationResponse, +) +from inference.core.entities.responses.server_state import ( + ModelsDescriptions, + ServerVersionInfo, +) +from inference.core.entities.responses.workflows import WorkflowInferenceResponse +from inference.core.env import ( + ALLOW_ORIGINS, + CORE_MODEL_CLIP_ENABLED, + CORE_MODEL_COGVLM_ENABLED, + CORE_MODEL_DOCTR_ENABLED, + CORE_MODEL_GAZE_ENABLED, + CORE_MODEL_GROUNDINGDINO_ENABLED, + CORE_MODEL_SAM_ENABLED, + CORE_MODEL_YOLO_WORLD_ENABLED, + CORE_MODELS_ENABLED, + DISABLE_WORKFLOW_ENDPOINTS, + LAMBDA, + LEGACY_ROUTE_ENABLED, + METLO_KEY, + METRICS_ENABLED, + NOTEBOOK_ENABLED, + NOTEBOOK_PASSWORD, + NOTEBOOK_PORT, + PROFILE, + ROBOFLOW_SERVICE_SECRET, + WORKFLOWS_MAX_CONCURRENT_STEPS, + WORKFLOWS_STEP_EXECUTION_MODE, +) +from inference.core.exceptions import ( + ContentTypeInvalid, + ContentTypeMissing, + InferenceModelNotFound, + InputImageLoadError, + InvalidEnvironmentVariableError, + InvalidMaskDecodeArgument, + InvalidModelIDError, + MalformedRoboflowAPIResponseError, + MalformedWorkflowResponseError, + MissingApiKeyError, + MissingServiceSecretError, + ModelArtefactError, + OnnxProviderNotAvailable, + PostProcessingError, + PreProcessingError, + RoboflowAPIConnectionError, + RoboflowAPINotAuthorizedError, + RoboflowAPINotNotFoundError, + RoboflowAPIUnsuccessfulRequestError, + ServiceConfigurationError, + WorkspaceLoadError, +) +from inference.core.interfaces.base import BaseInterface +from inference.core.interfaces.http.orjson_utils import ( + orjson_response, + serialise_workflow_result, +) +from inference.core.managers.base import ModelManager +from inference.core.roboflow_api import ( + get_roboflow_workspace, + get_workflow_specification, +) +from inference.core.utils.notebooks import start_notebook +from inference.enterprise.workflows.complier.core import compile_and_execute_async +from inference.enterprise.workflows.complier.entities import StepExecutionMode +from inference.enterprise.workflows.complier.steps_executors.active_learning_middlewares import ( + WorkflowsActiveLearningMiddleware, +) +from inference.enterprise.workflows.errors import ( + ExecutionEngineError, + RuntimePayloadError, + WorkflowsCompilerError, +) +from inference.models.aliases import resolve_roboflow_model_alias + +if LAMBDA: + from inference.core.usage import trackUsage +if METLO_KEY: + from metlo.fastapi import ASGIMiddleware + +from inference.core.version import __version__ + + +def with_route_exceptions(route): + """ + A decorator that wraps a FastAPI route to handle specific exceptions. If an exception + is caught, it returns a JSON response with the error message. + + Args: + route (Callable): The FastAPI route to be wrapped. + + Returns: + Callable: The wrapped route. + """ + + @wraps(route) + async def wrapped_route(*args, **kwargs): + try: + return await route(*args, **kwargs) + except ( + ContentTypeInvalid, + ContentTypeMissing, + InputImageLoadError, + InvalidModelIDError, + InvalidMaskDecodeArgument, + MissingApiKeyError, + RuntimePayloadError, + ) as e: + resp = JSONResponse(status_code=400, content={"message": str(e)}) + traceback.print_exc() + except RoboflowAPINotAuthorizedError as e: + resp = JSONResponse(status_code=401, content={"message": str(e)}) + traceback.print_exc() + except (RoboflowAPINotNotFoundError, InferenceModelNotFound) as e: + resp = JSONResponse(status_code=404, content={"message": str(e)}) + traceback.print_exc() + except ( + InvalidEnvironmentVariableError, + MissingServiceSecretError, + WorkspaceLoadError, + PreProcessingError, + PostProcessingError, + ServiceConfigurationError, + ModelArtefactError, + MalformedWorkflowResponseError, + WorkflowsCompilerError, + ExecutionEngineError, + ) as e: + resp = JSONResponse(status_code=500, content={"message": str(e)}) + traceback.print_exc() + except OnnxProviderNotAvailable as e: + resp = JSONResponse(status_code=501, content={"message": str(e)}) + traceback.print_exc() + except ( + MalformedRoboflowAPIResponseError, + RoboflowAPIUnsuccessfulRequestError, + ) as e: + resp = JSONResponse(status_code=502, content={"message": str(e)}) + traceback.print_exc() + except RoboflowAPIConnectionError as e: + resp = JSONResponse(status_code=503, content={"message": str(e)}) + traceback.print_exc() + except Exception: + resp = JSONResponse(status_code=500, content={"message": "Internal error."}) + traceback.print_exc() + return resp + + return wrapped_route + + +class HttpInterface(BaseInterface): + """Roboflow defined HTTP interface for a general-purpose inference server. + + This class sets up the FastAPI application and adds necessary middleware, + as well as initializes the model manager and model registry for the inference server. + + Attributes: + app (FastAPI): The FastAPI application instance. + model_manager (ModelManager): The manager for handling different models. + """ + + def __init__( + self, + model_manager: ModelManager, + root_path: Optional[str] = None, + ): + """ + Initializes the HttpInterface with given model manager and model registry. + + Args: + model_manager (ModelManager): The manager for handling different models. + root_path (Optional[str]): The root path for the FastAPI application. + + Description: + Deploy Roboflow trained models to nearly any compute environment! + """ + description = "Roboflow inference server" + app = FastAPI( + title="Roboflow Inference Server", + description=description, + version=__version__, + terms_of_service="https://roboflow.com/terms", + contact={ + "name": "Roboflow Inc.", + "url": "https://roboflow.com/contact", + "email": "help@roboflow.com", + }, + license_info={ + "name": "Apache 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0.html", + }, + root_path=root_path, + ) + if METLO_KEY: + app.add_middleware( + ASGIMiddleware, host="https://app.metlo.com", api_key=METLO_KEY + ) + + if len(ALLOW_ORIGINS) > 0: + app.add_middleware( + CORSMiddleware, + allow_origins=ALLOW_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Optionally add middleware for profiling the FastAPI server and underlying inference API code + if PROFILE: + app.add_middleware( + CProfileMiddleware, + enable=True, + server_app=app, + filename="/profile/output.pstats", + strip_dirs=False, + sort_by="cumulative", + ) + + if METRICS_ENABLED: + + @app.middleware("http") + async def count_errors(request: Request, call_next): + """Middleware to count errors. + + Args: + request (Request): The incoming request. + call_next (Callable): The next middleware or endpoint to call. + + Returns: + Response: The response from the next middleware or endpoint. + """ + response = await call_next(request) + if response.status_code >= 400: + self.model_manager.num_errors += 1 + return response + + self.app = app + self.model_manager = model_manager + self.workflows_active_learning_middleware = WorkflowsActiveLearningMiddleware( + cache=cache, + ) + + async def process_inference_request( + inference_request: InferenceRequest, **kwargs + ) -> InferenceResponse: + """Processes an inference request by calling the appropriate model. + + Args: + inference_request (InferenceRequest): The request containing model ID and other inference details. + + Returns: + InferenceResponse: The response containing the inference results. + """ + de_aliased_model_id = resolve_roboflow_model_alias( + model_id=inference_request.model_id + ) + self.model_manager.add_model(de_aliased_model_id, inference_request.api_key) + resp = await self.model_manager.infer_from_request( + de_aliased_model_id, inference_request, **kwargs + ) + return orjson_response(resp) + + async def process_workflow_inference_request( + workflow_request: WorkflowInferenceRequest, + workflow_specification: dict, + background_tasks: Optional[BackgroundTasks], + ) -> WorkflowInferenceResponse: + step_execution_mode = StepExecutionMode(WORKFLOWS_STEP_EXECUTION_MODE) + result = await compile_and_execute_async( + workflow_specification=workflow_specification, + runtime_parameters=workflow_request.inputs, + model_manager=model_manager, + api_key=workflow_request.api_key, + max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS, + step_execution_mode=step_execution_mode, + active_learning_middleware=self.workflows_active_learning_middleware, + background_tasks=background_tasks, + ) + outputs = serialise_workflow_result( + result=result, + excluded_fields=workflow_request.excluded_fields, + ) + response = WorkflowInferenceResponse(outputs=outputs) + return orjson_response(response=response) + + def load_core_model( + inference_request: InferenceRequest, + api_key: Optional[str] = None, + core_model: str = None, + ) -> None: + """Loads a core model (e.g., "clip" or "sam") into the model manager. + + Args: + inference_request (InferenceRequest): The request containing version and other details. + api_key (Optional[str]): The API key for the request. + core_model (str): The core model type, e.g., "clip" or "sam". + + Returns: + str: The core model ID. + """ + if api_key: + inference_request.api_key = api_key + version_id_field = f"{core_model}_version_id" + core_model_id = ( + f"{core_model}/{inference_request.__getattribute__(version_id_field)}" + ) + self.model_manager.add_model(core_model_id, inference_request.api_key) + return core_model_id + + load_clip_model = partial(load_core_model, core_model="clip") + """Loads the CLIP model into the model manager. + + Args: + inference_request: The request containing version and other details. + api_key: The API key for the request. + + Returns: + The CLIP model ID. + """ + + load_sam_model = partial(load_core_model, core_model="sam") + """Loads the SAM model into the model manager. + + Args: + inference_request: The request containing version and other details. + api_key: The API key for the request. + + Returns: + The SAM model ID. + """ + + load_gaze_model = partial(load_core_model, core_model="gaze") + """Loads the GAZE model into the model manager. + + Args: + inference_request: The request containing version and other details. + api_key: The API key for the request. + + Returns: + The GAZE model ID. + """ + + load_doctr_model = partial(load_core_model, core_model="doctr") + """Loads the DocTR model into the model manager. + + Args: + inference_request: The request containing version and other details. + api_key: The API key for the request. + + Returns: + The DocTR model ID. + """ + load_cogvlm_model = partial(load_core_model, core_model="cogvlm") + + load_grounding_dino_model = partial( + load_core_model, core_model="grounding_dino" + ) + """Loads the Grounding DINO model into the model manager. + + Args: + inference_request: The request containing version and other details. + api_key: The API key for the request. + + Returns: + The Grounding DINO model ID. + """ + + load_yolo_world_model = partial(load_core_model, core_model="yolo_world") + """Loads the YOLO World model into the model manager. + + Args: + inference_request: The request containing version and other details. + api_key: The API key for the request. + + Returns: + The YOLO World model ID. + """ + + @app.get( + "/info", + response_model=ServerVersionInfo, + summary="Info", + description="Get the server name and version number", + ) + async def root(): + """Endpoint to get the server name and version number. + + Returns: + ServerVersionInfo: The server version information. + """ + return ServerVersionInfo( + name="Roboflow Inference Server", + version=__version__, + uuid=GLOBAL_INFERENCE_SERVER_ID, + ) + + # The current AWS Lambda authorizer only supports path parameters, therefore we can only use the legacy infer route. This case statement excludes routes which won't work for the current Lambda authorizer. + if not LAMBDA: + + @app.get( + "/model/registry", + response_model=ModelsDescriptions, + summary="Get model keys", + description="Get the ID of each loaded model", + ) + async def registry(): + """Get the ID of each loaded model in the registry. + + Returns: + ModelsDescriptions: The object containing models descriptions + """ + logger.debug(f"Reached /model/registry") + models_descriptions = self.model_manager.describe_models() + return ModelsDescriptions.from_models_descriptions( + models_descriptions=models_descriptions + ) + + @app.post( + "/model/add", + response_model=ModelsDescriptions, + summary="Load a model", + description="Load the model with the given model ID", + ) + @with_route_exceptions + async def model_add(request: AddModelRequest): + """Load the model with the given model ID into the model manager. + + Args: + request (AddModelRequest): The request containing the model ID and optional API key. + + Returns: + ModelsDescriptions: The object containing models descriptions + """ + logger.debug(f"Reached /model/add") + de_aliased_model_id = resolve_roboflow_model_alias( + model_id=request.model_id + ) + self.model_manager.add_model(de_aliased_model_id, request.api_key) + models_descriptions = self.model_manager.describe_models() + return ModelsDescriptions.from_models_descriptions( + models_descriptions=models_descriptions + ) + + @app.post( + "/model/remove", + response_model=ModelsDescriptions, + summary="Remove a model", + description="Remove the model with the given model ID", + ) + @with_route_exceptions + async def model_remove(request: ClearModelRequest): + """Remove the model with the given model ID from the model manager. + + Args: + request (ClearModelRequest): The request containing the model ID to be removed. + + Returns: + ModelsDescriptions: The object containing models descriptions + """ + logger.debug(f"Reached /model/remove") + de_aliased_model_id = resolve_roboflow_model_alias( + model_id=request.model_id + ) + self.model_manager.remove(de_aliased_model_id) + models_descriptions = self.model_manager.describe_models() + return ModelsDescriptions.from_models_descriptions( + models_descriptions=models_descriptions + ) + + @app.post( + "/model/clear", + response_model=ModelsDescriptions, + summary="Remove all models", + description="Remove all loaded models", + ) + @with_route_exceptions + async def model_clear(): + """Remove all loaded models from the model manager. + + Returns: + ModelsDescriptions: The object containing models descriptions + """ + logger.debug(f"Reached /model/clear") + self.model_manager.clear() + models_descriptions = self.model_manager.describe_models() + return ModelsDescriptions.from_models_descriptions( + models_descriptions=models_descriptions + ) + + @app.post( + "/infer/object_detection", + response_model=Union[ + ObjectDetectionInferenceResponse, + List[ObjectDetectionInferenceResponse], + StubResponse, + ], + summary="Object detection infer", + description="Run inference with the specified object detection model", + response_model_exclude_none=True, + ) + @with_route_exceptions + async def infer_object_detection( + inference_request: ObjectDetectionInferenceRequest, + background_tasks: BackgroundTasks, + ): + """Run inference with the specified object detection model. + + Args: + inference_request (ObjectDetectionInferenceRequest): The request containing the necessary details for object detection. + background_tasks: (BackgroundTasks) pool of fastapi background tasks + + Returns: + Union[ObjectDetectionInferenceResponse, List[ObjectDetectionInferenceResponse]]: The response containing the inference results. + """ + logger.debug(f"Reached /infer/object_detection") + return await process_inference_request( + inference_request, + active_learning_eligible=True, + background_tasks=background_tasks, + ) + + @app.post( + "/infer/instance_segmentation", + response_model=Union[ + InstanceSegmentationInferenceResponse, StubResponse + ], + summary="Instance segmentation infer", + description="Run inference with the specified instance segmentation model", + ) + @with_route_exceptions + async def infer_instance_segmentation( + inference_request: InstanceSegmentationInferenceRequest, + background_tasks: BackgroundTasks, + ): + """Run inference with the specified instance segmentation model. + + Args: + inference_request (InstanceSegmentationInferenceRequest): The request containing the necessary details for instance segmentation. + background_tasks: (BackgroundTasks) pool of fastapi background tasks + + Returns: + InstanceSegmentationInferenceResponse: The response containing the inference results. + """ + logger.debug(f"Reached /infer/instance_segmentation") + return await process_inference_request( + inference_request, + active_learning_eligible=True, + background_tasks=background_tasks, + ) + + @app.post( + "/infer/classification", + response_model=Union[ + ClassificationInferenceResponse, + MultiLabelClassificationInferenceResponse, + StubResponse, + ], + summary="Classification infer", + description="Run inference with the specified classification model", + ) + @with_route_exceptions + async def infer_classification( + inference_request: ClassificationInferenceRequest, + background_tasks: BackgroundTasks, + ): + """Run inference with the specified classification model. + + Args: + inference_request (ClassificationInferenceRequest): The request containing the necessary details for classification. + background_tasks: (BackgroundTasks) pool of fastapi background tasks + + Returns: + Union[ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse]: The response containing the inference results. + """ + logger.debug(f"Reached /infer/classification") + return await process_inference_request( + inference_request, + active_learning_eligible=True, + background_tasks=background_tasks, + ) + + @app.post( + "/infer/keypoints_detection", + response_model=Union[KeypointsDetectionInferenceResponse, StubResponse], + summary="Keypoints detection infer", + description="Run inference with the specified keypoints detection model", + ) + @with_route_exceptions + async def infer_keypoints( + inference_request: KeypointsDetectionInferenceRequest, + ): + """Run inference with the specified keypoints detection model. + + Args: + inference_request (KeypointsDetectionInferenceRequest): The request containing the necessary details for keypoints detection. + + Returns: + Union[ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse]: The response containing the inference results. + """ + logger.debug(f"Reached /infer/keypoints_detection") + return await process_inference_request(inference_request) + + if not DISABLE_WORKFLOW_ENDPOINTS: + + @app.post( + "/infer/workflows/{workspace_name}/{workflow_name}", + response_model=WorkflowInferenceResponse, + summary="Endpoint to trigger inference from predefined workflow", + description="Checks Roboflow API for workflow definition, once acquired - parses and executes injecting runtime parameters from request body", + ) + @with_route_exceptions + async def infer_from_predefined_workflow( + workspace_name: str, + workflow_name: str, + workflow_request: WorkflowInferenceRequest, + background_tasks: BackgroundTasks, + ) -> WorkflowInferenceResponse: + workflow_specification = get_workflow_specification( + api_key=workflow_request.api_key, + workspace_id=workspace_name, + workflow_name=workflow_name, + ) + return await process_workflow_inference_request( + workflow_request=workflow_request, + workflow_specification=workflow_specification, + background_tasks=background_tasks if not LAMBDA else None, + ) + + @app.post( + "/infer/workflows", + response_model=WorkflowInferenceResponse, + summary="Endpoint to trigger inference from workflow specification provided in payload", + description="Parses and executes workflow specification, injecting runtime parameters from request body", + ) + @with_route_exceptions + async def infer_from_workflow( + workflow_request: WorkflowSpecificationInferenceRequest, + background_tasks: BackgroundTasks, + ) -> WorkflowInferenceResponse: + workflow_specification = { + "specification": workflow_request.specification + } + return await process_workflow_inference_request( + workflow_request=workflow_request, + workflow_specification=workflow_specification, + background_tasks=background_tasks if not LAMBDA else None, + ) + + if CORE_MODELS_ENABLED: + if CORE_MODEL_CLIP_ENABLED: + + @app.post( + "/clip/embed_image", + response_model=ClipEmbeddingResponse, + summary="CLIP Image Embeddings", + description="Run the Open AI CLIP model to embed image data.", + ) + @with_route_exceptions + async def clip_embed_image( + inference_request: ClipImageEmbeddingRequest, + request: Request, + api_key: Optional[str] = Query( + None, + description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval", + ), + ): + """ + Embeds image data using the OpenAI CLIP model. + + Args: + inference_request (ClipImageEmbeddingRequest): The request containing the image to be embedded. + api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval. + request (Request, default Body()): The HTTP request. + + Returns: + ClipEmbeddingResponse: The response containing the embedded image. + """ + logger.debug(f"Reached /clip/embed_image") + clip_model_id = load_clip_model(inference_request, api_key=api_key) + response = await self.model_manager.infer_from_request( + clip_model_id, inference_request + ) + if LAMBDA: + actor = request.scope["aws.event"]["requestContext"][ + "authorizer" + ]["lambda"]["actor"] + trackUsage(clip_model_id, actor) + return response + + @app.post( + "/clip/embed_text", + response_model=ClipEmbeddingResponse, + summary="CLIP Text Embeddings", + description="Run the Open AI CLIP model to embed text data.", + ) + @with_route_exceptions + async def clip_embed_text( + inference_request: ClipTextEmbeddingRequest, + request: Request, + api_key: Optional[str] = Query( + None, + description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval", + ), + ): + """ + Embeds text data using the OpenAI CLIP model. + + Args: + inference_request (ClipTextEmbeddingRequest): The request containing the text to be embedded. + api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval. + request (Request, default Body()): The HTTP request. + + Returns: + ClipEmbeddingResponse: The response containing the embedded text. + """ + logger.debug(f"Reached /clip/embed_text") + clip_model_id = load_clip_model(inference_request, api_key=api_key) + response = await self.model_manager.infer_from_request( + clip_model_id, inference_request + ) + if LAMBDA: + actor = request.scope["aws.event"]["requestContext"][ + "authorizer" + ]["lambda"]["actor"] + trackUsage(clip_model_id, actor) + return response + + @app.post( + "/clip/compare", + response_model=ClipCompareResponse, + summary="CLIP Compare", + description="Run the Open AI CLIP model to compute similarity scores.", + ) + @with_route_exceptions + async def clip_compare( + inference_request: ClipCompareRequest, + request: Request, + api_key: Optional[str] = Query( + None, + description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval", + ), + ): + """ + Computes similarity scores using the OpenAI CLIP model. + + Args: + inference_request (ClipCompareRequest): The request containing the data to be compared. + api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval. + request (Request, default Body()): The HTTP request. + + Returns: + ClipCompareResponse: The response containing the similarity scores. + """ + logger.debug(f"Reached /clip/compare") + clip_model_id = load_clip_model(inference_request, api_key=api_key) + response = await self.model_manager.infer_from_request( + clip_model_id, inference_request + ) + if LAMBDA: + actor = request.scope["aws.event"]["requestContext"][ + "authorizer" + ]["lambda"]["actor"] + trackUsage(clip_model_id, actor, n=2) + return response + + if CORE_MODEL_GROUNDINGDINO_ENABLED: + + @app.post( + "/grounding_dino/infer", + response_model=ObjectDetectionInferenceResponse, + summary="Grounding DINO inference.", + description="Run the Grounding DINO zero-shot object detection model.", + ) + @with_route_exceptions + async def grounding_dino_infer( + inference_request: GroundingDINOInferenceRequest, + request: Request, + api_key: Optional[str] = Query( + None, + description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval", + ), + ): + """ + Embeds image data using the Grounding DINO model. + + Args: + inference_request GroundingDINOInferenceRequest): The request containing the image on which to run object detection. + api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval. + request (Request, default Body()): The HTTP request. + + Returns: + ObjectDetectionInferenceResponse: The object detection response. + """ + logger.debug(f"Reached /grounding_dino/infer") + grounding_dino_model_id = load_grounding_dino_model( + inference_request, api_key=api_key + ) + response = await self.model_manager.infer_from_request( + grounding_dino_model_id, inference_request + ) + if LAMBDA: + actor = request.scope["aws.event"]["requestContext"][ + "authorizer" + ]["lambda"]["actor"] + trackUsage(grounding_dino_model_id, actor) + return response + + if CORE_MODEL_YOLO_WORLD_ENABLED: + + @app.post( + "/yolo_world/infer", + response_model=ObjectDetectionInferenceResponse, + summary="YOLO-World inference.", + description="Run the YOLO-World zero-shot object detection model.", + response_model_exclude_none=True, + ) + @with_route_exceptions + async def yolo_world_infer( + inference_request: YOLOWorldInferenceRequest, + request: Request, + api_key: Optional[str] = Query( + None, + description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval", + ), + ): + """ + Runs the YOLO-World zero-shot object detection model. + + Args: + inference_request (YOLOWorldInferenceRequest): The request containing the image on which to run object detection. + api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval. + request (Request, default Body()): The HTTP request. + + Returns: + ObjectDetectionInferenceResponse: The object detection response. + """ + logger.debug(f"Reached /yolo_world/infer") + yolo_world_model_id = load_yolo_world_model( + inference_request, api_key=api_key + ) + response = await self.model_manager.infer_from_request( + yolo_world_model_id, inference_request + ) + if LAMBDA: + actor = request.scope["aws.event"]["requestContext"][ + "authorizer" + ]["lambda"]["actor"] + trackUsage(yolo_world_model_id, actor) + return response + + if CORE_MODEL_DOCTR_ENABLED: + + @app.post( + "/doctr/ocr", + response_model=DoctrOCRInferenceResponse, + summary="DocTR OCR response", + description="Run the DocTR OCR model to retrieve text in an image.", + ) + @with_route_exceptions + async def doctr_retrieve_text( + inference_request: DoctrOCRInferenceRequest, + request: Request, + api_key: Optional[str] = Query( + None, + description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval", + ), + ): + """ + Embeds image data using the DocTR model. + + Args: + inference_request (M.DoctrOCRInferenceRequest): The request containing the image from which to retrieve text. + api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval. + request (Request, default Body()): The HTTP request. + + Returns: + M.DoctrOCRInferenceResponse: The response containing the embedded image. + """ + logger.debug(f"Reached /doctr/ocr") + doctr_model_id = load_doctr_model( + inference_request, api_key=api_key + ) + response = await self.model_manager.infer_from_request( + doctr_model_id, inference_request + ) + if LAMBDA: + actor = request.scope["aws.event"]["requestContext"][ + "authorizer" + ]["lambda"]["actor"] + trackUsage(doctr_model_id, actor) + return response + + if CORE_MODEL_SAM_ENABLED: + + @app.post( + "/sam/embed_image", + response_model=SamEmbeddingResponse, + summary="SAM Image Embeddings", + description="Run the Meta AI Segmant Anything Model to embed image data.", + ) + @with_route_exceptions + async def sam_embed_image( + inference_request: SamEmbeddingRequest, + request: Request, + api_key: Optional[str] = Query( + None, + description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval", + ), + ): + """ + Embeds image data using the Meta AI Segmant Anything Model (SAM). + + Args: + inference_request (SamEmbeddingRequest): The request containing the image to be embedded. + api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval. + request (Request, default Body()): The HTTP request. + + Returns: + M.SamEmbeddingResponse or Response: The response containing the embedded image. + """ + logger.debug(f"Reached /sam/embed_image") + sam_model_id = load_sam_model(inference_request, api_key=api_key) + model_response = await self.model_manager.infer_from_request( + sam_model_id, inference_request + ) + if LAMBDA: + actor = request.scope["aws.event"]["requestContext"][ + "authorizer" + ]["lambda"]["actor"] + trackUsage(sam_model_id, actor) + if inference_request.format == "binary": + return Response( + content=model_response.embeddings, + headers={"Content-Type": "application/octet-stream"}, + ) + return model_response + + @app.post( + "/sam/segment_image", + response_model=SamSegmentationResponse, + summary="SAM Image Segmentation", + description="Run the Meta AI Segmant Anything Model to generate segmenations for image data.", + ) + @with_route_exceptions + async def sam_segment_image( + inference_request: SamSegmentationRequest, + request: Request, + api_key: Optional[str] = Query( + None, + description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval", + ), + ): + """ + Generates segmentations for image data using the Meta AI Segmant Anything Model (SAM). + + Args: + inference_request (SamSegmentationRequest): The request containing the image to be segmented. + api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval. + request (Request, default Body()): The HTTP request. + + Returns: + M.SamSegmentationResponse or Response: The response containing the segmented image. + """ + logger.debug(f"Reached /sam/segment_image") + sam_model_id = load_sam_model(inference_request, api_key=api_key) + model_response = await self.model_manager.infer_from_request( + sam_model_id, inference_request + ) + if LAMBDA: + actor = request.scope["aws.event"]["requestContext"][ + "authorizer" + ]["lambda"]["actor"] + trackUsage(sam_model_id, actor) + if inference_request.format == "binary": + return Response( + content=model_response, + headers={"Content-Type": "application/octet-stream"}, + ) + return model_response + + if CORE_MODEL_GAZE_ENABLED: + + @app.post( + "/gaze/gaze_detection", + response_model=List[GazeDetectionInferenceResponse], + summary="Gaze Detection", + description="Run the gaze detection model to detect gaze.", + ) + @with_route_exceptions + async def gaze_detection( + inference_request: GazeDetectionInferenceRequest, + request: Request, + api_key: Optional[str] = Query( + None, + description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval", + ), + ): + """ + Detect gaze using the gaze detection model. + + Args: + inference_request (M.GazeDetectionRequest): The request containing the image to be detected. + api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval. + request (Request, default Body()): The HTTP request. + + Returns: + M.GazeDetectionResponse: The response containing all the detected faces and the corresponding gazes. + """ + logger.debug(f"Reached /gaze/gaze_detection") + gaze_model_id = load_gaze_model(inference_request, api_key=api_key) + response = await self.model_manager.infer_from_request( + gaze_model_id, inference_request + ) + if LAMBDA: + actor = request.scope["aws.event"]["requestContext"][ + "authorizer" + ]["lambda"]["actor"] + trackUsage(gaze_model_id, actor) + return response + + if CORE_MODEL_COGVLM_ENABLED: + + @app.post( + "/llm/cogvlm", + response_model=CogVLMResponse, + summary="CogVLM", + description="Run the CogVLM model to chat or describe an image.", + ) + @with_route_exceptions + async def cog_vlm( + inference_request: CogVLMInferenceRequest, + request: Request, + api_key: Optional[str] = Query( + None, + description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval", + ), + ): + """ + Chat with CogVLM or ask it about an image. Multi-image requests not currently supported. + + Args: + inference_request (M.CogVLMInferenceRequest): The request containing the prompt and image to be described. + api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval. + request (Request, default Body()): The HTTP request. + + Returns: + M.CogVLMResponse: The model's text response + """ + logger.debug(f"Reached /llm/cogvlm") + cog_model_id = load_cogvlm_model(inference_request, api_key=api_key) + response = await self.model_manager.infer_from_request( + cog_model_id, inference_request + ) + if LAMBDA: + actor = request.scope["aws.event"]["requestContext"][ + "authorizer" + ]["lambda"]["actor"] + trackUsage(cog_model_id, actor) + return response + + if LEGACY_ROUTE_ENABLED: + # Legacy object detection inference path for backwards compatability + @app.post( + "/{dataset_id}/{version_id}", + # Order matters in this response model Union. It will use the first matching model. For example, Object Detection Inference Response is a subset of Instance segmentation inference response, so instance segmentation must come first in order for the matching logic to work. + response_model=Union[ + InstanceSegmentationInferenceResponse, + KeypointsDetectionInferenceResponse, + ObjectDetectionInferenceResponse, + ClassificationInferenceResponse, + MultiLabelClassificationInferenceResponse, + StubResponse, + Any, + ], + response_model_exclude_none=True, + ) + @with_route_exceptions + async def legacy_infer_from_request( + background_tasks: BackgroundTasks, + request: Request, + dataset_id: str = Path( + description="ID of a Roboflow dataset corresponding to the model to use for inference" + ), + version_id: str = Path( + description="ID of a Roboflow dataset version corresponding to the model to use for inference" + ), + api_key: Optional[str] = Query( + None, + description="Roboflow API Key that will be passed to the model during initialization for artifact retrieval", + ), + confidence: float = Query( + 0.4, + description="The confidence threshold used to filter out predictions", + ), + keypoint_confidence: float = Query( + 0.0, + description="The confidence threshold used to filter out keypoints that are not visible based on model confidence", + ), + format: str = Query( + "json", + description="One of 'json' or 'image'. If 'json' prediction data is return as a JSON string. If 'image' prediction data is visualized and overlayed on the original input image.", + ), + image: Optional[str] = Query( + None, + description="The publically accessible URL of an image to use for inference.", + ), + image_type: Optional[str] = Query( + "base64", + description="One of base64 or numpy. Note, numpy input is not supported for Roboflow Hosted Inference.", + ), + labels: Optional[bool] = Query( + False, + description="If true, labels will be include in any inference visualization.", + ), + mask_decode_mode: Optional[str] = Query( + "accurate", + description="One of 'accurate' or 'fast'. If 'accurate' the mask will be decoded using the original image size. If 'fast' the mask will be decoded using the original mask size. 'accurate' is slower but more accurate.", + ), + tradeoff_factor: Optional[float] = Query( + 0.0, + description="The amount to tradeoff between 0='fast' and 1='accurate'", + ), + max_detections: int = Query( + 300, + description="The maximum number of detections to return. This is used to limit the number of predictions returned by the model. The model may return more predictions than this number, but only the top `max_detections` predictions will be returned.", + ), + overlap: float = Query( + 0.3, + description="The IoU threhsold that must be met for a box pair to be considered duplicate during NMS", + ), + stroke: int = Query( + 1, description="The stroke width used when visualizing predictions" + ), + countinference: Optional[bool] = Query( + True, + description="If false, does not track inference against usage.", + include_in_schema=False, + ), + service_secret: Optional[str] = Query( + None, + description="Shared secret used to authenticate requests to the inference server from internal services (e.g. to allow disabling inference usage tracking via the `countinference` query parameter)", + include_in_schema=False, + ), + disable_preproc_auto_orient: Optional[bool] = Query( + False, description="If true, disables automatic image orientation" + ), + disable_preproc_contrast: Optional[bool] = Query( + False, description="If true, disables automatic contrast adjustment" + ), + disable_preproc_grayscale: Optional[bool] = Query( + False, + description="If true, disables automatic grayscale conversion", + ), + disable_preproc_static_crop: Optional[bool] = Query( + False, description="If true, disables automatic static crop" + ), + disable_active_learning: Optional[bool] = Query( + default=False, + description="If true, the predictions will be prevented from registration by Active Learning (if the functionality is enabled)", + ), + source: Optional[str] = Query( + "external", + description="The source of the inference request", + ), + source_info: Optional[str] = Query( + "external", + description="The detailed source information of the inference request", + ), + ): + """ + Legacy inference endpoint for object detection, instance segmentation, and classification. + + Args: + background_tasks: (BackgroundTasks) pool of fastapi background tasks + dataset_id (str): ID of a Roboflow dataset corresponding to the model to use for inference. + version_id (str): ID of a Roboflow dataset version corresponding to the model to use for inference. + api_key (Optional[str], default None): Roboflow API Key passed to the model during initialization for artifact retrieval. + # Other parameters described in the function signature... + + Returns: + Union[InstanceSegmentationInferenceResponse, KeypointsDetectionInferenceRequest, ObjectDetectionInferenceResponse, ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse, Any]: The response containing the inference results. + """ + logger.debug( + f"Reached legacy route /:dataset_id/:version_id with {dataset_id}/{version_id}" + ) + model_id = f"{dataset_id}/{version_id}" + + if confidence >= 1: + confidence /= 100 + elif confidence < 0.01: + confidence = 0.01 + + if overlap >= 1: + overlap /= 100 + + if image is not None: + request_image = InferenceRequestImage(type="url", value=image) + else: + if "Content-Type" not in request.headers: + raise ContentTypeMissing( + f"Request must include a Content-Type header" + ) + if "multipart/form-data" in request.headers["Content-Type"]: + form_data = await request.form() + base64_image_str = await form_data["file"].read() + base64_image_str = base64.b64encode(base64_image_str) + request_image = InferenceRequestImage( + type="base64", value=base64_image_str.decode("ascii") + ) + elif ( + "application/x-www-form-urlencoded" + in request.headers["Content-Type"] + or "application/json" in request.headers["Content-Type"] + ): + data = await request.body() + request_image = InferenceRequestImage( + type=image_type, value=data + ) + else: + raise ContentTypeInvalid( + f"Invalid Content-Type: {request.headers['Content-Type']}" + ) + + if LAMBDA: + request_model_id = ( + request.scope["aws.event"]["requestContext"]["authorizer"][ + "lambda" + ]["model"]["endpoint"] + .replace("--", "/") + .replace("rf-", "") + .replace("nu-", "") + ) + actor = request.scope["aws.event"]["requestContext"]["authorizer"][ + "lambda" + ]["actor"] + if countinference: + trackUsage(request_model_id, actor) + else: + if service_secret != ROBOFLOW_SERVICE_SECRET: + raise MissingServiceSecretError( + "Service secret is required to disable inference usage tracking" + ) + else: + request_model_id = model_id + self.model_manager.add_model( + request_model_id, api_key, model_id_alias=model_id + ) + + task_type = self.model_manager.get_task_type(model_id, api_key=api_key) + inference_request_type = ObjectDetectionInferenceRequest + args = dict() + if task_type == "instance-segmentation": + inference_request_type = InstanceSegmentationInferenceRequest + args = { + "mask_decode_mode": mask_decode_mode, + "tradeoff_factor": tradeoff_factor, + } + elif task_type == "classification": + inference_request_type = ClassificationInferenceRequest + elif task_type == "keypoint-detection": + inference_request_type = KeypointsDetectionInferenceRequest + args = {"keypoint_confidence": keypoint_confidence} + inference_request = inference_request_type( + api_key=api_key, + model_id=model_id, + image=request_image, + confidence=confidence, + iou_threshold=overlap, + max_detections=max_detections, + visualization_labels=labels, + visualization_stroke_width=stroke, + visualize_predictions=True if format == "image" else False, + disable_preproc_auto_orient=disable_preproc_auto_orient, + disable_preproc_contrast=disable_preproc_contrast, + disable_preproc_grayscale=disable_preproc_grayscale, + disable_preproc_static_crop=disable_preproc_static_crop, + disable_active_learning=disable_active_learning, + source=source, + source_info=source_info, + **args, + ) + + inference_response = await self.model_manager.infer_from_request( + inference_request.model_id, + inference_request, + active_learning_eligible=True, + background_tasks=background_tasks, + ) + logger.debug("Response ready.") + if format == "image": + return Response( + content=inference_response.visualization, + media_type="image/jpeg", + ) + else: + return orjson_response(inference_response) + + if not LAMBDA: + # Legacy clear cache endpoint for backwards compatability + @app.get("/clear_cache", response_model=str) + async def legacy_clear_cache(): + """ + Clears the model cache. + + This endpoint provides a way to clear the cache of loaded models. + + Returns: + str: A string indicating that the cache has been cleared. + """ + logger.debug(f"Reached /clear_cache") + await model_clear() + return "Cache Cleared" + + # Legacy add model endpoint for backwards compatability + @app.get("/start/{dataset_id}/{version_id}") + async def model_add(dataset_id: str, version_id: str, api_key: str = None): + """ + Starts a model inference session. + + This endpoint initializes and starts an inference session for the specified model version. + + Args: + dataset_id (str): ID of a Roboflow dataset corresponding to the model. + version_id (str): ID of a Roboflow dataset version corresponding to the model. + api_key (str, optional): Roboflow API Key for artifact retrieval. + + Returns: + JSONResponse: A response object containing the status and a success message. + """ + logger.debug( + f"Reached /start/{dataset_id}/{version_id} with {dataset_id}/{version_id}" + ) + model_id = f"{dataset_id}/{version_id}" + self.model_manager.add_model(model_id, api_key) + + return JSONResponse( + { + "status": 200, + "message": "inference session started from local memory.", + } + ) + + if not LAMBDA: + + @app.get( + "/notebook/start", + summary="Jupyter Lab Server Start", + description="Starts a jupyter lab server for running development code", + ) + @with_route_exceptions + async def notebook_start(browserless: bool = False): + """Starts a jupyter lab server for running development code. + + Args: + inference_request (NotebookStartRequest): The request containing the necessary details for starting a jupyter lab server. + background_tasks: (BackgroundTasks) pool of fastapi background tasks + + Returns: + NotebookStartResponse: The response containing the URL of the jupyter lab server. + """ + logger.debug(f"Reached /notebook/start") + if NOTEBOOK_ENABLED: + start_notebook() + if browserless: + return { + "success": True, + "message": f"Jupyter Lab server started at http://localhost:{NOTEBOOK_PORT}?token={NOTEBOOK_PASSWORD}", + } + else: + sleep(2) + return RedirectResponse( + f"http://localhost:{NOTEBOOK_PORT}/lab/tree/quickstart.ipynb?token={NOTEBOOK_PASSWORD}" + ) + else: + if browserless: + return { + "success": False, + "message": "Notebook server is not enabled. Enable notebooks via the NOTEBOOK_ENABLED environment variable.", + } + else: + return RedirectResponse(f"/notebook-instructions.html") + + app.mount( + "/", + StaticFiles(directory="./inference/landing/out", html=True), + name="static", + ) + + def run(self): + uvicorn.run(self.app, host="127.0.0.1", port=8080) diff --git a/inference/core/interfaces/http/orjson_utils.py b/inference/core/interfaces/http/orjson_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a651fefa8f211d142a24a17a7243ee06061ef276 --- /dev/null +++ b/inference/core/interfaces/http/orjson_utils.py @@ -0,0 +1,80 @@ +import base64 +from typing import Any, Dict, List, Optional, Union + +import orjson +from fastapi.responses import ORJSONResponse +from pydantic import BaseModel + +from inference.core.entities.responses.inference import InferenceResponse +from inference.core.utils.image_utils import ImageType, encode_image_to_jpeg_bytes + + +class ORJSONResponseBytes(ORJSONResponse): + def render(self, content: Any) -> bytes: + return orjson.dumps( + content, + default=default, + option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY, + ) + + +JSON = Union[Dict[str, "JSON"], List["JSON"], str, int, float, bool, None] + + +def default(obj: Any) -> JSON: + if isinstance(obj, bytes): + return base64.b64encode(obj).decode("ascii") + return obj + + +def orjson_response( + response: Union[List[InferenceResponse], InferenceResponse, BaseModel] +) -> ORJSONResponseBytes: + if isinstance(response, list): + content = [r.dict(by_alias=True, exclude_none=True) for r in response] + else: + content = response.dict(by_alias=True, exclude_none=True) + return ORJSONResponseBytes(content=content) + + +def serialise_workflow_result( + result: Dict[str, Any], + excluded_fields: Optional[List[str]] = None, +) -> Dict[str, Any]: + if excluded_fields is None: + excluded_fields = [] + excluded_fields = set(excluded_fields) + serialised_result = {} + for key, value in result.items(): + if key in excluded_fields: + continue + if contains_image(element=value): + value = serialise_image(image=value) + elif issubclass(type(value), list): + value = serialise_list(elements=value) + serialised_result[key] = value + return serialised_result + + +def serialise_list(elements: List[Any]) -> List[Any]: + result = [] + for element in elements: + if contains_image(element=element): + element = serialise_image(image=element) + result.append(element) + return result + + +def contains_image(element: Any) -> bool: + return ( + issubclass(type(element), dict) + and element.get("type") == ImageType.NUMPY_OBJECT.value + ) + + +def serialise_image(image: Dict[str, Any]) -> Dict[str, Any]: + image["type"] = "base64" + image["value"] = base64.b64encode( + encode_image_to_jpeg_bytes(image["value"]) + ).decode("ascii") + return image diff --git a/inference/core/interfaces/stream/__init__.py b/inference/core/interfaces/stream/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/core/interfaces/stream/__pycache__/__init__.cpython-310.pyc b/inference/core/interfaces/stream/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01996cb282d1e9883724f66b2acd9a1b9ed29efd Binary files /dev/null and b/inference/core/interfaces/stream/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/interfaces/stream/__pycache__/entities.cpython-310.pyc b/inference/core/interfaces/stream/__pycache__/entities.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8a3173b94394d63fc190bf1a73123af8bfc6b01 Binary files /dev/null and b/inference/core/interfaces/stream/__pycache__/entities.cpython-310.pyc differ diff --git a/inference/core/interfaces/stream/__pycache__/inference_pipeline.cpython-310.pyc b/inference/core/interfaces/stream/__pycache__/inference_pipeline.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ed5f7a8650e6c846fd7972eb89a1927d2c93767 Binary files /dev/null and b/inference/core/interfaces/stream/__pycache__/inference_pipeline.cpython-310.pyc differ diff --git a/inference/core/interfaces/stream/__pycache__/sinks.cpython-310.pyc b/inference/core/interfaces/stream/__pycache__/sinks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..948eec849818012bf47f8b56cb641b3328888763 Binary files /dev/null and b/inference/core/interfaces/stream/__pycache__/sinks.cpython-310.pyc differ diff --git a/inference/core/interfaces/stream/__pycache__/stream.cpython-310.pyc b/inference/core/interfaces/stream/__pycache__/stream.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84b87617ebac9c571b18778a5a092b4806df9a84 Binary files /dev/null and b/inference/core/interfaces/stream/__pycache__/stream.cpython-310.pyc differ diff --git a/inference/core/interfaces/stream/__pycache__/watchdog.cpython-310.pyc b/inference/core/interfaces/stream/__pycache__/watchdog.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03bde1b2d5889c4c2891e1e98671d6702304cbed Binary files /dev/null and b/inference/core/interfaces/stream/__pycache__/watchdog.cpython-310.pyc differ diff --git a/inference/core/interfaces/stream/entities.py b/inference/core/interfaces/stream/entities.py new file mode 100644 index 0000000000000000000000000000000000000000..95a1e3f4a02b86027330ffa6c392ade059b54607 --- /dev/null +++ b/inference/core/interfaces/stream/entities.py @@ -0,0 +1,121 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Dict, List, Optional, Union + +from inference.core.env import ( + CLASS_AGNOSTIC_NMS_ENV, + DEFAULT_CLASS_AGNOSTIC_NMS, + DEFAULT_CONFIDENCE, + DEFAULT_IOU_THRESHOLD, + DEFAULT_MAX_CANDIDATES, + DEFAULT_MAX_DETECTIONS, + IOU_THRESHOLD_ENV, + MAX_CANDIDATES_ENV, + MAX_DETECTIONS_ENV, +) +from inference.core.interfaces.camera.entities import StatusUpdate +from inference.core.interfaces.camera.video_source import SourceMetadata +from inference.core.utils.environment import safe_env_to_type, str2bool + +ObjectDetectionPrediction = dict + + +@dataclass(frozen=True) +class ModelConfig: + class_agnostic_nms: Optional[bool] + confidence: Optional[float] + iou_threshold: Optional[float] + max_candidates: Optional[int] + max_detections: Optional[int] + mask_decode_mode: Optional[str] + tradeoff_factor: Optional[float] + + @classmethod + def init( + cls, + class_agnostic_nms: Optional[bool] = None, + confidence: Optional[float] = None, + iou_threshold: Optional[float] = None, + max_candidates: Optional[int] = None, + max_detections: Optional[int] = None, + mask_decode_mode: Optional[str] = None, + tradeoff_factor: Optional[float] = None, + ) -> "ModelConfig": + if class_agnostic_nms is None: + class_agnostic_nms = safe_env_to_type( + variable_name=CLASS_AGNOSTIC_NMS_ENV, + default_value=DEFAULT_CLASS_AGNOSTIC_NMS, + type_constructor=str2bool, + ) + if confidence is None: + confidence = safe_env_to_type( + variable_name=CLASS_AGNOSTIC_NMS_ENV, + default_value=DEFAULT_CONFIDENCE, + type_constructor=float, + ) + if iou_threshold is None: + iou_threshold = safe_env_to_type( + variable_name=IOU_THRESHOLD_ENV, + default_value=DEFAULT_IOU_THRESHOLD, + type_constructor=float, + ) + if max_candidates is None: + max_candidates = safe_env_to_type( + variable_name=MAX_CANDIDATES_ENV, + default_value=DEFAULT_MAX_CANDIDATES, + type_constructor=int, + ) + if max_detections is None: + max_detections = safe_env_to_type( + variable_name=MAX_DETECTIONS_ENV, + default_value=DEFAULT_MAX_DETECTIONS, + type_constructor=int, + ) + return ModelConfig( + class_agnostic_nms=class_agnostic_nms, + confidence=confidence, + iou_threshold=iou_threshold, + max_candidates=max_candidates, + max_detections=max_detections, + mask_decode_mode=mask_decode_mode, + tradeoff_factor=tradeoff_factor, + ) + + def to_postprocessing_params(self) -> Dict[str, Union[bool, float, int]]: + result = {} + for field in [ + "class_agnostic_nms", + "confidence", + "iou_threshold", + "max_candidates", + "max_detections", + "mask_decode_mode", + "tradeoff_factor", + ]: + result[field] = getattr(self, field, None) + return {name: value for name, value in result.items() if value is not None} + + +@dataclass(frozen=True) +class ModelActivityEvent: + frame_decoding_timestamp: datetime + event_timestamp: datetime + frame_id: int + + +@dataclass(frozen=True) +class LatencyMonitorReport: + frame_decoding_latency: Optional[float] = None + pre_processing_latency: Optional[float] = None + inference_latency: Optional[float] = None + post_processing_latency: Optional[float] = None + model_latency: Optional[float] = None + e2e_latency: Optional[float] = None + + +@dataclass(frozen=True) +class PipelineStateReport: + video_source_status_updates: List[StatusUpdate] + latency_report: LatencyMonitorReport + inference_throughput: float + source_metadata: Optional[SourceMetadata] diff --git a/inference/core/interfaces/stream/inference_pipeline.py b/inference/core/interfaces/stream/inference_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..2af306d64614d3f892692359767a9ff5270989ee --- /dev/null +++ b/inference/core/interfaces/stream/inference_pipeline.py @@ -0,0 +1,457 @@ +import time +from datetime import datetime +from functools import partial +from queue import Queue +from threading import Thread +from typing import Callable, Generator, List, Optional, Tuple, Union + +from inference.core import logger +from inference.core.active_learning.middlewares import ( + NullActiveLearningMiddleware, + ThreadingActiveLearningMiddleware, +) +from inference.core.cache import cache +from inference.core.env import ( + ACTIVE_LEARNING_ENABLED, + API_KEY, + API_KEY_ENV_NAMES, + DISABLE_PREPROC_AUTO_ORIENT, + PREDICTIONS_QUEUE_SIZE, + RESTART_ATTEMPT_DELAY, +) +from inference.core.exceptions import MissingApiKeyError +from inference.core.interfaces.camera.entities import ( + StatusUpdate, + UpdateSeverity, + VideoFrame, +) +from inference.core.interfaces.camera.exceptions import SourceConnectionError +from inference.core.interfaces.camera.utils import get_video_frames_generator +from inference.core.interfaces.camera.video_source import ( + BufferConsumptionStrategy, + BufferFillingStrategy, + VideoSource, +) +from inference.core.interfaces.stream.entities import ( + ModelConfig, + ObjectDetectionPrediction, +) +from inference.core.interfaces.stream.sinks import active_learning_sink, multi_sink +from inference.core.interfaces.stream.watchdog import ( + NullPipelineWatchdog, + PipelineWatchDog, +) +from inference.core.models.roboflow import OnnxRoboflowInferenceModel +from inference.models.utils import get_roboflow_model + +INFERENCE_PIPELINE_CONTEXT = "inference_pipeline" +SOURCE_CONNECTION_ATTEMPT_FAILED_EVENT = "SOURCE_CONNECTION_ATTEMPT_FAILED" +SOURCE_CONNECTION_LOST_EVENT = "SOURCE_CONNECTION_LOST" +INFERENCE_RESULTS_DISPATCHING_ERROR_EVENT = "INFERENCE_RESULTS_DISPATCHING_ERROR" +INFERENCE_THREAD_STARTED_EVENT = "INFERENCE_THREAD_STARTED" +INFERENCE_THREAD_FINISHED_EVENT = "INFERENCE_THREAD_FINISHED" +INFERENCE_COMPLETED_EVENT = "INFERENCE_COMPLETED" +INFERENCE_ERROR_EVENT = "INFERENCE_ERROR" + + +class InferencePipeline: + @classmethod + def init( + cls, + model_id: str, + video_reference: Union[str, int], + on_prediction: Callable[[ObjectDetectionPrediction, VideoFrame], None], + api_key: Optional[str] = None, + max_fps: Optional[Union[float, int]] = None, + watchdog: Optional[PipelineWatchDog] = None, + status_update_handlers: Optional[List[Callable[[StatusUpdate], None]]] = None, + source_buffer_filling_strategy: Optional[BufferFillingStrategy] = None, + source_buffer_consumption_strategy: Optional[BufferConsumptionStrategy] = None, + class_agnostic_nms: Optional[bool] = None, + confidence: Optional[float] = None, + iou_threshold: Optional[float] = None, + max_candidates: Optional[int] = None, + max_detections: Optional[int] = None, + mask_decode_mode: Optional[str] = "accurate", + tradeoff_factor: Optional[float] = 0.0, + active_learning_enabled: Optional[bool] = None, + ) -> "InferencePipeline": + """ + This class creates the abstraction for making inferences from CV models against video stream. + It allows to choose Object Detection model from Roboflow platform and run predictions against + video streams - just by the price of specifying which model to use and what to do with predictions. + + It allows to set the model post-processing parameters (via .init() or env) and intercept updates + related to state of pipeline via `PipelineWatchDog` abstraction (although that is something probably + useful only for advanced use-cases). + + For maximum efficiency, all separate chunks of processing: video decoding, inference, results dispatching + are handled by separate threads. + + Given that reference to stream is passed and connectivity is lost - it attempts to re-connect with delay. + + Since version 0.9.11 it works not only for object detection models but is also compatible with stubs, + classification, instance-segmentation and keypoint-detection models. + + Args: + model_id (str): Name and version of model at Roboflow platform (example: "my-model/3") + video_reference (Union[str, int]): Reference of source to be used to make predictions against. + It can be video file path, stream URL and device (like camera) id (we handle whatever cv2 handles). + on_prediction (Callable[ObjectDetectionPrediction, VideoFrame], None]): Function to be called + once prediction is ready - passing both decoded frame, their metadata and dict with standard + Roboflow Object Detection prediction. + api_key (Optional[str]): Roboflow API key - if not passed - will be looked in env under "ROBOFLOW_API_KEY" + and "API_KEY" variables. API key, passed in some form is required. + max_fps (Optional[Union[float, int]]): Specific value passed as this parameter will be used to + dictate max FPS of processing. It can be useful if we wanted to run concurrent inference pipelines + on single machine making tradeoff between number of frames and number of streams handled. Disabled + by default. + watchdog (Optional[PipelineWatchDog]): Implementation of class that allows profiling of + inference pipeline - if not given null implementation (doing nothing) will be used. + status_update_handlers (Optional[List[Callable[[StatusUpdate], None]]]): List of handlers to intercept + status updates of all elements of the pipeline. Should be used only if detailed inspection of + pipeline behaviour in time is needed. Please point out that handlers should be possible to be executed + fast - otherwise they will impair pipeline performance. All errors will be logged as warnings + without re-raising. Default: None. + source_buffer_filling_strategy (Optional[BufferFillingStrategy]): Parameter dictating strategy for + video stream decoding behaviour. By default - tweaked to the type of source given. + Please find detailed explanation in docs of [`VideoSource`](../camera/video_source.py) + source_buffer_consumption_strategy (Optional[BufferConsumptionStrategy]): Parameter dictating strategy for + video stream frames consumption. By default - tweaked to the type of source given. + Please find detailed explanation in docs of [`VideoSource`](../camera/video_source.py) + class_agnostic_nms (Optional[bool]): Parameter of model post-processing. If not given - value checked in + env variable "CLASS_AGNOSTIC_NMS" with default "False" + confidence (Optional[float]): Parameter of model post-processing. If not given - value checked in + env variable "CONFIDENCE" with default "0.5" + iou_threshold (Optional[float]): Parameter of model post-processing. If not given - value checked in + env variable "IOU_THRESHOLD" with default "0.5" + max_candidates (Optional[int]): Parameter of model post-processing. If not given - value checked in + env variable "MAX_CANDIDATES" with default "3000" + max_detections (Optional[int]): Parameter of model post-processing. If not given - value checked in + env variable "MAX_DETECTIONS" with default "300" + mask_decode_mode: (Optional[str]): Parameter of model post-processing. If not given - model "accurate" is + used. Applicable for instance segmentation models + tradeoff_factor (Optional[float]): Parameter of model post-processing. If not 0.0 - model default is used. + Applicable for instance segmentation models + active_learning_enabled (Optional[bool]): Flag to enable / disable Active Learning middleware (setting it + true does not guarantee any data to be collected, as data collection is controlled by Roboflow backend - + it just enables middleware intercepting predictions). If not given, env variable + `ACTIVE_LEARNING_ENABLED` will be used. Please point out that Active Learning will be forcefully + disabled in a scenario when Roboflow API key is not given, as Roboflow account is required + for this feature to be operational. + + Other ENV variables involved in low-level configuration: + * INFERENCE_PIPELINE_PREDICTIONS_QUEUE_SIZE - size of buffer for predictions that are ready for dispatching + * INFERENCE_PIPELINE_RESTART_ATTEMPT_DELAY - delay for restarts on stream connection drop + * ACTIVE_LEARNING_ENABLED - controls Active Learning middleware if explicit parameter not given + + Returns: Instance of InferencePipeline + + Throws: + * SourceConnectionError if source cannot be connected at start, however it attempts to reconnect + always if connection to stream is lost. + """ + if api_key is None: + api_key = API_KEY + if status_update_handlers is None: + status_update_handlers = [] + inference_config = ModelConfig.init( + class_agnostic_nms=class_agnostic_nms, + confidence=confidence, + iou_threshold=iou_threshold, + max_candidates=max_candidates, + max_detections=max_detections, + mask_decode_mode=mask_decode_mode, + tradeoff_factor=tradeoff_factor, + ) + model = get_roboflow_model(model_id=model_id, api_key=api_key) + if watchdog is None: + watchdog = NullPipelineWatchdog() + status_update_handlers.append(watchdog.on_status_update) + video_source = VideoSource.init( + video_reference=video_reference, + status_update_handlers=status_update_handlers, + buffer_filling_strategy=source_buffer_filling_strategy, + buffer_consumption_strategy=source_buffer_consumption_strategy, + ) + watchdog.register_video_source(video_source=video_source) + predictions_queue = Queue(maxsize=PREDICTIONS_QUEUE_SIZE) + active_learning_middleware = NullActiveLearningMiddleware() + if active_learning_enabled is None: + logger.info( + f"`active_learning_enabled` parameter not set - using env `ACTIVE_LEARNING_ENABLED` " + f"with value: {ACTIVE_LEARNING_ENABLED}" + ) + active_learning_enabled = ACTIVE_LEARNING_ENABLED + if api_key is None: + logger.info( + f"Roboflow API key not given - Active Learning is forced to be disabled." + ) + active_learning_enabled = False + if active_learning_enabled is True: + active_learning_middleware = ThreadingActiveLearningMiddleware.init( + api_key=api_key, + model_id=model_id, + cache=cache, + ) + al_sink = partial( + active_learning_sink, + active_learning_middleware=active_learning_middleware, + model_type=model.task_type, + disable_preproc_auto_orient=DISABLE_PREPROC_AUTO_ORIENT, + ) + logger.info( + "AL enabled - wrapping `on_prediction` with multi_sink() and active_learning_sink()" + ) + on_prediction = partial(multi_sink, sinks=[on_prediction, al_sink]) + return cls( + model=model, + video_source=video_source, + on_prediction=on_prediction, + max_fps=max_fps, + predictions_queue=predictions_queue, + watchdog=watchdog, + status_update_handlers=status_update_handlers, + inference_config=inference_config, + active_learning_middleware=active_learning_middleware, + ) + + def __init__( + self, + model: OnnxRoboflowInferenceModel, + video_source: VideoSource, + on_prediction: Callable[[ObjectDetectionPrediction, VideoFrame], None], + max_fps: Optional[float], + predictions_queue: Queue, + watchdog: PipelineWatchDog, + status_update_handlers: List[Callable[[StatusUpdate], None]], + inference_config: ModelConfig, + active_learning_middleware: Union[ + NullActiveLearningMiddleware, ThreadingActiveLearningMiddleware + ], + ): + self._model = model + self._video_source = video_source + self._on_prediction = on_prediction + self._max_fps = max_fps + self._predictions_queue = predictions_queue + self._watchdog = watchdog + self._command_handler_thread: Optional[Thread] = None + self._inference_thread: Optional[Thread] = None + self._dispatching_thread: Optional[Thread] = None + self._stop = False + self._camera_restart_ongoing = False + self._status_update_handlers = status_update_handlers + self._inference_config = inference_config + self._active_learning_middleware = active_learning_middleware + + def start(self, use_main_thread: bool = True) -> None: + self._stop = False + self._inference_thread = Thread(target=self._execute_inference) + self._inference_thread.start() + if self._active_learning_middleware is not None: + self._active_learning_middleware.start_registration_thread() + if use_main_thread: + self._dispatch_inference_results() + else: + self._dispatching_thread = Thread(target=self._dispatch_inference_results) + self._dispatching_thread.start() + + def terminate(self) -> None: + self._stop = True + self._video_source.terminate() + + def pause_stream(self) -> None: + self._video_source.pause() + + def mute_stream(self) -> None: + self._video_source.mute() + + def resume_stream(self) -> None: + self._video_source.resume() + + def join(self) -> None: + if self._inference_thread is not None: + self._inference_thread.join() + self._inference_thread = None + if self._dispatching_thread is not None: + self._dispatching_thread.join() + self._dispatching_thread = None + if self._active_learning_middleware is not None: + self._active_learning_middleware.stop_registration_thread() + + def _execute_inference(self) -> None: + send_inference_pipeline_status_update( + severity=UpdateSeverity.INFO, + event_type=INFERENCE_THREAD_STARTED_EVENT, + status_update_handlers=self._status_update_handlers, + ) + logger.info(f"Inference thread started") + try: + for video_frame in self._generate_frames(): + self._watchdog.on_model_preprocessing_started( + frame_timestamp=video_frame.frame_timestamp, + frame_id=video_frame.frame_id, + ) + preprocessed_image, preprocessing_metadata = self._model.preprocess( + video_frame.image + ) + self._watchdog.on_model_inference_started( + frame_timestamp=video_frame.frame_timestamp, + frame_id=video_frame.frame_id, + ) + predictions = self._model.predict(preprocessed_image) + self._watchdog.on_model_postprocessing_started( + frame_timestamp=video_frame.frame_timestamp, + frame_id=video_frame.frame_id, + ) + postprocessing_args = self._inference_config.to_postprocessing_params() + predictions = self._model.postprocess( + predictions, + preprocessing_metadata, + **postprocessing_args, + ) + if issubclass(type(predictions), list): + predictions = predictions[0].dict( + by_alias=True, + exclude_none=True, + ) + self._watchdog.on_model_prediction_ready( + frame_timestamp=video_frame.frame_timestamp, + frame_id=video_frame.frame_id, + ) + self._predictions_queue.put((predictions, video_frame)) + send_inference_pipeline_status_update( + severity=UpdateSeverity.DEBUG, + event_type=INFERENCE_COMPLETED_EVENT, + payload={ + "frame_id": video_frame.frame_id, + "frame_timestamp": video_frame.frame_timestamp, + }, + status_update_handlers=self._status_update_handlers, + ) + except Exception as error: + payload = { + "error_type": error.__class__.__name__, + "error_message": str(error), + "error_context": "inference_thread", + } + send_inference_pipeline_status_update( + severity=UpdateSeverity.ERROR, + event_type=INFERENCE_ERROR_EVENT, + payload=payload, + status_update_handlers=self._status_update_handlers, + ) + logger.exception(f"Encountered inference error: {error}") + finally: + self._predictions_queue.put(None) + send_inference_pipeline_status_update( + severity=UpdateSeverity.INFO, + event_type=INFERENCE_THREAD_FINISHED_EVENT, + status_update_handlers=self._status_update_handlers, + ) + logger.info(f"Inference thread finished") + + def _dispatch_inference_results(self) -> None: + while True: + inference_results: Optional[Tuple[dict, VideoFrame]] = ( + self._predictions_queue.get() + ) + if inference_results is None: + self._predictions_queue.task_done() + break + predictions, video_frame = inference_results + try: + self._on_prediction(predictions, video_frame) + except Exception as error: + payload = { + "error_type": error.__class__.__name__, + "error_message": str(error), + "error_context": "inference_results_dispatching", + } + send_inference_pipeline_status_update( + severity=UpdateSeverity.ERROR, + event_type=INFERENCE_RESULTS_DISPATCHING_ERROR_EVENT, + payload=payload, + status_update_handlers=self._status_update_handlers, + ) + logger.warning(f"Error in results dispatching - {error}") + finally: + self._predictions_queue.task_done() + + def _generate_frames( + self, + ) -> Generator[VideoFrame, None, None]: + self._video_source.start() + while True: + source_properties = self._video_source.describe_source().source_properties + if source_properties is None: + break + allow_reconnect = not source_properties.is_file + yield from get_video_frames_generator( + video=self._video_source, max_fps=self._max_fps + ) + if not allow_reconnect: + self.terminate() + break + if self._stop: + break + logger.warning(f"Lost connection with video source.") + send_inference_pipeline_status_update( + severity=UpdateSeverity.WARNING, + event_type=SOURCE_CONNECTION_LOST_EVENT, + payload={ + "source_reference": self._video_source.describe_source().source_reference + }, + status_update_handlers=self._status_update_handlers, + ) + self._attempt_restart() + + def _attempt_restart(self) -> None: + succeeded = False + while not self._stop and not succeeded: + try: + self._video_source.restart() + succeeded = True + except SourceConnectionError as error: + payload = { + "error_type": error.__class__.__name__, + "error_message": str(error), + "error_context": "video_frames_generator", + } + send_inference_pipeline_status_update( + severity=UpdateSeverity.WARNING, + event_type=SOURCE_CONNECTION_ATTEMPT_FAILED_EVENT, + payload=payload, + status_update_handlers=self._status_update_handlers, + ) + logger.warning( + f"Could not connect to video source. Retrying in {RESTART_ATTEMPT_DELAY}s..." + ) + time.sleep(RESTART_ATTEMPT_DELAY) + + +def send_inference_pipeline_status_update( + severity: UpdateSeverity, + event_type: str, + status_update_handlers: List[Callable[[StatusUpdate], None]], + payload: Optional[dict] = None, + sub_context: Optional[str] = None, +) -> None: + if payload is None: + payload = {} + context = INFERENCE_PIPELINE_CONTEXT + if sub_context is not None: + context = f"{context}.{sub_context}" + status_update = StatusUpdate( + timestamp=datetime.now(), + severity=severity, + event_type=event_type, + payload=payload, + context=context, + ) + for handler in status_update_handlers: + try: + handler(status_update) + except Exception as error: + logger.warning(f"Could not execute handler update. Cause: {error}") diff --git a/inference/core/interfaces/stream/sinks.py b/inference/core/interfaces/stream/sinks.py new file mode 100644 index 0000000000000000000000000000000000000000..ece7cd90302900550baf19ec99a2ff683c483784 --- /dev/null +++ b/inference/core/interfaces/stream/sinks.py @@ -0,0 +1,387 @@ +import json +import socket +from datetime import datetime +from functools import partial +from typing import Callable, List, Optional, Tuple + +import cv2 +import numpy as np +import supervision as sv + +from inference.core import logger +from inference.core.active_learning.middlewares import ActiveLearningMiddleware +from inference.core.interfaces.camera.entities import VideoFrame +from inference.core.utils.preprocess import letterbox_image + +DEFAULT_ANNOTATOR = sv.BoxAnnotator() +DEFAULT_FPS_MONITOR = sv.FPSMonitor() + + +def display_image(image: np.ndarray) -> None: + cv2.imshow("Predictions", image) + cv2.waitKey(1) + + +def render_boxes( + predictions: dict, + video_frame: VideoFrame, + annotator: sv.BoxAnnotator = DEFAULT_ANNOTATOR, + display_size: Optional[Tuple[int, int]] = (1280, 720), + fps_monitor: Optional[sv.FPSMonitor] = DEFAULT_FPS_MONITOR, + display_statistics: bool = False, + on_frame_rendered: Callable[[np.ndarray], None] = display_image, +) -> None: + """ + Helper tool to render object detection predictions on top of video frame. It is designed + to be used with `InferencePipeline`, as sink for predictions. By default, it uses standard `sv.BoxAnnotator()` + to draw bounding boxes and resizes prediction to 1280x720 (keeping aspect ratio and adding black padding). + One may configure default behaviour, for instance to display latency and throughput statistics. + + This sink is only partially compatible with stubs and classification models (it will not fail, + although predictions will not be displayed). + + Args: + predictions (dict): Roboflow object detection predictions with Bounding Boxes + video_frame (VideoFrame): frame of video with its basic metadata emitted by `VideoSource` + annotator (sv.BoxAnnotator): Annotator used to draw Bounding Boxes - if custom object is not passed, + default is used. + display_size (Tuple[int, int]): tuple in format (width, height) to resize visualisation output + fps_monitor (Optional[sv.FPSMonitor]): FPS monitor used to monitor throughput + display_statistics (bool): Flag to decide if throughput and latency can be displayed in the result image, + if enabled, throughput will only be presented if `fps_monitor` is not None + on_frame_rendered (Callable[[np.ndarray], None]): callback to be called once frame is rendered - by default, + function will display OpenCV window. + + Returns: None + Side effects: on_frame_rendered() is called against the np.ndarray produced from video frame + and predictions. + + Example: + ```python + from functools import partial + import cv2 + from inference import InferencePipeline + from inference.core.interfaces.stream.sinks import render_boxes + + output_size = (640, 480) + video_sink = cv2.VideoWriter("output.avi", cv2.VideoWriter_fourcc(*"MJPG"), 25.0, output_size) + on_prediction = partial(render_boxes, display_size=output_size, on_frame_rendered=video_sink.write) + + pipeline = InferencePipeline.init( + model_id="your-model/3", + video_reference="./some_file.mp4", + on_prediction=on_prediction, + ) + pipeline.start() + pipeline.join() + video_sink.release() + ``` + + In this example, `render_boxes()` is used as a sink for `InferencePipeline` predictions - making frames with + predictions displayed to be saved into video file. + """ + fps_value = None + if fps_monitor is not None: + fps_monitor.tick() + fps_value = fps_monitor() + try: + labels = [p["class"] for p in predictions["predictions"]] + detections = sv.Detections.from_roboflow(predictions) + image = annotator.annotate( + scene=video_frame.image.copy(), detections=detections, labels=labels + ) + except (TypeError, KeyError): + logger.warning( + f"Used `render_boxes(...)` sink, but predictions that were provided do not match the expected format " + f"of object detection prediction that could be accepted by `supervision.Detection.from_roboflow(...)" + ) + image = video_frame.image.copy() + if display_size is not None: + image = letterbox_image(image, desired_size=display_size) + if display_statistics: + image = render_statistics( + image=image, frame_timestamp=video_frame.frame_timestamp, fps=fps_value + ) + on_frame_rendered(image) + + +def render_statistics( + image: np.ndarray, frame_timestamp: datetime, fps: Optional[float] +) -> np.ndarray: + latency = round((datetime.now() - frame_timestamp).total_seconds() * 1000, 2) + image_height = image.shape[0] + image = cv2.putText( + image, + f"LATENCY: {latency} ms", + (10, image_height - 10), + cv2.FONT_HERSHEY_SIMPLEX, + 0.8, + (0, 255, 0), + 2, + ) + if fps is not None: + fps = round(fps, 2) + image = cv2.putText( + image, + f"THROUGHPUT: {fps}", + (10, image_height - 50), + cv2.FONT_HERSHEY_SIMPLEX, + 0.8, + (0, 255, 0), + 2, + ) + return image + + +class UDPSink: + @classmethod + def init(cls, ip_address: str, port: int) -> "UDPSink": + """ + Creates `InferencePipeline` predictions sink capable of sending model predictions over network + using UDP socket. + + As an `inference` user, please use .init() method instead of constructor to instantiate objects. + Args: + ip_address (str): IP address to send predictions + port (int): Port to send predictions + + Returns: Initialised object of `UDPSink` class. + """ + udp_socket = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) + udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1) + udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 65536) + return cls( + ip_address=ip_address, + port=port, + udp_socket=udp_socket, + ) + + def __init__(self, ip_address: str, port: int, udp_socket: socket.socket): + self._ip_address = ip_address + self._port = port + self._socket = udp_socket + + def send_predictions( + self, + predictions: dict, + video_frame: VideoFrame, + ) -> None: + """ + Method to send predictions via UDP socket. Useful in combination with `InferencePipeline` as + a sink for predictions. + + Args: + predictions (dict): Roboflow object detection predictions with Bounding Boxes + video_frame (VideoFrame): frame of video with its basic metadata emitted by `VideoSource` + + Returns: None + Side effects: Sends serialised `predictions` and `video_frame` metadata via the UDP socket as + JSON string. It adds key named "inference_metadata" into `predictions` dict (mutating its + state). "inference_metadata" contain id of the frame, frame grabbing timestamp and message + emission time in datetime iso format. + + Example: + ```python + import cv2 + from inference.core.interfaces.stream.inference_pipeline import InferencePipeline + from inference.core.interfaces.stream.sinks import UDPSink + + udp_sink = UDPSink.init(ip_address="127.0.0.1", port=9090) + + pipeline = InferencePipeline.init( + model_id="your-model/3", + video_reference="./some_file.mp4", + on_prediction=udp_sink.send_predictions, + ) + pipeline.start() + pipeline.join() + ``` + `UDPSink` used in this way will emit predictions to receiver automatically. + """ + inference_metadata = { + "frame_id": video_frame.frame_id, + "frame_decoding_time": video_frame.frame_timestamp.isoformat(), + "emission_time": datetime.now().isoformat(), + } + predictions["inference_metadata"] = inference_metadata + serialised_predictions = json.dumps(predictions).encode("utf-8") + self._socket.sendto( + serialised_predictions, + ( + self._ip_address, + self._port, + ), + ) + + +def multi_sink( + predictions: dict, + video_frame: VideoFrame, + sinks: List[Callable[[dict, VideoFrame], None]], +) -> None: + """ + Helper util useful to combine multiple sinks together, while using `InferencePipeline`. + + Args: + video_frame (VideoFrame): frame of video with its basic metadata emitted by `VideoSource` + predictions (dict): Roboflow object detection predictions with Bounding Boxes + sinks (List[Callable[[VideoFrame, dict], None]]): list of sinks to be used. Each will be executed + one-by-one in the order pointed in input list, all errors will be caught and reported via logger, + without re-raising. + + Returns: None + Side effects: Uses all sinks in context if (video_frame, predictions) input. + + Example: + ```python + from functools import partial + import cv2 + from inference import InferencePipeline + from inference.core.interfaces.stream.sinks import UDPSink, render_boxes + + udp_sink = UDPSink(ip_address="127.0.0.1", port=9090) + on_prediction = partial(multi_sink, sinks=[udp_sink.send_predictions, render_boxes]) + + pipeline = InferencePipeline.init( + model_id="your-model/3", + video_reference="./some_file.mp4", + on_prediction=on_prediction, + ) + pipeline.start() + pipeline.join() + ``` + + As a result, predictions will both be sent via UDP socket and displayed in the screen. + """ + for sink in sinks: + try: + sink(predictions, video_frame) + except Exception as error: + logger.error( + f"Could not sent prediction with frame_id={video_frame.frame_id} to sink " + f"due to error: {error}." + ) + + +def active_learning_sink( + predictions: dict, + video_frame: VideoFrame, + active_learning_middleware: ActiveLearningMiddleware, + model_type: str, + disable_preproc_auto_orient: bool = False, +) -> None: + active_learning_middleware.register( + inference_input=video_frame.image, + prediction=predictions, + prediction_type=model_type, + disable_preproc_auto_orient=disable_preproc_auto_orient, + ) + + +class VideoFileSink: + @classmethod + def init( + cls, + video_file_name: str, + annotator: sv.BoxAnnotator = DEFAULT_ANNOTATOR, + display_size: Optional[Tuple[int, int]] = (1280, 720), + fps_monitor: Optional[sv.FPSMonitor] = DEFAULT_FPS_MONITOR, + display_statistics: bool = False, + output_fps: int = 25, + quiet: bool = False, + ) -> "VideoFileSink": + """ + Creates `InferencePipeline` predictions sink capable of saving model predictions into video file. + + As an `inference` user, please use .init() method instead of constructor to instantiate objects. + Args: + video_file_name (str): name of the video file to save predictions + render_boxes (Callable[[dict, VideoFrame], None]): callable to render predictions on top of video frame + + Attributes: + on_prediction (Callable[[dict, VideoFrame], None]): callable to be used as a sink for predictions + + Returns: Initialized object of `VideoFileSink` class. + + Example: + ```python + import cv2 + from inference import InferencePipeline + from inference.core.interfaces.stream.sinks import VideoFileSink + + video_sink = VideoFileSink.init(video_file_name="output.avi") + + pipeline = InferencePipeline.init( + model_id="your-model/3", + video_reference="./some_file.mp4", + on_prediction=video_sink.on_prediction, + ) + pipeline.start() + pipeline.join() + video_sink.release() + ``` + + `VideoFileSink` used in this way will save predictions to video file automatically. + """ + return cls( + video_file_name=video_file_name, + annotator=annotator, + display_size=display_size, + fps_monitor=fps_monitor, + display_statistics=display_statistics, + output_fps=output_fps, + quiet=quiet, + ) + + def __init__( + self, + video_file_name: str, + annotator: sv.BoxAnnotator, + display_size: Optional[Tuple[int, int]], + fps_monitor: Optional[sv.FPSMonitor], + display_statistics: bool, + output_fps: int, + quiet: bool, + ): + self._video_file_name = video_file_name + self._annotator = annotator + self._display_size = display_size + self._fps_monitor = fps_monitor + self._display_statistics = display_statistics + self._output_fps = output_fps + self._quiet = quiet + self._frame_idx = 0 + + self._video_writer = cv2.VideoWriter( + self._video_file_name, + cv2.VideoWriter_fourcc(*"MJPG"), + self._output_fps, + self._display_size, + ) + + self.on_prediction = partial( + render_boxes, + annotator=self._annotator, + display_size=self._display_size, + fps_monitor=self._fps_monitor, + display_statistics=self._display_statistics, + on_frame_rendered=self._save_predictions, + ) + + def _save_predictions( + self, + frame: np.ndarray, + ) -> None: + """ """ + self._video_writer.write(frame) + if not self._quiet: + print(f"Writing frame {self._frame_idx}", end="\r") + self._frame_idx += 1 + + def release(self) -> None: + """ + Releases VideoWriter object. + """ + self._video_writer.release() diff --git a/inference/core/interfaces/stream/stream.py b/inference/core/interfaces/stream/stream.py new file mode 100644 index 0000000000000000000000000000000000000000..bee4258fe4b348cc15d13aa5a76264b13ffcaa1d --- /dev/null +++ b/inference/core/interfaces/stream/stream.py @@ -0,0 +1,323 @@ +import json +import threading +import time +import traceback +from typing import Callable, Union + +import cv2 +import numpy as np +import supervision as sv +from PIL import Image + +import inference.core.entities.requests.inference +from inference.core.active_learning.middlewares import ( + NullActiveLearningMiddleware, + ThreadingActiveLearningMiddleware, +) +from inference.core.cache import cache +from inference.core.env import ( + ACTIVE_LEARNING_ENABLED, + API_KEY, + API_KEY_ENV_NAMES, + CLASS_AGNOSTIC_NMS, + CONFIDENCE, + ENABLE_BYTE_TRACK, + ENFORCE_FPS, + IOU_THRESHOLD, + JSON_RESPONSE, + MAX_CANDIDATES, + MAX_DETECTIONS, + MODEL_ID, + STREAM_ID, +) +from inference.core.interfaces.base import BaseInterface +from inference.core.interfaces.camera.camera import WebcamStream +from inference.core.logger import logger +from inference.core.registries.roboflow import get_model_type +from inference.models.utils import get_roboflow_model + + +class Stream(BaseInterface): + """Roboflow defined stream interface for a general-purpose inference server. + + Attributes: + model_manager (ModelManager): The manager that handles model inference tasks. + model_registry (RoboflowModelRegistry): The registry to fetch model instances. + api_key (str): The API key for accessing models. + class_agnostic_nms (bool): Flag for class-agnostic non-maximum suppression. + confidence (float): Confidence threshold for inference. + iou_threshold (float): The intersection-over-union threshold for detection. + json_response (bool): Flag to toggle JSON response format. + max_candidates (float): The maximum number of candidates for detection. + max_detections (float): The maximum number of detections. + model (str|Callable): The model to be used. + stream_id (str): The ID of the stream to be used. + use_bytetrack (bool): Flag to use bytetrack, + + Methods: + init_infer: Initialize the inference with a test frame. + preprocess_thread: Preprocess incoming frames for inference. + inference_request_thread: Manage the inference requests. + run_thread: Run the preprocessing and inference threads. + """ + + def __init__( + self, + api_key: str = API_KEY, + class_agnostic_nms: bool = CLASS_AGNOSTIC_NMS, + confidence: float = CONFIDENCE, + enforce_fps: bool = ENFORCE_FPS, + iou_threshold: float = IOU_THRESHOLD, + max_candidates: float = MAX_CANDIDATES, + max_detections: float = MAX_DETECTIONS, + model: Union[str, Callable] = MODEL_ID, + source: Union[int, str] = STREAM_ID, + use_bytetrack: bool = ENABLE_BYTE_TRACK, + use_main_thread: bool = False, + output_channel_order: str = "RGB", + on_prediction: Callable = None, + on_start: Callable = None, + on_stop: Callable = None, + ): + """Initialize the stream with the given parameters. + Prints the server settings and initializes the inference with a test frame. + """ + logger.info("Initializing server") + + self.frame_count = 0 + self.byte_tracker = sv.ByteTrack() if use_bytetrack else None + self.use_bytetrack = use_bytetrack + + if source == "webcam": + stream_id = 0 + else: + stream_id = source + + self.stream_id = stream_id + if self.stream_id is None: + raise ValueError("STREAM_ID is not defined") + self.model_id = model + if not self.model_id: + raise ValueError("MODEL_ID is not defined") + self.api_key = api_key + + self.active_learning_middleware = NullActiveLearningMiddleware() + if isinstance(model, str): + self.model = get_roboflow_model(model, self.api_key) + if ACTIVE_LEARNING_ENABLED: + self.active_learning_middleware = ( + ThreadingActiveLearningMiddleware.init( + api_key=self.api_key, + model_id=self.model_id, + cache=cache, + ) + ) + self.task_type = get_model_type( + model_id=self.model_id, api_key=self.api_key + )[0] + else: + self.model = model + self.task_type = "unknown" + + self.class_agnostic_nms = class_agnostic_nms + self.confidence = confidence + self.iou_threshold = iou_threshold + self.max_candidates = max_candidates + self.max_detections = max_detections + self.use_main_thread = use_main_thread + self.output_channel_order = output_channel_order + + self.inference_request_type = ( + inference.core.entities.requests.inference.ObjectDetectionInferenceRequest + ) + + self.webcam_stream = WebcamStream( + stream_id=self.stream_id, enforce_fps=enforce_fps + ) + logger.info( + f"Streaming from device with resolution: {self.webcam_stream.width} x {self.webcam_stream.height}" + ) + + self.on_start_callbacks = [] + self.on_stop_callbacks = [ + lambda: self.active_learning_middleware.stop_registration_thread() + ] + self.on_prediction_callbacks = [] + + if on_prediction: + self.on_prediction_callbacks.append(on_prediction) + + if on_start: + self.on_start_callbacks.append(on_start) + + if on_stop: + self.on_stop_callbacks.append(on_stop) + + self.init_infer() + self.preproc_result = None + self.inference_request_obj = None + self.queue_control = False + self.inference_response = None + self.stop = False + + self.frame = None + self.frame_cv = None + self.frame_id = None + logger.info("Server initialized with settings:") + logger.info(f"Stream ID: {self.stream_id}") + logger.info(f"Model ID: {self.model_id}") + logger.info(f"Enforce FPS: {enforce_fps}") + logger.info(f"Confidence: {self.confidence}") + logger.info(f"Class Agnostic NMS: {self.class_agnostic_nms}") + logger.info(f"IOU Threshold: {self.iou_threshold}") + logger.info(f"Max Candidates: {self.max_candidates}") + logger.info(f"Max Detections: {self.max_detections}") + + self.run_thread() + + def on_start(self, callback): + self.on_start_callbacks.append(callback) + + unsubscribe = lambda: self.on_start_callbacks.remove(callback) + return unsubscribe + + def on_stop(self, callback): + self.on_stop_callbacks.append(callback) + + unsubscribe = lambda: self.on_stop_callbacks.remove(callback) + return unsubscribe + + def on_prediction(self, callback): + self.on_prediction_callbacks.append(callback) + + unsubscribe = lambda: self.on_prediction_callbacks.remove(callback) + return unsubscribe + + def init_infer(self): + """Initialize the inference with a test frame. + + Creates a test frame and runs it through the entire inference process to ensure everything is working. + """ + frame = Image.new("RGB", (640, 640), color="black") + self.model.infer( + frame, confidence=self.confidence, iou_threshold=self.iou_threshold + ) + self.active_learning_middleware.start_registration_thread() + + def preprocess_thread(self): + """Preprocess incoming frames for inference. + + Reads frames from the webcam stream, converts them into the proper format, and preprocesses them for + inference. + """ + webcam_stream = self.webcam_stream + webcam_stream.start() + # processing frames in input stream + try: + while True: + if webcam_stream.stopped is True or self.stop: + break + else: + self.frame_cv, frame_id = webcam_stream.read_opencv() + if frame_id > 0 and frame_id != self.frame_id: + self.frame_id = frame_id + self.frame = cv2.cvtColor(self.frame_cv, cv2.COLOR_BGR2RGB) + self.preproc_result = self.model.preprocess(self.frame_cv) + self.img_in, self.img_dims = self.preproc_result + self.queue_control = True + + except Exception as e: + traceback.print_exc() + logger.error(e) + + def inference_request_thread(self): + """Manage the inference requests. + + Processes preprocessed frames for inference, post-processes the predictions, and sends the results + to registered callbacks. + """ + last_print = time.perf_counter() + print_ind = 0 + while True: + if self.webcam_stream.stopped is True or self.stop: + while len(self.on_stop_callbacks) > 0: + # run each onStop callback only once from this thread + cb = self.on_stop_callbacks.pop() + cb() + break + if self.queue_control: + while len(self.on_start_callbacks) > 0: + # run each onStart callback only once from this thread + cb = self.on_start_callbacks.pop() + cb() + + self.queue_control = False + frame_id = self.frame_id + inference_input = np.copy(self.frame_cv) + start = time.perf_counter() + predictions = self.model.predict( + self.img_in, + ) + predictions = self.model.postprocess( + predictions, + self.img_dims, + class_agnostic_nms=self.class_agnostic_nms, + confidence=self.confidence, + iou_threshold=self.iou_threshold, + max_candidates=self.max_candidates, + max_detections=self.max_detections, + )[0] + + self.active_learning_middleware.register( + inference_input=inference_input, + prediction=predictions.dict(by_alias=True, exclude_none=True), + prediction_type=self.task_type, + ) + if self.use_bytetrack: + detections = sv.Detections.from_roboflow( + predictions.dict(by_alias=True, exclude_none=True) + ) + detections = self.byte_tracker.update_with_detections(detections) + + if detections.tracker_id is None: + detections.tracker_id = np.array([], dtype=int) + + for pred, detect in zip(predictions.predictions, detections): + pred.tracker_id = int(detect[4]) + predictions.frame_id = frame_id + predictions = predictions.dict(by_alias=True, exclude_none=True) + + self.inference_response = predictions + self.frame_count += 1 + + for cb in self.on_prediction_callbacks: + if self.output_channel_order == "BGR": + cb(predictions, self.frame_cv) + else: + cb(predictions, np.asarray(self.frame)) + + current = time.perf_counter() + self.webcam_stream.max_fps = 1 / (current - start) + logger.debug(f"FPS: {self.webcam_stream.max_fps:.2f}") + + if time.perf_counter() - last_print > 1: + print_ind = (print_ind + 1) % 4 + last_print = time.perf_counter() + + def run_thread(self): + """Run the preprocessing and inference threads. + + Starts the preprocessing and inference threads, and handles graceful shutdown on KeyboardInterrupt. + """ + preprocess_thread = threading.Thread(target=self.preprocess_thread) + preprocess_thread.start() + + if self.use_main_thread: + self.inference_request_thread() + else: + # start a thread that looks for the predictions + # and call the callbacks + inference_request_thread = threading.Thread( + target=self.inference_request_thread + ) + inference_request_thread.start() diff --git a/inference/core/interfaces/stream/watchdog.py b/inference/core/interfaces/stream/watchdog.py new file mode 100644 index 0000000000000000000000000000000000000000..ee1b0b89260129074646c0527aaab801de8f257a --- /dev/null +++ b/inference/core/interfaces/stream/watchdog.py @@ -0,0 +1,319 @@ +""" +This module contains component intended to use in combination with `InferencePipeline` to ensure +observability. Please consider them internal details of implementation. +""" + +from abc import ABC, abstractmethod +from collections import deque +from datetime import datetime +from typing import Any, Deque, Iterable, List, Optional, TypeVar + +import supervision as sv + +from inference.core.interfaces.camera.entities import StatusUpdate, UpdateSeverity +from inference.core.interfaces.camera.video_source import VideoSource +from inference.core.interfaces.stream.entities import ( + LatencyMonitorReport, + ModelActivityEvent, + PipelineStateReport, +) + +T = TypeVar("T") + +MAX_LATENCY_CONTEXT = 64 +MAX_UPDATES_CONTEXT = 512 + + +class PipelineWatchDog(ABC): + def __init__(self): + pass + + @abstractmethod + def register_video_source(self, video_source: VideoSource) -> None: + pass + + @abstractmethod + def on_status_update(self, status_update: StatusUpdate) -> None: + pass + + @abstractmethod + def on_model_preprocessing_started( + self, frame_timestamp: datetime, frame_id: int + ) -> None: + pass + + @abstractmethod + def on_model_inference_started( + self, frame_timestamp: datetime, frame_id: int + ) -> None: + pass + + @abstractmethod + def on_model_postprocessing_started( + self, frame_timestamp: datetime, frame_id: int + ) -> None: + pass + + @abstractmethod + def on_model_prediction_ready( + self, frame_timestamp: datetime, frame_id: int + ) -> None: + pass + + @abstractmethod + def get_report(self) -> Optional[PipelineStateReport]: + pass + + +class NullPipelineWatchdog(PipelineWatchDog): + def register_video_source(self, video_source: VideoSource) -> None: + pass + + def on_status_update(self, status_update: StatusUpdate) -> None: + pass + + def on_model_preprocessing_started( + self, frame_timestamp: datetime, frame_id: int + ) -> None: + pass + + def on_model_inference_started( + self, frame_timestamp: datetime, frame_id: int + ) -> None: + pass + + def on_model_postprocessing_started( + self, frame_timestamp: datetime, frame_id: int + ) -> None: + pass + + def on_model_prediction_ready( + self, frame_timestamp: datetime, frame_id: int + ) -> None: + pass + + def get_report(self) -> Optional[PipelineStateReport]: + return None + + +class LatencyMonitor: + def __init__(self): + self._preprocessing_start_event: Optional[ModelActivityEvent] = None + self._inference_start_event: Optional[ModelActivityEvent] = None + self._postprocessing_start_event: Optional[ModelActivityEvent] = None + self._prediction_ready_event: Optional[ModelActivityEvent] = None + self._reports: Deque[LatencyMonitorReport] = deque(maxlen=MAX_LATENCY_CONTEXT) + + def register_preprocessing_start( + self, frame_timestamp: datetime, frame_id: int + ) -> None: + self._preprocessing_start_event = ModelActivityEvent( + event_timestamp=datetime.now(), + frame_id=frame_id, + frame_decoding_timestamp=frame_timestamp, + ) + + def register_inference_start( + self, frame_timestamp: datetime, frame_id: int + ) -> None: + self._inference_start_event = ModelActivityEvent( + event_timestamp=datetime.now(), + frame_id=frame_id, + frame_decoding_timestamp=frame_timestamp, + ) + + def register_postprocessing_start( + self, frame_timestamp: datetime, frame_id: int + ) -> None: + self._postprocessing_start_event = ModelActivityEvent( + event_timestamp=datetime.now(), + frame_id=frame_id, + frame_decoding_timestamp=frame_timestamp, + ) + + def register_prediction_ready( + self, frame_timestamp: datetime, frame_id: int + ) -> None: + self._prediction_ready_event = ModelActivityEvent( + event_timestamp=datetime.now(), + frame_id=frame_id, + frame_decoding_timestamp=frame_timestamp, + ) + self._generate_report() + + def summarise_reports(self) -> LatencyMonitorReport: + avg_frame_decoding_latency = average_property_values( + examined_objects=self._reports, property_name="frame_decoding_latency" + ) + avg_pre_processing_latency = average_property_values( + examined_objects=self._reports, property_name="pre_processing_latency" + ) + avg_inference_latency = average_property_values( + examined_objects=self._reports, property_name="inference_latency" + ) + avg_pos_processing_latency = average_property_values( + examined_objects=self._reports, property_name="post_processing_latency" + ) + avg_model_latency = average_property_values( + examined_objects=self._reports, property_name="model_latency" + ) + avg_e2e_latency = average_property_values( + examined_objects=self._reports, property_name="e2e_latency" + ) + return LatencyMonitorReport( + frame_decoding_latency=avg_frame_decoding_latency, + pre_processing_latency=avg_pre_processing_latency, + inference_latency=avg_inference_latency, + post_processing_latency=avg_pos_processing_latency, + model_latency=avg_model_latency, + e2e_latency=avg_e2e_latency, + ) + + def _generate_report(self) -> None: + frame_decoding_latency = None + if self._preprocessing_start_event is not None: + frame_decoding_latency = ( + self._preprocessing_start_event.event_timestamp + - self._preprocessing_start_event.frame_decoding_timestamp + ).total_seconds() + event_pairs = [ + (self._preprocessing_start_event, self._inference_start_event), + (self._inference_start_event, self._postprocessing_start_event), + (self._postprocessing_start_event, self._prediction_ready_event), + (self._preprocessing_start_event, self._prediction_ready_event), + ] + event_pairs_results = [] + for earlier_event, later_event in event_pairs: + latency = compute_events_latency( + earlier_event=earlier_event, + later_event=later_event, + ) + event_pairs_results.append(latency) + ( + pre_processing_latency, + inference_latency, + post_processing_latency, + model_latency, + ) = event_pairs_results + e2e_latency = None + if self._prediction_ready_event is not None: + e2e_latency = ( + self._prediction_ready_event.event_timestamp + - self._prediction_ready_event.frame_decoding_timestamp + ).total_seconds() + self._reports.append( + LatencyMonitorReport( + frame_decoding_latency=frame_decoding_latency, + pre_processing_latency=pre_processing_latency, + inference_latency=inference_latency, + post_processing_latency=post_processing_latency, + model_latency=model_latency, + e2e_latency=e2e_latency, + ) + ) + + +def average_property_values( + examined_objects: Iterable, property_name: str +) -> Optional[float]: + values = get_not_empty_properties( + examined_objects=examined_objects, property_name=property_name + ) + return safe_average(values=values) + + +def get_not_empty_properties( + examined_objects: Iterable, property_name: str +) -> List[Any]: + results = [ + getattr(examined_object, property_name, None) + for examined_object in examined_objects + ] + return [e for e in results if e is not None] + + +def safe_average(values: List[float]) -> Optional[float]: + if len(values) == 0: + return None + return sum(values) / len(values) + + +def compute_events_latency( + earlier_event: Optional[ModelActivityEvent], + later_event: Optional[ModelActivityEvent], +) -> Optional[float]: + if not are_events_compatible(events=[earlier_event, later_event]): + return None + return (later_event.event_timestamp - earlier_event.event_timestamp).total_seconds() + + +def are_events_compatible(events: List[Optional[ModelActivityEvent]]) -> bool: + if any(e is None for e in events): + return False + if len(events) == 0: + return False + frame_ids = [e.frame_id for e in events] + return all(e == frame_ids[0] for e in frame_ids) + + +class BasePipelineWatchDog(PipelineWatchDog): + """ + Implementation to be used from single inference thread, as it keeps + state assumed to represent status of consecutive stage of prediction process + in latency monitor. + """ + + def __init__(self): + super().__init__() + self._video_source: Optional[VideoSource] = None + self._inference_throughput_monitor = sv.FPSMonitor() + self._latency_monitor = LatencyMonitor() + self._stream_updates = deque(maxlen=MAX_UPDATES_CONTEXT) + + def register_video_source(self, video_source: VideoSource) -> None: + self._video_source = video_source + + def on_status_update(self, status_update: StatusUpdate) -> None: + if status_update.severity.value <= UpdateSeverity.DEBUG.value: + return None + self._stream_updates.append(status_update) + + def on_model_preprocessing_started( + self, frame_timestamp: datetime, frame_id: int + ) -> None: + self._latency_monitor.register_preprocessing_start( + frame_timestamp=frame_timestamp, frame_id=frame_id + ) + + def on_model_inference_started( + self, frame_timestamp: datetime, frame_id: int + ) -> None: + self._latency_monitor.register_inference_start( + frame_timestamp=frame_timestamp, frame_id=frame_id + ) + + def on_model_postprocessing_started( + self, frame_timestamp: datetime, frame_id: int + ) -> None: + self._latency_monitor.register_postprocessing_start( + frame_timestamp=frame_timestamp, frame_id=frame_id + ) + + def on_model_prediction_ready( + self, frame_timestamp: datetime, frame_id: int + ) -> None: + self._latency_monitor.register_prediction_ready( + frame_timestamp=frame_timestamp, frame_id=frame_id + ) + self._inference_throughput_monitor.tick() + + def get_report(self) -> PipelineStateReport: + source_metadata = None + if self._video_source is not None: + source_metadata = self._video_source.describe_source() + return PipelineStateReport( + video_source_status_updates=list(self._stream_updates), + latency_report=self._latency_monitor.summarise_reports(), + inference_throughput=self._inference_throughput_monitor(), + source_metadata=source_metadata, + ) diff --git a/inference/core/interfaces/udp/__init__.py b/inference/core/interfaces/udp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/core/interfaces/udp/__pycache__/__init__.cpython-310.pyc b/inference/core/interfaces/udp/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c613b53e08331661b3fac71cfe73cec234a4900 Binary files /dev/null and b/inference/core/interfaces/udp/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/interfaces/udp/__pycache__/udp_stream.cpython-310.pyc b/inference/core/interfaces/udp/__pycache__/udp_stream.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..123a3a1144d91a2cf22a57932ee25cb8719f55fe Binary files /dev/null and b/inference/core/interfaces/udp/__pycache__/udp_stream.cpython-310.pyc differ diff --git a/inference/core/interfaces/udp/udp_stream.py b/inference/core/interfaces/udp/udp_stream.py new file mode 100644 index 0000000000000000000000000000000000000000..c990bf20978153aa686ee8052af79c00aaa5f3c4 --- /dev/null +++ b/inference/core/interfaces/udp/udp_stream.py @@ -0,0 +1,276 @@ +import json +import socket +import sys +import threading +import time +from typing import Union + +import cv2 +import numpy as np +import supervision as sv +from PIL import Image + +import inference.core.entities.requests.inference +from inference.core.active_learning.middlewares import ( + NullActiveLearningMiddleware, + ThreadingActiveLearningMiddleware, +) +from inference.core.cache import cache +from inference.core.env import ( + ACTIVE_LEARNING_ENABLED, + API_KEY, + API_KEY_ENV_NAMES, + CLASS_AGNOSTIC_NMS, + CONFIDENCE, + ENABLE_BYTE_TRACK, + ENFORCE_FPS, + IOU_THRESHOLD, + IP_BROADCAST_ADDR, + IP_BROADCAST_PORT, + MAX_CANDIDATES, + MAX_DETECTIONS, + MODEL_ID, + STREAM_ID, +) +from inference.core.interfaces.base import BaseInterface +from inference.core.interfaces.camera.camera import WebcamStream +from inference.core.logger import logger +from inference.core.registries.roboflow import get_model_type +from inference.core.version import __version__ +from inference.models.utils import get_roboflow_model + + +class UdpStream(BaseInterface): + """Roboflow defined UDP interface for a general-purpose inference server. + + Attributes: + model_manager (ModelManager): The manager that handles model inference tasks. + model_registry (RoboflowModelRegistry): The registry to fetch model instances. + api_key (str): The API key for accessing models. + class_agnostic_nms (bool): Flag for class-agnostic non-maximum suppression. + confidence (float): Confidence threshold for inference. + ip_broadcast_addr (str): The IP address to broadcast to. + ip_broadcast_port (int): The port to broadcast on. + iou_threshold (float): The intersection-over-union threshold for detection. + max_candidates (float): The maximum number of candidates for detection. + max_detections (float): The maximum number of detections. + model_id (str): The ID of the model to be used. + stream_id (str): The ID of the stream to be used. + use_bytetrack (bool): Flag to use bytetrack, + + Methods: + init_infer: Initialize the inference with a test frame. + preprocess_thread: Preprocess incoming frames for inference. + inference_request_thread: Manage the inference requests. + run_thread: Run the preprocessing and inference threads. + """ + + def __init__( + self, + api_key: str = API_KEY, + class_agnostic_nms: bool = CLASS_AGNOSTIC_NMS, + confidence: float = CONFIDENCE, + enforce_fps: bool = ENFORCE_FPS, + ip_broadcast_addr: str = IP_BROADCAST_ADDR, + ip_broadcast_port: int = IP_BROADCAST_PORT, + iou_threshold: float = IOU_THRESHOLD, + max_candidates: float = MAX_CANDIDATES, + max_detections: float = MAX_DETECTIONS, + model_id: str = MODEL_ID, + stream_id: Union[int, str] = STREAM_ID, + use_bytetrack: bool = ENABLE_BYTE_TRACK, + ): + """Initialize the UDP stream with the given parameters. + Prints the server settings and initializes the inference with a test frame. + """ + logger.info("Initializing server") + + self.frame_count = 0 + self.byte_tracker = sv.ByteTrack() if use_bytetrack else None + self.use_bytetrack = use_bytetrack + + self.stream_id = stream_id + if self.stream_id is None: + raise ValueError("STREAM_ID is not defined") + self.model_id = model_id + if not self.model_id: + raise ValueError("MODEL_ID is not defined") + self.api_key = api_key + if not self.api_key: + raise ValueError( + f"API key is missing. Either pass it explicitly to constructor, or use one of env variables: " + f"{API_KEY_ENV_NAMES}. Visit " + f"https://docs.roboflow.com/api-reference/authentication#retrieve-an-api-key to learn how to generate " + f"the key." + ) + + self.model = get_roboflow_model(self.model_id, self.api_key) + self.task_type = get_model_type(model_id=self.model_id, api_key=self.api_key)[0] + self.active_learning_middleware = NullActiveLearningMiddleware() + if ACTIVE_LEARNING_ENABLED: + self.active_learning_middleware = ThreadingActiveLearningMiddleware.init( + api_key=self.api_key, + model_id=self.model_id, + cache=cache, + ) + self.class_agnostic_nms = class_agnostic_nms + self.confidence = confidence + self.iou_threshold = iou_threshold + self.max_candidates = max_candidates + self.max_detections = max_detections + self.ip_broadcast_addr = ip_broadcast_addr + self.ip_broadcast_port = ip_broadcast_port + + self.inference_request_type = ( + inference.core.entities.requests.inference.ObjectDetectionInferenceRequest + ) + + self.UDPServerSocket = socket.socket( + family=socket.AF_INET, type=socket.SOCK_DGRAM + ) + self.UDPServerSocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.UDPServerSocket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) + self.UDPServerSocket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1) + self.UDPServerSocket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 65536) + + self.webcam_stream = WebcamStream( + stream_id=self.stream_id, enforce_fps=enforce_fps + ) + logger.info( + f"Streaming from device with resolution: {self.webcam_stream.width} x {self.webcam_stream.height}" + ) + + self.init_infer() + self.preproc_result = None + self.inference_request_obj = None + self.queue_control = False + self.inference_response = None + self.stop = False + + self.frame_cv = None + self.frame_id = None + logger.info("Server initialized with settings:") + logger.info(f"Stream ID: {self.stream_id}") + logger.info(f"Model ID: {self.model_id}") + logger.info(f"Confidence: {self.confidence}") + logger.info(f"Class Agnostic NMS: {self.class_agnostic_nms}") + logger.info(f"IOU Threshold: {self.iou_threshold}") + logger.info(f"Max Candidates: {self.max_candidates}") + logger.info(f"Max Detections: {self.max_detections}") + + def init_infer(self): + """Initialize the inference with a test frame. + + Creates a test frame and runs it through the entire inference process to ensure everything is working. + """ + frame = Image.new("RGB", (640, 640), color="black") + self.model.infer( + frame, confidence=self.confidence, iou_threshold=self.iou_threshold + ) + self.active_learning_middleware.start_registration_thread() + + def preprocess_thread(self): + """Preprocess incoming frames for inference. + + Reads frames from the webcam stream, converts them into the proper format, and preprocesses them for + inference. + """ + webcam_stream = self.webcam_stream + webcam_stream.start() + # processing frames in input stream + try: + while True: + if webcam_stream.stopped is True or self.stop: + break + else: + self.frame_cv, frame_id = webcam_stream.read_opencv() + if frame_id != self.frame_id: + self.frame_id = frame_id + self.preproc_result = self.model.preprocess(self.frame_cv) + self.img_in, self.img_dims = self.preproc_result + self.queue_control = True + + except Exception as e: + logger.error(e) + + def inference_request_thread(self): + """Manage the inference requests. + + Processes preprocessed frames for inference, post-processes the predictions, and sends the results + as a UDP broadcast. + """ + last_print = time.perf_counter() + print_ind = 0 + print_chars = ["|", "/", "-", "\\"] + while True: + if self.stop: + break + if self.queue_control: + self.queue_control = False + frame_id = self.frame_id + inference_input = np.copy(self.frame_cv) + predictions = self.model.predict( + self.img_in, + ) + predictions = self.model.postprocess( + predictions, + self.img_dims, + class_agnostic_nms=self.class_agnostic_nms, + confidence=self.confidence, + iou_threshold=self.iou_threshold, + max_candidates=self.max_candidates, + max_detections=self.max_detections, + )[0] + self.active_learning_middleware.register( + inference_input=inference_input, + prediction=predictions.dict(by_alias=True, exclude_none=True), + prediction_type=self.task_type, + ) + if self.use_bytetrack: + detections = sv.Detections.from_roboflow( + predictions.dict(by_alias=True), self.model.class_names + ) + detections = self.byte_tracker.update_with_detections(detections) + for pred, detect in zip(predictions.predictions, detections): + pred.tracker_id = int(detect[4]) + predictions.frame_id = frame_id + predictions = predictions.json(exclude_none=True, by_alias=True) + + self.inference_response = predictions + self.frame_count += 1 + + bytesToSend = predictions.encode("utf-8") + self.UDPServerSocket.sendto( + bytesToSend, + ( + self.ip_broadcast_addr, + self.ip_broadcast_port, + ), + ) + if time.perf_counter() - last_print > 1: + print(f"Streaming {print_chars[print_ind]}", end="\r") + print_ind = (print_ind + 1) % 4 + last_print = time.perf_counter() + + def run_thread(self): + """Run the preprocessing and inference threads. + + Starts the preprocessing and inference threads, and handles graceful shutdown on KeyboardInterrupt. + """ + preprocess_thread = threading.Thread(target=self.preprocess_thread) + inference_request_thread = threading.Thread( + target=self.inference_request_thread + ) + + preprocess_thread.start() + inference_request_thread.start() + + while True: + try: + time.sleep(10) + except KeyboardInterrupt: + logger.info("Stopping server...") + self.stop = True + self.active_learning_middleware.stop_registration_thread() + time.sleep(3) + sys.exit(0) diff --git a/inference/core/logger.py b/inference/core/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..4e885453122ac99ad3a4024ca541b9bddc07668e --- /dev/null +++ b/inference/core/logger.py @@ -0,0 +1,14 @@ +import logging +import warnings + +from rich.logging import RichHandler + +from inference.core.env import LOG_LEVEL + +logger = logging.getLogger("inference") +logger.setLevel(LOG_LEVEL) +logger.addHandler(RichHandler()) +logger.propagate = False + +if LOG_LEVEL == "ERROR" or LOG_LEVEL == "FATAL": + warnings.filterwarnings("ignore", category=UserWarning, module="onnxruntime.*") diff --git a/inference/core/managers/__init__.py b/inference/core/managers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/inference/core/managers/__init__.py @@ -0,0 +1 @@ + diff --git a/inference/core/managers/__pycache__/__init__.cpython-310.pyc b/inference/core/managers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..607ecd1fbf76756af027a9f94ab2d2d88405e549 Binary files /dev/null and b/inference/core/managers/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/managers/__pycache__/active_learning.cpython-310.pyc b/inference/core/managers/__pycache__/active_learning.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd0f07d17342b6acb2842fd4002cf219e624d190 Binary files /dev/null and b/inference/core/managers/__pycache__/active_learning.cpython-310.pyc differ diff --git a/inference/core/managers/__pycache__/base.cpython-310.pyc b/inference/core/managers/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ecd13e40f955b4f11f922f5a44ac2fea890a8fb Binary files /dev/null and b/inference/core/managers/__pycache__/base.cpython-310.pyc differ diff --git a/inference/core/managers/__pycache__/entities.cpython-310.pyc b/inference/core/managers/__pycache__/entities.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b865f9566609623ba243298a87ccfd67d6b64177 Binary files /dev/null and b/inference/core/managers/__pycache__/entities.cpython-310.pyc differ diff --git a/inference/core/managers/__pycache__/metrics.cpython-310.pyc b/inference/core/managers/__pycache__/metrics.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ccf54b6219e1e84db3b72207d0c2491caf88f5f Binary files /dev/null and b/inference/core/managers/__pycache__/metrics.cpython-310.pyc differ diff --git a/inference/core/managers/__pycache__/pingback.cpython-310.pyc b/inference/core/managers/__pycache__/pingback.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59443488e3a04aa81bf099623fd8e60a459ccc1b Binary files /dev/null and b/inference/core/managers/__pycache__/pingback.cpython-310.pyc differ diff --git a/inference/core/managers/__pycache__/stub_loader.cpython-310.pyc b/inference/core/managers/__pycache__/stub_loader.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0768f0c4cfcb2a31d9780887ed94bdee29d9f855 Binary files /dev/null and b/inference/core/managers/__pycache__/stub_loader.cpython-310.pyc differ diff --git a/inference/core/managers/active_learning.py b/inference/core/managers/active_learning.py new file mode 100644 index 0000000000000000000000000000000000000000..09e405c8a68c0c709f6b87a7293f0dc88fd353b0 --- /dev/null +++ b/inference/core/managers/active_learning.py @@ -0,0 +1,144 @@ +import time +from typing import Dict, Optional + +from fastapi import BackgroundTasks + +from inference.core import logger +from inference.core.active_learning.middlewares import ActiveLearningMiddleware +from inference.core.cache.base import BaseCache +from inference.core.entities.requests.inference import InferenceRequest +from inference.core.entities.responses.inference import InferenceResponse +from inference.core.env import DISABLE_PREPROC_AUTO_ORIENT +from inference.core.managers.base import ModelManager +from inference.core.registries.base import ModelRegistry + +ACTIVE_LEARNING_ELIGIBLE_PARAM = "active_learning_eligible" +DISABLE_ACTIVE_LEARNING_PARAM = "disable_active_learning" +BACKGROUND_TASKS_PARAM = "background_tasks" + + +class ActiveLearningManager(ModelManager): + def __init__( + self, + model_registry: ModelRegistry, + cache: BaseCache, + middlewares: Optional[Dict[str, ActiveLearningMiddleware]] = None, + ): + super().__init__(model_registry=model_registry) + self._cache = cache + self._middlewares = middlewares if middlewares is not None else {} + + async def infer_from_request( + self, model_id: str, request: InferenceRequest, **kwargs + ) -> InferenceResponse: + prediction = await super().infer_from_request( + model_id=model_id, request=request, **kwargs + ) + active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False) + active_learning_disabled_for_request = getattr( + request, DISABLE_ACTIVE_LEARNING_PARAM, False + ) + if ( + not active_learning_eligible + or active_learning_disabled_for_request + or request.api_key is None + ): + return prediction + self.register(prediction=prediction, model_id=model_id, request=request) + return prediction + + def register( + self, prediction: InferenceResponse, model_id: str, request: InferenceRequest + ) -> None: + try: + self.ensure_middleware_initialised(model_id=model_id, request=request) + self.register_datapoint( + prediction=prediction, + model_id=model_id, + request=request, + ) + except Exception as error: + # Error handling to be decided + logger.warning( + f"Error in datapoint registration for Active Learning. Details: {error}. " + f"Error is suppressed in favour of normal operations of API." + ) + + def ensure_middleware_initialised( + self, model_id: str, request: InferenceRequest + ) -> None: + if model_id in self._middlewares: + return None + start = time.perf_counter() + logger.debug(f"Initialising AL middleware for {model_id}") + self._middlewares[model_id] = ActiveLearningMiddleware.init( + api_key=request.api_key, + model_id=model_id, + cache=self._cache, + ) + end = time.perf_counter() + logger.debug(f"Middleware init latency: {(end - start) * 1000} ms") + + def register_datapoint( + self, prediction: InferenceResponse, model_id: str, request: InferenceRequest + ) -> None: + start = time.perf_counter() + inference_inputs = getattr(request, "image", None) + if inference_inputs is None: + logger.warning( + "Could not register datapoint, as inference input has no `image` field." + ) + return None + if not issubclass(type(inference_inputs), list): + inference_inputs = [inference_inputs] + if not issubclass(type(prediction), list): + results_dicts = [prediction.dict(by_alias=True, exclude={"visualization"})] + else: + results_dicts = [ + e.dict(by_alias=True, exclude={"visualization"}) for e in prediction + ] + prediction_type = self.get_task_type(model_id=model_id) + disable_preproc_auto_orient = ( + getattr(request, "disable_preproc_auto_orient", False) + or DISABLE_PREPROC_AUTO_ORIENT + ) + self._middlewares[model_id].register_batch( + inference_inputs=inference_inputs, + predictions=results_dicts, + prediction_type=prediction_type, + disable_preproc_auto_orient=disable_preproc_auto_orient, + ) + end = time.perf_counter() + logger.debug(f"Registration: {(end - start) * 1000} ms") + + +class BackgroundTaskActiveLearningManager(ActiveLearningManager): + async def infer_from_request( + self, model_id: str, request: InferenceRequest, **kwargs + ) -> InferenceResponse: + active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False) + active_learning_disabled_for_request = getattr( + request, DISABLE_ACTIVE_LEARNING_PARAM, False + ) + kwargs[ACTIVE_LEARNING_ELIGIBLE_PARAM] = False # disabling AL in super-classes + prediction = await super().infer_from_request( + model_id=model_id, request=request, **kwargs + ) + if ( + not active_learning_eligible + or active_learning_disabled_for_request + or request.api_key is None + ): + return prediction + if BACKGROUND_TASKS_PARAM not in kwargs: + logger.warning( + "BackgroundTaskActiveLearningManager used against rules - `background_tasks` argument not " + "provided making Active Learning registration running sequentially." + ) + self.register(prediction=prediction, model_id=model_id, request=request) + else: + background_tasks: BackgroundTasks = kwargs["background_tasks"] + background_tasks.add_task( + self.register, prediction=prediction, model_id=model_id, request=request + ) + return prediction diff --git a/inference/core/managers/base.py b/inference/core/managers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d0e0e1f8d7c3c3be9879882f1fa3b69e48a51a85 --- /dev/null +++ b/inference/core/managers/base.py @@ -0,0 +1,325 @@ +import time +from typing import Dict, List, Optional, Tuple + +import numpy as np +from fastapi.encoders import jsonable_encoder + +from inference.core.cache import cache +from inference.core.cache.serializers import to_cachable_inference_item +from inference.core.devices.utils import GLOBAL_INFERENCE_SERVER_ID +from inference.core.entities.requests.inference import InferenceRequest +from inference.core.entities.responses.inference import InferenceResponse +from inference.core.env import ( + DISABLE_INFERENCE_CACHE, + METRICS_ENABLED, + METRICS_INTERVAL, + ROBOFLOW_SERVER_UUID, +) +from inference.core.exceptions import InferenceModelNotFound +from inference.core.logger import logger +from inference.core.managers.entities import ModelDescription +from inference.core.managers.pingback import PingbackInfo +from inference.core.models.base import Model, PreprocessReturnMetadata +from inference.core.registries.base import ModelRegistry + + +class ModelManager: + """Model managers keep track of a dictionary of Model objects and is responsible for passing requests to the right model using the infer method.""" + + def __init__(self, model_registry: ModelRegistry, models: Optional[dict] = None): + self.model_registry = model_registry + self._models: Dict[str, Model] = models if models is not None else {} + + def init_pingback(self): + """Initializes pingback mechanism.""" + self.num_errors = 0 # in the device + self.uuid = ROBOFLOW_SERVER_UUID + if METRICS_ENABLED: + self.pingback = PingbackInfo(self) + self.pingback.start() + + def add_model( + self, model_id: str, api_key: str, model_id_alias: Optional[str] = None + ) -> None: + """Adds a new model to the manager. + + Args: + model_id (str): The identifier of the model. + model (Model): The model instance. + """ + logger.debug( + f"ModelManager - Adding model with model_id={model_id}, model_id_alias={model_id_alias}" + ) + if model_id in self._models: + logger.debug( + f"ModelManager - model with model_id={model_id} is already loaded." + ) + return + logger.debug("ModelManager - model initialisation...") + model = self.model_registry.get_model( + model_id if model_id_alias is None else model_id_alias, api_key + )( + model_id=model_id, + api_key=api_key, + ) + logger.debug("ModelManager - model successfully loaded.") + self._models[model_id if model_id_alias is None else model_id_alias] = model + + def check_for_model(self, model_id: str) -> None: + """Checks whether the model with the given ID is in the manager. + + Args: + model_id (str): The identifier of the model. + + Raises: + InferenceModelNotFound: If the model is not found in the manager. + """ + if model_id not in self: + raise InferenceModelNotFound(f"Model with id {model_id} not loaded.") + + async def infer_from_request( + self, model_id: str, request: InferenceRequest, **kwargs + ) -> InferenceResponse: + """Runs inference on the specified model with the given request. + + Args: + model_id (str): The identifier of the model. + request (InferenceRequest): The request to process. + + Returns: + InferenceResponse: The response from the inference. + """ + logger.debug( + f"ModelManager - inference from request started for model_id={model_id}." + ) + try: + rtn_val = await self.model_infer( + model_id=model_id, request=request, **kwargs + ) + logger.debug( + f"ModelManager - inference from request finished for model_id={model_id}." + ) + finish_time = time.time() + if not DISABLE_INFERENCE_CACHE: + logger.debug( + f"ModelManager - caching inference request started for model_id={model_id}" + ) + cache.zadd( + f"models", + value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}", + score=finish_time, + expire=METRICS_INTERVAL * 2, + ) + if ( + hasattr(request, "image") + and hasattr(request.image, "type") + and request.image.type == "numpy" + ): + request.image.value = str(request.image.value) + cache.zadd( + f"inference:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}", + value=to_cachable_inference_item(request, rtn_val), + score=finish_time, + expire=METRICS_INTERVAL * 2, + ) + logger.debug( + f"ModelManager - caching inference request finished for model_id={model_id}" + ) + return rtn_val + except Exception as e: + finish_time = time.time() + if not DISABLE_INFERENCE_CACHE: + cache.zadd( + f"models", + value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}", + score=finish_time, + expire=METRICS_INTERVAL * 2, + ) + cache.zadd( + f"error:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}", + value={ + "request": jsonable_encoder( + request.dict(exclude={"image", "subject", "prompt"}) + ), + "error": str(e), + }, + score=finish_time, + expire=METRICS_INTERVAL * 2, + ) + raise + + async def model_infer(self, model_id: str, request: InferenceRequest, **kwargs): + self.check_for_model(model_id) + return self._models[model_id].infer_from_request(request) + + def make_response( + self, model_id: str, predictions: List[List[float]], *args, **kwargs + ) -> InferenceResponse: + """Creates a response object from the model's predictions. + + Args: + model_id (str): The identifier of the model. + predictions (List[List[float]]): The model's predictions. + + Returns: + InferenceResponse: The created response object. + """ + self.check_for_model(model_id) + return self._models[model_id].make_response(predictions, *args, **kwargs) + + def postprocess( + self, + model_id: str, + predictions: Tuple[np.ndarray, ...], + preprocess_return_metadata: PreprocessReturnMetadata, + *args, + **kwargs, + ) -> List[List[float]]: + """Processes the model's predictions after inference. + + Args: + model_id (str): The identifier of the model. + predictions (np.ndarray): The model's predictions. + + Returns: + List[List[float]]: The post-processed predictions. + """ + self.check_for_model(model_id) + return self._models[model_id].postprocess( + predictions, preprocess_return_metadata, *args, **kwargs + ) + + def predict(self, model_id: str, *args, **kwargs) -> Tuple[np.ndarray, ...]: + """Runs prediction on the specified model. + + Args: + model_id (str): The identifier of the model. + + Returns: + np.ndarray: The predictions from the model. + """ + self.check_for_model(model_id) + self._models[model_id].metrics["num_inferences"] += 1 + tic = time.perf_counter() + res = self._models[model_id].predict(*args, **kwargs) + toc = time.perf_counter() + self._models[model_id].metrics["avg_inference_time"] += toc - tic + return res + + def preprocess( + self, model_id: str, request: InferenceRequest + ) -> Tuple[np.ndarray, PreprocessReturnMetadata]: + """Preprocesses the request before inference. + + Args: + model_id (str): The identifier of the model. + request (InferenceRequest): The request to preprocess. + + Returns: + Tuple[np.ndarray, List[Tuple[int, int]]]: The preprocessed data. + """ + self.check_for_model(model_id) + return self._models[model_id].preprocess(**request.dict()) + + def get_class_names(self, model_id): + """Retrieves the class names for a given model. + + Args: + model_id (str): The identifier of the model. + + Returns: + List[str]: The class names of the model. + """ + self.check_for_model(model_id) + return self._models[model_id].class_names + + def get_task_type(self, model_id: str, api_key: str = None) -> str: + """Retrieves the task type for a given model. + + Args: + model_id (str): The identifier of the model. + + Returns: + str: The task type of the model. + """ + self.check_for_model(model_id) + return self._models[model_id].task_type + + def remove(self, model_id: str) -> None: + """Removes a model from the manager. + + Args: + model_id (str): The identifier of the model. + """ + try: + self.check_for_model(model_id) + self._models[model_id].clear_cache() + del self._models[model_id] + except InferenceModelNotFound: + logger.warning( + f"Attempted to remove model with id {model_id}, but it is not loaded. Skipping..." + ) + + def clear(self) -> None: + """Removes all models from the manager.""" + for model_id in list(self.keys()): + self.remove(model_id) + + def __contains__(self, model_id: str) -> bool: + """Checks if the model is contained in the manager. + + Args: + model_id (str): The identifier of the model. + + Returns: + bool: Whether the model is in the manager. + """ + return model_id in self._models + + def __getitem__(self, key: str) -> Model: + """Retrieve a model from the manager by key. + + Args: + key (str): The identifier of the model. + + Returns: + Model: The model corresponding to the key. + """ + self.check_for_model(model_id=key) + return self._models[key] + + def __len__(self) -> int: + """Retrieve the number of models in the manager. + + Returns: + int: The number of models in the manager. + """ + return len(self._models) + + def keys(self): + """Retrieve the keys (model identifiers) from the manager. + + Returns: + List[str]: The keys of the models in the manager. + """ + return self._models.keys() + + def models(self) -> Dict[str, Model]: + """Retrieve the models dictionary from the manager. + + Returns: + Dict[str, Model]: The keys of the models in the manager. + """ + return self._models + + def describe_models(self) -> List[ModelDescription]: + return [ + ModelDescription( + model_id=model_id, + task_type=model.task_type, + batch_size=getattr(model, "batch_size", None), + input_width=getattr(model, "img_size_w", None), + input_height=getattr(model, "img_size_h", None), + ) + for model_id, model in self._models.items() + ] diff --git a/inference/core/managers/decorators/__init__.py b/inference/core/managers/decorators/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/core/managers/decorators/__pycache__/__init__.cpython-310.pyc b/inference/core/managers/decorators/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28b3c289b28c60e484d98b719ffca48fb45ffd29 Binary files /dev/null and b/inference/core/managers/decorators/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/managers/decorators/__pycache__/base.cpython-310.pyc b/inference/core/managers/decorators/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46888597eb3963bfe4339bf5de9ba6e996f38dd9 Binary files /dev/null and b/inference/core/managers/decorators/__pycache__/base.cpython-310.pyc differ diff --git a/inference/core/managers/decorators/__pycache__/fixed_size_cache.cpython-310.pyc b/inference/core/managers/decorators/__pycache__/fixed_size_cache.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c7efbd09ca6a4cdb72f97613cc67c62cc84f7fb Binary files /dev/null and b/inference/core/managers/decorators/__pycache__/fixed_size_cache.cpython-310.pyc differ diff --git a/inference/core/managers/decorators/__pycache__/locked_load.cpython-310.pyc b/inference/core/managers/decorators/__pycache__/locked_load.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4683435300c5ba88a5e70dd0a15b15cf987a7362 Binary files /dev/null and b/inference/core/managers/decorators/__pycache__/locked_load.cpython-310.pyc differ diff --git a/inference/core/managers/decorators/__pycache__/logger.cpython-310.pyc b/inference/core/managers/decorators/__pycache__/logger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32b775ed0e19f84243c0431001c5c151dc9ba322 Binary files /dev/null and b/inference/core/managers/decorators/__pycache__/logger.cpython-310.pyc differ diff --git a/inference/core/managers/decorators/base.py b/inference/core/managers/decorators/base.py new file mode 100644 index 0000000000000000000000000000000000000000..75b80e2d84fbcbe072d2b8adefd77ec25b1a9ebd --- /dev/null +++ b/inference/core/managers/decorators/base.py @@ -0,0 +1,191 @@ +from typing import List, Optional, Tuple + +import numpy as np + +from inference.core.entities.requests.inference import InferenceRequest +from inference.core.entities.responses.inference import InferenceResponse +from inference.core.env import API_KEY +from inference.core.managers.base import Model, ModelManager +from inference.core.models.types import PreprocessReturnMetadata + + +class ModelManagerDecorator(ModelManager): + """Basic decorator, it acts like a `ModelManager` and contains a `ModelManager`. + + Args: + model_manager (ModelManager): Instance of a ModelManager. + + Methods: + add_model: Adds a model to the manager. + infer: Processes a complete inference request. + infer_only: Performs only the inference part of a request. + preprocess: Processes the preprocessing part of a request. + get_task_type: Gets the task type associated with a model. + get_class_names: Gets the class names for a given model. + remove: Removes a model from the manager. + __len__: Returns the number of models in the manager. + __getitem__: Retrieves a model by its ID. + __contains__: Checks if a model exists in the manager. + keys: Returns the keys (model IDs) from the manager. + """ + + @property + def _models(self): + raise ValueError("Should only be accessing self.model_manager._models") + + @property + def model_registry(self): + raise ValueError("Should only be accessing self.model_manager.model_registry") + + def __init__(self, model_manager: ModelManager): + """Initializes the decorator with an instance of a ModelManager.""" + self.model_manager = model_manager + + def add_model( + self, model_id: str, api_key: str, model_id_alias: Optional[str] = None + ): + """Adds a model to the manager. + + Args: + model_id (str): The identifier of the model. + model (Model): The model instance. + """ + if model_id in self: + return + self.model_manager.add_model(model_id, api_key, model_id_alias=model_id_alias) + + async def infer_from_request( + self, model_id: str, request: InferenceRequest, **kwargs + ) -> InferenceResponse: + """Processes a complete inference request. + + Args: + model_id (str): The identifier of the model. + request (InferenceRequest): The request to process. + + Returns: + InferenceResponse: The response from the inference. + """ + return await self.model_manager.infer_from_request(model_id, request, **kwargs) + + def infer_only(self, model_id: str, request, img_in, img_dims, batch_size=None): + """Performs only the inference part of a request. + + Args: + model_id (str): The identifier of the model. + request: The request to process. + img_in: Input image. + img_dims: Image dimensions. + batch_size (int, optional): Batch size. + + Returns: + Response from the inference-only operation. + """ + return self.model_manager.infer_only( + model_id, request, img_in, img_dims, batch_size + ) + + def preprocess(self, model_id: str, request: InferenceRequest): + """Processes the preprocessing part of a request. + + Args: + model_id (str): The identifier of the model. + request (InferenceRequest): The request to preprocess. + """ + return self.model_manager.preprocess(model_id, request) + + def get_task_type(self, model_id: str, api_key: str = None) -> str: + """Gets the task type associated with a model. + + Args: + model_id (str): The identifier of the model. + + Returns: + str: The task type. + """ + if api_key is None: + api_key = API_KEY + return self.model_manager.get_task_type(model_id, api_key=api_key) + + def get_class_names(self, model_id): + """Gets the class names for a given model. + + Args: + model_id: The identifier of the model. + + Returns: + List of class names. + """ + return self.model_manager.get_class_names(model_id) + + def remove(self, model_id: str) -> Model: + """Removes a model from the manager. + + Args: + model_id (str): The identifier of the model. + + Returns: + Model: The removed model. + """ + return self.model_manager.remove(model_id) + + def __len__(self) -> int: + """Returns the number of models in the manager. + + Returns: + int: Number of models. + """ + return len(self.model_manager) + + def __getitem__(self, key: str) -> Model: + """Retrieves a model by its ID. + + Args: + key (str): The identifier of the model. + + Returns: + Model: The model instance. + """ + return self.model_manager[key] + + def __contains__(self, model_id: str): + """Checks if a model exists in the manager. + + Args: + model_id (str): The identifier of the model. + + Returns: + bool: True if the model exists, False otherwise. + """ + return model_id in self.model_manager + + def keys(self): + """Returns the keys (model IDs) from the manager. + + Returns: + List of keys (model IDs). + """ + return self.model_manager.keys() + + def models(self): + return self.model_manager.models() + + def predict(self, model_id: str, *args, **kwargs) -> Tuple[np.ndarray, ...]: + return self.model_manager.predict(model_id, *args, **kwargs) + + def postprocess( + self, + model_id: str, + predictions: Tuple[np.ndarray, ...], + preprocess_return_metadata: PreprocessReturnMetadata, + *args, + **kwargs + ) -> List[List[float]]: + return self.model_manager.postprocess( + model_id, predictions, preprocess_return_metadata, *args, **kwargs + ) + + def make_response( + self, model_id: str, predictions: List[List[float]], *args, **kwargs + ) -> InferenceResponse: + return self.model_manager.make_response(model_id, predictions, *args, **kwargs) diff --git a/inference/core/managers/decorators/fixed_size_cache.py b/inference/core/managers/decorators/fixed_size_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..113a7ca813e987991158269775b36e463ae6616c --- /dev/null +++ b/inference/core/managers/decorators/fixed_size_cache.py @@ -0,0 +1,114 @@ +from collections import deque +from typing import List, Optional + +from inference.core.entities.requests.inference import InferenceRequest +from inference.core.entities.responses.inference import InferenceResponse +from inference.core.managers.base import Model, ModelManager +from inference.core.managers.decorators.base import ModelManagerDecorator +from inference.core.managers.entities import ModelDescription + + +class WithFixedSizeCache(ModelManagerDecorator): + def __init__(self, model_manager: ModelManager, max_size: int = 8): + """Cache decorator, models will be evicted based on the last utilization (`.infer` call). Internally, a [double-ended queue](https://docs.python.org/3/library/collections.html#collections.deque) is used to keep track of model utilization. + + Args: + model_manager (ModelManager): Instance of a ModelManager. + max_size (int, optional): Max number of models at the same time. Defaults to 8. + """ + super().__init__(model_manager) + self.max_size = max_size + self._key_queue = deque(self.model_manager.keys()) + + def add_model( + self, model_id: str, api_key: str, model_id_alias: Optional[str] = None + ): + """Adds a model to the manager and evicts the least recently used if the cache is full. + + Args: + model_id (str): The identifier of the model. + model (Model): The model instance. + """ + queue_id = self._resolve_queue_id( + model_id=model_id, model_id_alias=model_id_alias + ) + if model_id in self: + self._key_queue.remove(queue_id) + self._key_queue.append(queue_id) + return + + should_pop = len(self) == self.max_size + if should_pop: + to_remove_model_id = self._key_queue.popleft() + self.remove(to_remove_model_id) + + self._key_queue.append(queue_id) + try: + return super().add_model(model_id, api_key, model_id_alias=model_id_alias) + except Exception as error: + self._key_queue.remove(model_id) + raise error + + def clear(self) -> None: + """Removes all models from the manager.""" + for model_id in list(self.keys()): + self.remove(model_id) + + def remove(self, model_id: str) -> Model: + try: + self._key_queue.remove(model_id) + except ValueError: + pass + return super().remove(model_id) + + async def infer_from_request( + self, model_id: str, request: InferenceRequest, **kwargs + ) -> InferenceResponse: + """Processes a complete inference request and updates the cache. + + Args: + model_id (str): The identifier of the model. + request (InferenceRequest): The request to process. + + Returns: + InferenceResponse: The response from the inference. + """ + self._key_queue.remove(model_id) + self._key_queue.append(model_id) + return await super().infer_from_request(model_id, request, **kwargs) + + def infer_only(self, model_id: str, request, img_in, img_dims, batch_size=None): + """Performs only the inference part of a request and updates the cache. + + Args: + model_id (str): The identifier of the model. + request: The request to process. + img_in: Input image. + img_dims: Image dimensions. + batch_size (int, optional): Batch size. + + Returns: + Response from the inference-only operation. + """ + self._key_queue.remove(model_id) + self._key_queue.append(model_id) + return super().infer_only(model_id, request, img_in, img_dims, batch_size) + + def preprocess(self, model_id: str, request): + """Processes the preprocessing part of a request and updates the cache. + + Args: + model_id (str): The identifier of the model. + request (InferenceRequest): The request to preprocess. + """ + self._key_queue.remove(model_id) + self._key_queue.append(model_id) + return super().preprocess(model_id, request) + + def describe_models(self) -> List[ModelDescription]: + return self.model_manager.describe_models() + + def _resolve_queue_id( + self, model_id: str, model_id_alias: Optional[str] = None + ) -> str: + return model_id if model_id_alias is None else model_id_alias diff --git a/inference/core/managers/decorators/locked_load.py b/inference/core/managers/decorators/locked_load.py new file mode 100644 index 0000000000000000000000000000000000000000..326ffb88828e59c57562bc8d1811aae5a07e6d82 --- /dev/null +++ b/inference/core/managers/decorators/locked_load.py @@ -0,0 +1,12 @@ +from inference.core.cache import cache +from inference.core.managers.decorators.base import ModelManagerDecorator + +lock_str = lambda z: f"locks:model-load:{z}" + + +class LockedLoadModelManagerDecorator(ModelManagerDecorator): + """Must acquire lock to load model""" + + def add_model(self, model_id: str, api_key: str, model_id_alias=None): + with cache.lock(lock_str(model_id), expire=180.0): + return super().add_model(model_id, api_key, model_id_alias=model_id_alias) diff --git a/inference/core/managers/decorators/logger.py b/inference/core/managers/decorators/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..4bb4245c0ee3ea1c39b8fbc09fe78b95a22a7c13 --- /dev/null +++ b/inference/core/managers/decorators/logger.py @@ -0,0 +1,56 @@ +from typing import Optional + +from inference.core.entities.requests.inference import InferenceRequest +from inference.core.entities.responses.inference import InferenceResponse +from inference.core.logger import logger +from inference.core.managers.base import Model +from inference.core.managers.decorators.base import ModelManagerDecorator + + +class WithLogger(ModelManagerDecorator): + """Logger Decorator, it logs what's going on inside the manager.""" + + def add_model( + self, model_id: str, api_key: str, model_id_alias: Optional[str] = None + ): + """Adds a model to the manager and logs the action. + + Args: + model_id (str): The identifier of the model. + model (Model): The model instance. + + Returns: + The result of the add_model method from the superclass. + """ + logger.info(f"🤖 {model_id} added.") + return super().add_model(model_id, api_key, model_id_alias=model_id_alias) + + async def infer_from_request( + self, model_id: str, request: InferenceRequest, **kwargs + ) -> InferenceResponse: + """Processes a complete inference request and logs both the request and response. + + Args: + model_id (str): The identifier of the model. + request (InferenceRequest): The request to process. + + Returns: + InferenceResponse: The response from the inference. + """ + logger.info(f"📥 [{model_id}] request={request}.") + res = await super().infer_from_request(model_id, request, **kwargs) + logger.info(f"📥 [{model_id}] res={res}.") + return res + + def remove(self, model_id: str) -> Model: + """Removes a model from the manager and logs the action. + + Args: + model_id (str): The identifier of the model to remove. + + Returns: + Model: The removed model. + """ + res = super().remove(model_id) + logger.info(f"❌ removed {model_id}") + return res diff --git a/inference/core/managers/entities.py b/inference/core/managers/entities.py new file mode 100644 index 0000000000000000000000000000000000000000..5c0f9e30a17ff9ef7468ad34fbd3f759f9b977f5 --- /dev/null +++ b/inference/core/managers/entities.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass(frozen=True) +class ModelDescription: + model_id: str + task_type: str + batch_size: Optional[int] + input_height: Optional[int] + input_width: Optional[int] diff --git a/inference/core/managers/metrics.py b/inference/core/managers/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..01af47c60cc828bef88517705bb884eb0aeab453 --- /dev/null +++ b/inference/core/managers/metrics.py @@ -0,0 +1,101 @@ +import platform +import re +import socket +import time +import uuid + +from inference.core.cache import cache +from inference.core.logger import logger +from inference.core.version import __version__ + + +def get_model_metrics( + inference_server_id: str, model_id: str, min: float = -1, max: float = float("inf") +) -> dict: + """ + Gets the metrics for a given model between a specified time range. + + Args: + device_id (str): The identifier of the device. + model_id (str): The identifier of the model. + start (float, optional): The starting timestamp of the time range. Defaults to -1. + stop (float, optional): The ending timestamp of the time range. Defaults to float("inf"). + + Returns: + dict: A dictionary containing the metrics of the model: + - num_inferences (int): The number of inferences made. + - avg_inference_time (float): The average inference time. + - num_errors (int): The number of errors occurred. + """ + now = time.time() + inferences_with_times = cache.zrangebyscore( + f"inference:{inference_server_id}:{model_id}", min=min, max=max, withscores=True + ) + num_inferences = len(inferences_with_times) + inference_times = [] + for inference, t in inferences_with_times: + response = inference["response"] + if isinstance(response, list): + times = [r["time"] for r in response if "time" in r] + inference_times.extend(times) + else: + if "time" in response: + inference_times.append(response["time"]) + avg_inference_time = ( + sum(inference_times) / len(inference_times) if len(inference_times) > 0 else 0 + ) + errors_with_times = cache.zrangebyscore( + f"error:{inference_server_id}:{model_id}", min=min, max=max, withscores=True + ) + num_errors = len(errors_with_times) + return { + "num_inferences": num_inferences, + "avg_inference_time": avg_inference_time, + "num_errors": num_errors, + } + + +def get_system_info() -> dict: + """Collects system information such as platform, architecture, hostname, IP address, MAC address, and processor details. + + Returns: + dict: A dictionary containing detailed system information. + """ + info = {} + try: + info["platform"] = platform.system() + info["platform_release"] = platform.release() + info["platform_version"] = platform.version() + info["architecture"] = platform.machine() + info["hostname"] = socket.gethostname() + info["ip_address"] = socket.gethostbyname(socket.gethostname()) + info["mac_address"] = ":".join(re.findall("..", "%012x" % uuid.getnode())) + info["processor"] = platform.processor() + return info + except Exception as e: + logger.exception(e) + finally: + return info + + +def get_inference_results_for_model( + inference_server_id: str, model_id: str, min: float = -1, max: float = float("inf") +): + inferences_with_times = cache.zrangebyscore( + f"inference:{inference_server_id}:{model_id}", min=min, max=max, withscores=True + ) + inference_results = [] + for result, score in inferences_with_times: + # Don't send large image files + if result.get("request", {}).get("image"): + del result["request"]["image"] + responses = result.get("response") + if responses: + if not isinstance(responses, list): + responses = [responses] + for resp in responses: + if resp.get("image"): + del resp["image"] + inference_results.append({"request_time": score, "inference": result}) + + return inference_results diff --git a/inference/core/managers/pingback.py b/inference/core/managers/pingback.py new file mode 100644 index 0000000000000000000000000000000000000000..cbe939294d9721ef8b52775b73fa6f26e12dafec --- /dev/null +++ b/inference/core/managers/pingback.py @@ -0,0 +1,135 @@ +import time +import traceback + +import requests +from apscheduler.schedulers.background import BackgroundScheduler + +from inference.core.devices.utils import GLOBAL_DEVICE_ID, GLOBAL_INFERENCE_SERVER_ID +from inference.core.env import ( + API_KEY, + METRICS_ENABLED, + METRICS_INTERVAL, + METRICS_URL, + TAGS, +) +from inference.core.logger import logger +from inference.core.managers.metrics import ( + get_inference_results_for_model, + get_system_info, +) +from inference.core.utils.requests import api_key_safe_raise_for_status +from inference.core.utils.url_utils import wrap_url +from inference.core.version import __version__ + + +class PingbackInfo: + """Class responsible for managing pingback information for Roboflow. + + This class initializes a scheduler to periodically post data to Roboflow, containing information about the models, + container, and device. + + Attributes: + scheduler (BackgroundScheduler): A scheduler for running jobs in the background. + model_manager (ModelManager): Reference to the model manager object. + process_startup_time (str): Unix timestamp indicating when the process started. + METRICS_URL (str): URL to send the pingback data to. + system_info (dict): Information about the system. + window_start_timestamp (str): Unix timestamp indicating the start of the current window. + """ + + def __init__(self, manager): + """Initializes PingbackInfo with the given manager. + + Args: + manager (ModelManager): Reference to the model manager object. + """ + try: + self.scheduler = BackgroundScheduler() + self.model_manager = manager + self.process_startup_time = str(int(time.time())) + logger.debug( + "UUID: " + self.model_manager.uuid + ) # To correlate with UI container view + self.window_start_timestamp = str(int(time.time())) + context = { + "api_key": API_KEY, + "timestamp": str(int(time.time())), + "device_id": GLOBAL_DEVICE_ID, + "inference_server_id": GLOBAL_INFERENCE_SERVER_ID, + "inference_server_version": __version__, + "tags": TAGS, + } + self.environment_info = context | get_system_info() + except Exception as e: + logger.debug( + "Error sending pingback to Roboflow, if you want to disable this feature unset the ROBOFLOW_ENABLED environment variable. " + + str(e) + ) + + def start(self): + """Starts the scheduler to periodically post data to Roboflow. + + If METRICS_ENABLED is False, a warning is logged, and the method returns without starting the scheduler. + """ + if METRICS_ENABLED == False: + logger.warning( + "Metrics reporting to Roboflow is disabled; not sending back stats to Roboflow." + ) + return + try: + self.scheduler.add_job( + self.post_data, + "interval", + seconds=METRICS_INTERVAL, + args=[self.model_manager], + ) + self.scheduler.start() + except Exception as e: + logger.debug(e) + + def stop(self): + """Stops the scheduler.""" + self.scheduler.shutdown() + + def post_data(self, model_manager): + """Posts data to Roboflow about the models, container, device, and other relevant metrics. + + Args: + model_manager (ModelManager): Reference to the model manager object. + + The data is collected and reset for the next window, and a POST request is made to the pingback URL. + """ + all_data = self.environment_info.copy() + all_data["inference_results"] = [] + try: + now = time.time() + start = now - METRICS_INTERVAL + for model_id in model_manager.models(): + results = get_inference_results_for_model( + GLOBAL_INFERENCE_SERVER_ID, model_id, min=start, max=now + ) + all_data["inference_results"] = all_data["inference_results"] + results + res = requests.post(wrap_url(METRICS_URL), json=all_data) + try: + api_key_safe_raise_for_status(response=res) + logger.debug( + "Sent metrics to Roboflow {} at {}.".format( + METRICS_URL, str(all_data) + ) + ) + except Exception as e: + logger.debug( + f"Error sending metrics to Roboflow, if you want to disable this feature unset the METRICS_ENABLED environment variable." + ) + + except Exception as e: + try: + logger.debug( + f"Error sending metrics to Roboflow, if you want to disable this feature unset the METRICS_ENABLED environment variable. Error was: {e}. Data was: {all_data}" + ) + traceback.print_exc() + + except Exception as e2: + logger.debug( + f"Error sending metrics to Roboflow, if you want to disable this feature unset the METRICS_ENABLED environment variable. Error was: {e}." + ) diff --git a/inference/core/managers/stub_loader.py b/inference/core/managers/stub_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..f877b0371c1ad9f19af40ae284a51c054e62e896 --- /dev/null +++ b/inference/core/managers/stub_loader.py @@ -0,0 +1,18 @@ +from inference.core.managers.base import ModelManager + + +class StubLoaderManager(ModelManager): + def add_model(self, model_id: str, api_key: str, model_id_alias=None) -> None: + """Adds a new model to the manager. + + Args: + model_id (str): The identifier of the model. + model (Model): The model instance. + """ + if model_id in self._models: + return + model_class = self.model_registry.get_model( + model_id_alias if model_id_alias is not None else model_id, api_key + ) + model = model_class(model_id=model_id, api_key=api_key, load_weights=False) + self._models[model_id] = model diff --git a/inference/core/models/__init__.py b/inference/core/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/core/models/__pycache__/__init__.cpython-310.pyc b/inference/core/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b58c38feb280af4c04a9d79f4e25ebcb742d00f Binary files /dev/null and b/inference/core/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/models/__pycache__/base.cpython-310.pyc b/inference/core/models/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41670607471e23c8a1d22fb5fb15727d41b2710c Binary files /dev/null and b/inference/core/models/__pycache__/base.cpython-310.pyc differ diff --git a/inference/core/models/__pycache__/classification_base.cpython-310.pyc b/inference/core/models/__pycache__/classification_base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..001a69d5cd53988b1aa64ac961b192d3e8e191be Binary files /dev/null and b/inference/core/models/__pycache__/classification_base.cpython-310.pyc differ diff --git a/inference/core/models/__pycache__/defaults.cpython-310.pyc b/inference/core/models/__pycache__/defaults.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e450ce503a6b60ef0c3ff2dca1d30b24b6a6dfe Binary files /dev/null and b/inference/core/models/__pycache__/defaults.cpython-310.pyc differ diff --git a/inference/core/models/__pycache__/instance_segmentation_base.cpython-310.pyc b/inference/core/models/__pycache__/instance_segmentation_base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..001a3c4d3703b4a8c80142e6e031539ea62eb348 Binary files /dev/null and b/inference/core/models/__pycache__/instance_segmentation_base.cpython-310.pyc differ diff --git a/inference/core/models/__pycache__/keypoints_detection_base.cpython-310.pyc b/inference/core/models/__pycache__/keypoints_detection_base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5671e74732965a407fc05e5212035f883e7e43e0 Binary files /dev/null and b/inference/core/models/__pycache__/keypoints_detection_base.cpython-310.pyc differ diff --git a/inference/core/models/__pycache__/object_detection_base.cpython-310.pyc b/inference/core/models/__pycache__/object_detection_base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..919f0fdabbe165e42f6867e7ed796c3003e0e87f Binary files /dev/null and b/inference/core/models/__pycache__/object_detection_base.cpython-310.pyc differ diff --git a/inference/core/models/__pycache__/roboflow.cpython-310.pyc b/inference/core/models/__pycache__/roboflow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f40a12e3504402c729212702b1f4aaf8e796743 Binary files /dev/null and b/inference/core/models/__pycache__/roboflow.cpython-310.pyc differ diff --git a/inference/core/models/__pycache__/stubs.cpython-310.pyc b/inference/core/models/__pycache__/stubs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..444c095cbaa5bd9829eca84eba383860de19df2b Binary files /dev/null and b/inference/core/models/__pycache__/stubs.cpython-310.pyc differ diff --git a/inference/core/models/__pycache__/types.cpython-310.pyc b/inference/core/models/__pycache__/types.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03f6022288a1550e1646abfb8d9dc1abcbd085cd Binary files /dev/null and b/inference/core/models/__pycache__/types.cpython-310.pyc differ diff --git a/inference/core/models/base.py b/inference/core/models/base.py new file mode 100644 index 0000000000000000000000000000000000000000..285742a4d4dabd7ec7dd3131822d7d3091b3042d --- /dev/null +++ b/inference/core/models/base.py @@ -0,0 +1,146 @@ +from time import perf_counter +from typing import Any, List, Tuple, Union + +import numpy as np + +from inference.core.entities.requests.inference import InferenceRequest +from inference.core.entities.responses.inference import InferenceResponse +from inference.core.models.types import PreprocessReturnMetadata + + +class BaseInference: + """General inference class. + + This class provides a basic interface for inference tasks. + """ + + def infer(self, image: Any, **kwargs) -> Any: + """Runs inference on given data.""" + preproc_image, returned_metadata = self.preprocess(image, **kwargs) + predicted_arrays = self.predict(preproc_image, **kwargs) + postprocessed = self.postprocess(predicted_arrays, returned_metadata, **kwargs) + + return postprocessed + + def preprocess( + self, image: Any, **kwargs + ) -> Tuple[np.ndarray, PreprocessReturnMetadata]: + raise NotImplementedError + + def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, ...]: + raise NotImplementedError + + def postprocess( + self, + predictions: Tuple[np.ndarray, ...], + preprocess_return_metadata: PreprocessReturnMetadata, + **kwargs + ) -> Any: + raise NotImplementedError + + def infer_from_request( + self, request: InferenceRequest + ) -> Union[InferenceResponse, List[InferenceResponse]]: + """Runs inference on a request + + Args: + request (InferenceRequest): The request object. + + Returns: + Union[CVInferenceResponse, List[CVInferenceResponse]]: The response object(s). + + Raises: + NotImplementedError: This method must be implemented by a subclass. + """ + raise NotImplementedError + + def make_response( + self, *args, **kwargs + ) -> Union[InferenceResponse, List[InferenceResponse]]: + """Constructs an object detection response. + + Raises: + NotImplementedError: This method must be implemented by a subclass. + """ + raise NotImplementedError + + +class Model(BaseInference): + """Base Inference Model (Inherits from BaseInference to define the needed methods) + + This class provides the foundational methods for inference and logging, and can be extended by specific models. + + Methods: + log(m): Print the given message. + clear_cache(): Clears any cache if necessary. + """ + + def log(self, m): + """Prints the given message. + + Args: + m (str): The message to print. + """ + print(m) + + def clear_cache(self): + """Clears any cache if necessary. This method should be implemented in derived classes as needed.""" + pass + + def infer_from_request( + self, + request: InferenceRequest, + ) -> Union[List[InferenceResponse], InferenceResponse]: + """ + Perform inference based on the details provided in the request, and return the associated responses. + The function can handle both single and multiple image inference requests. Optionally, it also provides + a visualization of the predictions if requested. + + Args: + request (InferenceRequest): The request object containing details for inference, such as the image or + images to process, any classes to filter by, and whether or not to visualize the predictions. + + Returns: + Union[List[InferenceResponse], InferenceResponse]: A list of response objects if the request contains + multiple images, or a single response object if the request contains one image. Each response object + contains details about the segmented instances, the time taken for inference, and optionally, a visualization. + + Examples: + >>> request = InferenceRequest(image=my_image, visualize_predictions=True) + >>> response = infer_from_request(request) + >>> print(response.time) # Prints the time taken for inference + 0.125 + >>> print(response.visualization) # Accesses the visualization of the prediction if available + + Notes: + - The processing time for each response is included within the response itself. + - If `visualize_predictions` is set to True in the request, a visualization of the prediction + is also included in the response. + """ + t1 = perf_counter() + responses = self.infer(**request.dict(), return_image_dims=False) + for response in responses: + response.time = perf_counter() - t1 + + if request.visualize_predictions: + for response in responses: + response.visualization = self.draw_predictions(request, response) + + if not isinstance(request.image, list) and len(responses) > 0: + responses = responses[0] + + return responses + + def make_response( + self, *args, **kwargs + ) -> Union[InferenceResponse, List[InferenceResponse]]: + """Makes an inference response from the given arguments. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + InferenceResponse: The inference response. + """ + raise NotImplementedError(self.__class__.__name__ + ".make_response") diff --git a/inference/core/models/classification_base.py b/inference/core/models/classification_base.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5ba0cda7c37f8e2d1b22de434047ca7b1258de --- /dev/null +++ b/inference/core/models/classification_base.py @@ -0,0 +1,365 @@ +from io import BytesIO +from time import perf_counter +from typing import Any, List, Tuple, Union + +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +from inference.core.entities.requests.inference import ClassificationInferenceRequest +from inference.core.entities.responses.inference import ( + ClassificationInferenceResponse, + InferenceResponse, + InferenceResponseImage, + MultiLabelClassificationInferenceResponse, +) +from inference.core.models.roboflow import OnnxRoboflowInferenceModel +from inference.core.models.types import PreprocessReturnMetadata +from inference.core.models.utils.validate import ( + get_num_classes_from_model_prediction_shape, +) +from inference.core.utils.image_utils import load_image_rgb + + +class ClassificationBaseOnnxRoboflowInferenceModel(OnnxRoboflowInferenceModel): + """Base class for ONNX models for Roboflow classification inference. + + Attributes: + multiclass (bool): Whether the classification is multi-class or not. + + Methods: + get_infer_bucket_file_list() -> list: Get the list of required files for inference. + softmax(x): Compute softmax values for a given set of scores. + infer(request: ClassificationInferenceRequest) -> Union[List[Union[ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse]], Union[ClassificationInferenceResponse, MultiLabelClassificationInferenceResponse]]: Perform inference on a given request and return the response. + draw_predictions(inference_request, inference_response): Draw prediction visuals on an image. + """ + + task_type = "classification" + + def __init__(self, *args, **kwargs): + """Initialize the model, setting whether it is multiclass or not.""" + super().__init__(*args, **kwargs) + self.multiclass = self.environment.get("MULTICLASS", False) + + def draw_predictions(self, inference_request, inference_response): + """Draw prediction visuals on an image. + + This method overlays the predictions on the input image, including drawing rectangles and text to visualize the predicted classes. + + Args: + inference_request: The request object containing the image and parameters. + inference_response: The response object containing the predictions and other details. + + Returns: + bytes: The bytes of the visualized image in JPEG format. + """ + image = load_image_rgb(inference_request.image) + image = Image.fromarray(image) + draw = ImageDraw.Draw(image) + font = ImageFont.load_default() + if isinstance(inference_response.predictions, list): + prediction = inference_response.predictions[0] + color = self.colors.get(prediction.class_name, "#4892EA") + draw.rectangle( + [0, 0, image.size[1], image.size[0]], + outline=color, + width=inference_request.visualization_stroke_width, + ) + text = f"{prediction.class_id} - {prediction.class_name} {prediction.confidence:.2f}" + text_size = font.getbbox(text) + + # set button size + 10px margins + button_size = (text_size[2] + 20, text_size[3] + 20) + button_img = Image.new("RGBA", button_size, color) + # put text on button with 10px margins + button_draw = ImageDraw.Draw(button_img) + button_draw.text((10, 10), text, font=font, fill=(255, 255, 255, 255)) + + # put button on source image in position (0, 0) + image.paste(button_img, (0, 0)) + else: + if len(inference_response.predictions) > 0: + box_color = "#4892EA" + draw.rectangle( + [0, 0, image.size[1], image.size[0]], + outline=box_color, + width=inference_request.visualization_stroke_width, + ) + row = 0 + predictions = [ + (cls_name, pred) + for cls_name, pred in inference_response.predictions.items() + ] + predictions = sorted( + predictions, key=lambda x: x[1].confidence, reverse=True + ) + for i, (cls_name, pred) in enumerate(predictions): + color = self.colors.get(cls_name, "#4892EA") + text = f"{cls_name} {pred.confidence:.2f}" + text_size = font.getbbox(text) + + # set button size + 10px margins + button_size = (text_size[2] + 20, text_size[3] + 20) + button_img = Image.new("RGBA", button_size, color) + # put text on button with 10px margins + button_draw = ImageDraw.Draw(button_img) + button_draw.text((10, 10), text, font=font, fill=(255, 255, 255, 255)) + + # put button on source image in position (0, 0) + image.paste(button_img, (0, row)) + row += button_size[1] + + buffered = BytesIO() + image = image.convert("RGB") + image.save(buffered, format="JPEG") + return buffered.getvalue() + + def get_infer_bucket_file_list(self) -> list: + """Get the list of required files for inference. + + Returns: + list: A list of required files for inference, e.g., ["environment.json"]. + """ + return ["environment.json"] + + def infer( + self, + image: Any, + disable_preproc_auto_orient: bool = False, + disable_preproc_contrast: bool = False, + disable_preproc_grayscale: bool = False, + disable_preproc_static_crop: bool = False, + return_image_dims: bool = False, + **kwargs, + ): + """ + Perform inference on the provided image(s) and return the predictions. + + Args: + image (Any): The image or list of images to be processed. + disable_preproc_auto_orient (bool, optional): If true, the auto orient preprocessing step is disabled for this call. Default is False. + disable_preproc_contrast (bool, optional): If true, the auto contrast preprocessing step is disabled for this call. Default is False. + disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False. + disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False. + return_image_dims (bool, optional): If set to True, the function will also return the dimensions of the image. Defaults to False. + **kwargs: Additional parameters to customize the inference process. + + Returns: + Union[List[np.array], np.array, Tuple[List[np.array], List[Tuple[int, int]]], Tuple[np.array, Tuple[int, int]]]: + If `return_image_dims` is True and a list of images is provided, a tuple containing a list of prediction arrays and a list of image dimensions (width, height) is returned. + If `return_image_dims` is True and a single image is provided, a tuple containing the prediction array and image dimensions (width, height) is returned. + If `return_image_dims` is False and a list of images is provided, only the list of prediction arrays is returned. + If `return_image_dims` is False and a single image is provided, only the prediction array is returned. + + Notes: + - The input image(s) will be preprocessed (normalized and reshaped) before inference. + - This function uses an ONNX session to perform inference on the input image(s). + """ + return super().infer( + image, + disable_preproc_auto_orient=disable_preproc_auto_orient, + disable_preproc_contrast=disable_preproc_contrast, + disable_preproc_grayscale=disable_preproc_grayscale, + disable_preproc_static_crop=disable_preproc_static_crop, + return_image_dims=return_image_dims, + ) + + def postprocess( + self, + predictions: Tuple[np.ndarray], + preprocess_return_metadata: PreprocessReturnMetadata, + return_image_dims=False, + **kwargs, + ) -> Union[ClassificationInferenceResponse, List[ClassificationInferenceResponse]]: + predictions = predictions[0] + return self.make_response( + predictions, preprocess_return_metadata["img_dims"], **kwargs + ) + + def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]: + predictions = self.onnx_session.run(None, {self.input_name: img_in}) + return (predictions,) + + def preprocess( + self, image: Any, **kwargs + ) -> Tuple[np.ndarray, PreprocessReturnMetadata]: + if isinstance(image, list): + imgs_with_dims = [ + self.preproc_image( + i, + disable_preproc_auto_orient=kwargs.get( + "disable_preproc_auto_orient", False + ), + disable_preproc_contrast=kwargs.get( + "disable_preproc_contrast", False + ), + disable_preproc_grayscale=kwargs.get( + "disable_preproc_grayscale", False + ), + disable_preproc_static_crop=kwargs.get( + "disable_preproc_static_crop", False + ), + ) + for i in image + ] + imgs, img_dims = zip(*imgs_with_dims) + img_in = np.concatenate(imgs, axis=0) + else: + img_in, img_dims = self.preproc_image( + image, + disable_preproc_auto_orient=kwargs.get( + "disable_preproc_auto_orient", False + ), + disable_preproc_contrast=kwargs.get("disable_preproc_contrast", False), + disable_preproc_grayscale=kwargs.get( + "disable_preproc_grayscale", False + ), + disable_preproc_static_crop=kwargs.get( + "disable_preproc_static_crop", False + ), + ) + img_dims = [img_dims] + + img_in /= 255.0 + + mean = (0.5, 0.5, 0.5) + std = (0.5, 0.5, 0.5) + + img_in = img_in.astype(np.float32) + + img_in[:, 0, :, :] = (img_in[:, 0, :, :] - mean[0]) / std[0] + img_in[:, 1, :, :] = (img_in[:, 1, :, :] - mean[1]) / std[1] + img_in[:, 2, :, :] = (img_in[:, 2, :, :] - mean[2]) / std[2] + return img_in, PreprocessReturnMetadata({"img_dims": img_dims}) + + def infer_from_request( + self, + request: ClassificationInferenceRequest, + ) -> Union[List[InferenceResponse], InferenceResponse]: + """ + Handle an inference request to produce an appropriate response. + + Args: + request (ClassificationInferenceRequest): The request object encapsulating the image(s) and relevant parameters. + + Returns: + Union[List[InferenceResponse], InferenceResponse]: The response object(s) containing the predictions, visualization, and other pertinent details. If a list of images was provided, a list of responses is returned. Otherwise, a single response is returned. + + Notes: + - Starts a timer at the beginning to calculate inference time. + - Processes the image(s) through the `infer` method. + - Generates the appropriate response object(s) using `make_response`. + - Calculates and sets the time taken for inference. + - If visualization is requested, the predictions are drawn on the image. + """ + t1 = perf_counter() + responses = self.infer(**request.dict(), return_image_dims=True) + for response in responses: + response.time = perf_counter() - t1 + + if request.visualize_predictions: + for response in responses: + response.visualization = self.draw_predictions(request, response) + + if not isinstance(request.image, list): + responses = responses[0] + + return responses + + def make_response( + self, + predictions, + img_dims, + confidence: float = 0.5, + **kwargs, + ) -> Union[ClassificationInferenceResponse, List[ClassificationInferenceResponse]]: + """ + Create response objects for the given predictions and image dimensions. + + Args: + predictions (list): List of prediction arrays from the inference process. + img_dims (list): List of tuples indicating the dimensions (width, height) of each image. + confidence (float, optional): Confidence threshold for filtering predictions. Defaults to 0.5. + **kwargs: Additional parameters to influence the response creation process. + + Returns: + Union[ClassificationInferenceResponse, List[ClassificationInferenceResponse]]: A response object or a list of response objects encapsulating the prediction details. + + Notes: + - If the model is multiclass, a `MultiLabelClassificationInferenceResponse` is generated for each image. + - If the model is not multiclass, a `ClassificationInferenceResponse` is generated for each image. + - Predictions below the confidence threshold are filtered out. + """ + responses = [] + confidence_threshold = float(confidence) + for ind, prediction in enumerate(predictions): + if self.multiclass: + preds = prediction[0] + results = dict() + predicted_classes = [] + for i, o in enumerate(preds): + cls_name = self.class_names[i] + score = float(o) + results[cls_name] = {"confidence": score, "class_id": i} + if score > confidence_threshold: + predicted_classes.append(cls_name) + response = MultiLabelClassificationInferenceResponse( + image=InferenceResponseImage( + width=img_dims[ind][0], height=img_dims[ind][1] + ), + predicted_classes=predicted_classes, + predictions=results, + ) + else: + preds = prediction[0] + preds = self.softmax(preds) + results = [] + for i, cls_name in enumerate(self.class_names): + score = float(preds[i]) + pred = { + "class_id": i, + "class": cls_name, + "confidence": round(score, 4), + } + results.append(pred) + results = sorted(results, key=lambda x: x["confidence"], reverse=True) + + response = ClassificationInferenceResponse( + image=InferenceResponseImage( + width=img_dims[ind][1], height=img_dims[ind][0] + ), + predictions=results, + top=results[0]["class"], + confidence=results[0]["confidence"], + ) + responses.append(response) + + return responses + + @staticmethod + def softmax(x): + """Compute softmax values for each set of scores in x. + + Args: + x (np.array): The input array containing the scores. + + Returns: + np.array: The softmax values for each set of scores. + """ + e_x = np.exp(x - np.max(x)) + return e_x / e_x.sum() + + def get_model_output_shape(self) -> Tuple[int, int, int]: + test_image = (np.random.rand(1024, 1024, 3) * 255).astype(np.uint8) + test_image, _ = self.preprocess(test_image) + output = np.array(self.predict(test_image)) + return output.shape + + def validate_model_classes(self) -> None: + output_shape = self.get_model_output_shape() + num_classes = output_shape[3] + try: + assert num_classes == self.num_classes + except AssertionError: + raise ValueError( + f"Number of classes in model ({num_classes}) does not match the number of classes in the environment ({self.num_classes})" + ) diff --git a/inference/core/models/defaults.py b/inference/core/models/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..209e446ca317fb19f7aaa610acefdf5a3e0121b1 --- /dev/null +++ b/inference/core/models/defaults.py @@ -0,0 +1,5 @@ +DEFAULT_CONFIDENCE = 0.4 +DEFAULT_IOU_THRESH = 0.3 +DEFAULT_CLASS_AGNOSTIC_NMS = False +DEFAUlT_MAX_DETECTIONS = 300 +DEFAULT_MAX_CANDIDATES = 3000 diff --git a/inference/core/models/instance_segmentation_base.py b/inference/core/models/instance_segmentation_base.py new file mode 100644 index 0000000000000000000000000000000000000000..bd4bc3bb9a94f8f24c9526ac37ce9b33e12f0f2d --- /dev/null +++ b/inference/core/models/instance_segmentation_base.py @@ -0,0 +1,296 @@ +from typing import Any, List, Tuple, Union + +import numpy as np + +from inference.core.entities.responses.inference import ( + InferenceResponseImage, + InstanceSegmentationInferenceResponse, + InstanceSegmentationPrediction, + Point, +) +from inference.core.exceptions import InvalidMaskDecodeArgument +from inference.core.models.roboflow import OnnxRoboflowInferenceModel +from inference.core.models.types import PreprocessReturnMetadata +from inference.core.models.utils.validate import ( + get_num_classes_from_model_prediction_shape, +) +from inference.core.nms import w_np_non_max_suppression +from inference.core.utils.postprocess import ( + masks2poly, + post_process_bboxes, + post_process_polygons, + process_mask_accurate, + process_mask_fast, + process_mask_tradeoff, +) + +DEFAULT_CONFIDENCE = 0.4 +DEFAULT_IOU_THRESH = 0.3 +DEFAULT_CLASS_AGNOSTIC_NMS = False +DEFAUlT_MAX_DETECTIONS = 300 +DEFAULT_MAX_CANDIDATES = 3000 +DEFAULT_MASK_DECODE_MODE = "accurate" +DEFAULT_TRADEOFF_FACTOR = 0.0 + +PREDICTIONS_TYPE = List[List[List[float]]] + + +class InstanceSegmentationBaseOnnxRoboflowInferenceModel(OnnxRoboflowInferenceModel): + """Roboflow ONNX Instance Segmentation model. + + This class implements an instance segmentation specific inference method + for ONNX models provided by Roboflow. + """ + + task_type = "instance-segmentation" + num_masks = 32 + + def infer( + self, + image: Any, + class_agnostic_nms: bool = False, + confidence: float = DEFAULT_CONFIDENCE, + disable_preproc_auto_orient: bool = False, + disable_preproc_contrast: bool = False, + disable_preproc_grayscale: bool = False, + disable_preproc_static_crop: bool = False, + iou_threshold: float = DEFAULT_IOU_THRESH, + mask_decode_mode: str = DEFAULT_MASK_DECODE_MODE, + max_candidates: int = DEFAULT_MAX_CANDIDATES, + max_detections: int = DEFAUlT_MAX_DETECTIONS, + return_image_dims: bool = False, + tradeoff_factor: float = DEFAULT_TRADEOFF_FACTOR, + **kwargs, + ) -> Union[PREDICTIONS_TYPE, Tuple[PREDICTIONS_TYPE, List[Tuple[int, int]]]]: + """ + Process an image or list of images for instance segmentation. + + Args: + image (Any): An image or a list of images for processing. + class_agnostic_nms (bool, optional): Whether to use class-agnostic non-maximum suppression. Defaults to False. + confidence (float, optional): Confidence threshold for predictions. Defaults to 0.5. + iou_threshold (float, optional): IoU threshold for non-maximum suppression. Defaults to 0.5. + mask_decode_mode (str, optional): Decoding mode for masks. Choices are "accurate", "tradeoff", and "fast". Defaults to "accurate". + max_candidates (int, optional): Maximum number of candidate detections. Defaults to 3000. + max_detections (int, optional): Maximum number of detections after non-maximum suppression. Defaults to 300. + return_image_dims (bool, optional): Whether to return the dimensions of the processed images. Defaults to False. + tradeoff_factor (float, optional): Tradeoff factor used when `mask_decode_mode` is set to "tradeoff". Must be in [0.0, 1.0]. Defaults to 0.5. + disable_preproc_auto_orient (bool, optional): If true, the auto orient preprocessing step is disabled for this call. Default is False. + disable_preproc_contrast (bool, optional): If true, the auto contrast preprocessing step is disabled for this call. Default is False. + disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False. + disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False. + **kwargs: Additional parameters to customize the inference process. + + Returns: + Union[List[List[List[float]]], Tuple[List[List[List[float]]], List[Tuple[int, int]]]]: The list of predictions, with each prediction being a list of lists. Optionally, also returns the dimensions of the processed images. + + Raises: + InvalidMaskDecodeArgument: If an invalid `mask_decode_mode` is provided or if the `tradeoff_factor` is outside the allowed range. + + Notes: + - Processes input images and normalizes them. + - Makes predictions using the ONNX runtime. + - Applies non-maximum suppression to the predictions. + - Decodes the masks according to the specified mode. + """ + return super().infer( + image, + class_agnostic_nms=class_agnostic_nms, + confidence=confidence, + disable_preproc_auto_orient=disable_preproc_auto_orient, + disable_preproc_contrast=disable_preproc_contrast, + disable_preproc_grayscale=disable_preproc_grayscale, + disable_preproc_static_crop=disable_preproc_static_crop, + iou_threshold=iou_threshold, + mask_decode_mode=mask_decode_mode, + max_candidates=max_candidates, + max_detections=max_detections, + return_image_dims=return_image_dims, + tradeoff_factor=tradeoff_factor, + ) + + def postprocess( + self, + predictions: Tuple[np.ndarray, np.ndarray], + preprocess_return_metadata: PreprocessReturnMetadata, + **kwargs, + ) -> Union[ + InstanceSegmentationInferenceResponse, + List[InstanceSegmentationInferenceResponse], + ]: + predictions, protos = predictions + predictions = w_np_non_max_suppression( + predictions, + conf_thresh=kwargs["confidence"], + iou_thresh=kwargs["iou_threshold"], + class_agnostic=kwargs["class_agnostic_nms"], + max_detections=kwargs["max_detections"], + max_candidate_detections=kwargs["max_candidates"], + num_masks=self.num_masks, + ) + infer_shape = (self.img_size_h, self.img_size_w) + predictions = np.array(predictions) + masks = [] + mask_decode_mode = kwargs["mask_decode_mode"] + tradeoff_factor = kwargs["tradeoff_factor"] + img_in_shape = preprocess_return_metadata["im_shape"] + if predictions.shape[1] > 0: + for i, (pred, proto, img_dim) in enumerate( + zip(predictions, protos, preprocess_return_metadata["img_dims"]) + ): + if mask_decode_mode == "accurate": + batch_masks = process_mask_accurate( + proto, pred[:, 7:], pred[:, :4], img_in_shape[2:] + ) + output_mask_shape = img_in_shape[2:] + elif mask_decode_mode == "tradeoff": + if not 0 <= tradeoff_factor <= 1: + raise InvalidMaskDecodeArgument( + f"Invalid tradeoff_factor: {tradeoff_factor}. Must be in [0.0, 1.0]" + ) + batch_masks = process_mask_tradeoff( + proto, + pred[:, 7:], + pred[:, :4], + img_in_shape[2:], + tradeoff_factor, + ) + output_mask_shape = batch_masks.shape[1:] + elif mask_decode_mode == "fast": + batch_masks = process_mask_fast( + proto, pred[:, 7:], pred[:, :4], img_in_shape[2:] + ) + output_mask_shape = batch_masks.shape[1:] + else: + raise InvalidMaskDecodeArgument( + f"Invalid mask_decode_mode: {mask_decode_mode}. Must be one of ['accurate', 'fast', 'tradeoff']" + ) + polys = masks2poly(batch_masks) + pred[:, :4] = post_process_bboxes( + [pred[:, :4]], + infer_shape, + [img_dim], + self.preproc, + resize_method=self.resize_method, + disable_preproc_static_crop=preprocess_return_metadata[ + "disable_preproc_static_crop" + ], + )[0] + polys = post_process_polygons( + img_dim, + polys, + output_mask_shape, + self.preproc, + resize_method=self.resize_method, + ) + masks.append(polys) + else: + masks.extend([[]] * len(predictions)) + return self.make_response( + predictions, masks, preprocess_return_metadata["img_dims"], **kwargs + ) + + def preprocess( + self, image: Any, **kwargs + ) -> Tuple[np.ndarray, PreprocessReturnMetadata]: + img_in, img_dims = self.load_image( + image, + disable_preproc_auto_orient=kwargs.get("disable_preproc_auto_orient"), + disable_preproc_contrast=kwargs.get("disable_preproc_contrast"), + disable_preproc_grayscale=kwargs.get("disable_preproc_grayscale"), + disable_preproc_static_crop=kwargs.get("disable_preproc_static_crop"), + ) + + img_in /= 255.0 + return img_in, PreprocessReturnMetadata( + { + "img_dims": img_dims, + "im_shape": img_in.shape, + "disable_preproc_static_crop": kwargs.get( + "disable_preproc_static_crop" + ), + } + ) + + def make_response( + self, + predictions: List[List[List[float]]], + masks: List[List[List[float]]], + img_dims: List[Tuple[int, int]], + class_filter: List[str] = [], + **kwargs, + ) -> Union[ + InstanceSegmentationInferenceResponse, + List[InstanceSegmentationInferenceResponse], + ]: + """ + Create instance segmentation inference response objects for the provided predictions and masks. + + Args: + predictions (List[List[List[float]]]): List of prediction data, one for each image. + masks (List[List[List[float]]]): List of masks corresponding to the predictions. + img_dims (List[Tuple[int, int]]): List of image dimensions corresponding to the processed images. + class_filter (List[str], optional): List of class names to filter predictions by. Defaults to an empty list (no filtering). + + Returns: + Union[InstanceSegmentationInferenceResponse, List[InstanceSegmentationInferenceResponse]]: A single instance segmentation response or a list of instance segmentation responses based on the number of processed images. + + Notes: + - For each image, constructs an `InstanceSegmentationInferenceResponse` object. + - Each response contains a list of `InstanceSegmentationPrediction` objects. + """ + responses = [ + InstanceSegmentationInferenceResponse( + predictions=[ + InstanceSegmentationPrediction( + # Passing args as a dictionary here since one of the args is 'class' (a protected term in Python) + **{ + "x": (pred[0] + pred[2]) / 2, + "y": (pred[1] + pred[3]) / 2, + "width": pred[2] - pred[0], + "height": pred[3] - pred[1], + "points": [Point(x=point[0], y=point[1]) for point in mask], + "confidence": pred[4], + "class": self.class_names[int(pred[6])], + "class_id": int(pred[6]), + } + ) + for pred, mask in zip(batch_predictions, batch_masks) + if not class_filter + or self.class_names[int(pred[6])] in class_filter + ], + image=InferenceResponseImage( + width=img_dims[ind][1], height=img_dims[ind][0] + ), + ) + for ind, (batch_predictions, batch_masks) in enumerate( + zip(predictions, masks) + ) + ] + return responses + + def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, np.ndarray]: + """Runs inference on the ONNX model. + + Args: + img_in (np.ndarray): The preprocessed image(s) to run inference on. + + Returns: + Tuple[np.ndarray, np.ndarray]: The ONNX model predictions and the ONNX model protos. + + Raises: + NotImplementedError: This method must be implemented by a subclass. + """ + raise NotImplementedError("predict must be implemented by a subclass") + + def validate_model_classes(self) -> None: + output_shape = self.get_model_output_shape() + num_classes = get_num_classes_from_model_prediction_shape( + output_shape[2], masks=self.num_masks + ) + try: + assert num_classes == self.num_classes + except AssertionError: + raise ValueError( + f"Number of classes in model ({num_classes}) does not match the number of classes in the environment ({self.num_classes})" + ) diff --git a/inference/core/models/keypoints_detection_base.py b/inference/core/models/keypoints_detection_base.py new file mode 100644 index 0000000000000000000000000000000000000000..f4da1939c940c9989a77b2d5559c5a0c8e9da813 --- /dev/null +++ b/inference/core/models/keypoints_detection_base.py @@ -0,0 +1,182 @@ +from typing import List, Optional, Tuple + +import numpy as np + +from inference.core.entities.responses.inference import ( + InferenceResponseImage, + Keypoint, + KeypointsDetectionInferenceResponse, + KeypointsPrediction, +) +from inference.core.exceptions import ModelArtefactError +from inference.core.models.object_detection_base import ( + ObjectDetectionBaseOnnxRoboflowInferenceModel, +) +from inference.core.models.types import PreprocessReturnMetadata +from inference.core.models.utils.keypoints import model_keypoints_to_response +from inference.core.models.utils.validate import ( + get_num_classes_from_model_prediction_shape, +) +from inference.core.nms import w_np_non_max_suppression +from inference.core.utils.postprocess import post_process_bboxes, post_process_keypoints + +DEFAULT_CONFIDENCE = 0.4 +DEFAULT_IOU_THRESH = 0.3 +DEFAULT_CLASS_AGNOSTIC_NMS = False +DEFAUlT_MAX_DETECTIONS = 300 +DEFAULT_MAX_CANDIDATES = 3000 + + +class KeypointsDetectionBaseOnnxRoboflowInferenceModel( + ObjectDetectionBaseOnnxRoboflowInferenceModel +): + """Roboflow ONNX Object detection model. This class implements an object detection specific infer method.""" + + task_type = "keypoint-detection" + + def __init__(self, model_id: str, *args, **kwargs): + super().__init__(model_id, *args, **kwargs) + + def get_infer_bucket_file_list(self) -> list: + """Returns the list of files to be downloaded from the inference bucket for ONNX model. + + Returns: + list: A list of filenames specific to ONNX models. + """ + return ["environment.json", "class_names.txt", "keypoints_metadata.json"] + + def postprocess( + self, + predictions: Tuple[np.ndarray], + preproc_return_metadata: PreprocessReturnMetadata, + class_agnostic_nms=DEFAULT_CLASS_AGNOSTIC_NMS, + confidence: float = DEFAULT_CONFIDENCE, + iou_threshold: float = DEFAULT_IOU_THRESH, + max_candidates: int = DEFAULT_MAX_CANDIDATES, + max_detections: int = DEFAUlT_MAX_DETECTIONS, + return_image_dims: bool = False, + **kwargs, + ) -> List[KeypointsDetectionInferenceResponse]: + """Postprocesses the object detection predictions. + + Args: + predictions (np.ndarray): Raw predictions from the model. + img_dims (List[Tuple[int, int]]): Dimensions of the images. + class_agnostic_nms (bool): Whether to apply class-agnostic non-max suppression. Default is False. + confidence (float): Confidence threshold for filtering detections. Default is 0.5. + iou_threshold (float): IoU threshold for non-max suppression. Default is 0.5. + max_candidates (int): Maximum number of candidate detections. Default is 3000. + max_detections (int): Maximum number of final detections. Default is 300. + + Returns: + List[KeypointsDetectionInferenceResponse]: The post-processed predictions. + """ + predictions = predictions[0] + number_of_classes = len(self.get_class_names) + num_masks = predictions.shape[2] - 5 - number_of_classes + predictions = w_np_non_max_suppression( + predictions, + conf_thresh=confidence, + iou_thresh=iou_threshold, + class_agnostic=class_agnostic_nms, + max_detections=max_detections, + max_candidate_detections=max_candidates, + num_masks=num_masks, + ) + + infer_shape = (self.img_size_h, self.img_size_w) + img_dims = preproc_return_metadata["img_dims"] + predictions = post_process_bboxes( + predictions=predictions, + infer_shape=infer_shape, + img_dims=img_dims, + preproc=self.preproc, + resize_method=self.resize_method, + disable_preproc_static_crop=preproc_return_metadata[ + "disable_preproc_static_crop" + ], + ) + predictions = post_process_keypoints( + predictions=predictions, + keypoints_start_index=-num_masks, + infer_shape=infer_shape, + img_dims=img_dims, + preproc=self.preproc, + resize_method=self.resize_method, + disable_preproc_static_crop=preproc_return_metadata[ + "disable_preproc_static_crop" + ], + ) + return self.make_response(predictions, img_dims, **kwargs) + + def make_response( + self, + predictions: List[List[float]], + img_dims: List[Tuple[int, int]], + class_filter: Optional[List[str]] = None, + *args, + **kwargs, + ) -> List[KeypointsDetectionInferenceResponse]: + """Constructs object detection response objects based on predictions. + + Args: + predictions (List[List[float]]): The list of predictions. + img_dims (List[Tuple[int, int]]): Dimensions of the images. + class_filter (Optional[List[str]]): A list of class names to filter, if provided. + + Returns: + List[KeypointsDetectionInferenceResponse]: A list of response objects containing keypoints detection predictions. + """ + if isinstance(img_dims, dict) and "img_dims" in img_dims: + img_dims = img_dims["img_dims"] + keypoint_confidence_threshold = 0.0 + if "request" in kwargs: + keypoint_confidence_threshold = kwargs["request"].keypoint_confidence + responses = [ + KeypointsDetectionInferenceResponse( + predictions=[ + KeypointsPrediction( + # Passing args as a dictionary here since one of the args is 'class' (a protected term in Python) + **{ + "x": (pred[0] + pred[2]) / 2, + "y": (pred[1] + pred[3]) / 2, + "width": pred[2] - pred[0], + "height": pred[3] - pred[1], + "confidence": pred[4], + "class": self.class_names[int(pred[6])], + "class_id": int(pred[6]), + "keypoints": model_keypoints_to_response( + keypoints_metadata=self.keypoints_metadata, + keypoints=pred[7:], + predicted_object_class_id=int( + pred[4 + len(self.get_class_names)] + ), + keypoint_confidence_threshold=keypoint_confidence_threshold, + ), + } + ) + for pred in batch_predictions + if not class_filter + or self.class_names[int(pred[6])] in class_filter + ], + image=InferenceResponseImage( + width=img_dims[ind][1], height=img_dims[ind][0] + ), + ) + for ind, batch_predictions in enumerate(predictions) + ] + return responses + + def keypoints_count(self) -> int: + raise NotImplementedError + + def validate_model_classes(self) -> None: + num_keypoints = self.keypoints_count() + output_shape = self.get_model_output_shape() + num_classes = get_num_classes_from_model_prediction_shape( + len_prediction=output_shape[2], keypoints=num_keypoints + ) + if num_classes != self.num_classes: + raise ValueError( + f"Number of classes in model ({num_classes}) does not match the number of classes in the environment ({self.num_classes})" + ) diff --git a/inference/core/models/object_detection_base.py b/inference/core/models/object_detection_base.py new file mode 100644 index 0000000000000000000000000000000000000000..fc4734b7e654103f356a823e25f4eebf76949a00 --- /dev/null +++ b/inference/core/models/object_detection_base.py @@ -0,0 +1,287 @@ +from typing import Any, List, Optional, Tuple, Union + +import numpy as np + +from inference.core.entities.responses.inference import ( + InferenceResponseImage, + ObjectDetectionInferenceResponse, + ObjectDetectionPrediction, +) +from inference.core.env import FIX_BATCH_SIZE, MAX_BATCH_SIZE +from inference.core.logger import logger +from inference.core.models.defaults import ( + DEFAULT_CLASS_AGNOSTIC_NMS, + DEFAULT_CONFIDENCE, + DEFAULT_IOU_THRESH, + DEFAULT_MAX_CANDIDATES, + DEFAUlT_MAX_DETECTIONS, +) +from inference.core.models.roboflow import OnnxRoboflowInferenceModel +from inference.core.models.types import PreprocessReturnMetadata +from inference.core.models.utils.validate import ( + get_num_classes_from_model_prediction_shape, +) +from inference.core.nms import w_np_non_max_suppression +from inference.core.utils.postprocess import post_process_bboxes + + +class ObjectDetectionBaseOnnxRoboflowInferenceModel(OnnxRoboflowInferenceModel): + """Roboflow ONNX Object detection model. This class implements an object detection specific infer method.""" + + task_type = "object-detection" + box_format = "xywh" + + def infer( + self, + image: Any, + class_agnostic_nms: bool = DEFAULT_CLASS_AGNOSTIC_NMS, + confidence: float = DEFAULT_CONFIDENCE, + disable_preproc_auto_orient: bool = False, + disable_preproc_contrast: bool = False, + disable_preproc_grayscale: bool = False, + disable_preproc_static_crop: bool = False, + iou_threshold: float = DEFAULT_IOU_THRESH, + fix_batch_size: bool = False, + max_candidates: int = DEFAULT_MAX_CANDIDATES, + max_detections: int = DEFAUlT_MAX_DETECTIONS, + return_image_dims: bool = False, + **kwargs, + ) -> Any: + """ + Runs object detection inference on one or multiple images and returns the detections. + + Args: + image (Any): The input image or a list of images to process. + class_agnostic_nms (bool, optional): Whether to use class-agnostic non-maximum suppression. Defaults to False. + confidence (float, optional): Confidence threshold for predictions. Defaults to 0.5. + iou_threshold (float, optional): IoU threshold for non-maximum suppression. Defaults to 0.5. + fix_batch_size (bool, optional): If True, fix the batch size for predictions. Useful when the model requires a fixed batch size. Defaults to False. + max_candidates (int, optional): Maximum number of candidate detections. Defaults to 3000. + max_detections (int, optional): Maximum number of detections after non-maximum suppression. Defaults to 300. + return_image_dims (bool, optional): Whether to return the dimensions of the processed images along with the predictions. Defaults to False. + disable_preproc_auto_orient (bool, optional): If true, the auto orient preprocessing step is disabled for this call. Default is False. + disable_preproc_contrast (bool, optional): If true, the auto contrast preprocessing step is disabled for this call. Default is False. + disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False. + disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + Union[List[ObjectDetectionInferenceResponse], ObjectDetectionInferenceResponse]: One or multiple object detection inference responses based on the number of processed images. Each response contains a list of predictions. If `return_image_dims` is True, it will return a tuple with predictions and image dimensions. + + Raises: + ValueError: If batching is not enabled for the model and more than one image is passed for processing. + """ + return super().infer( + image, + class_agnostic_nms=class_agnostic_nms, + confidence=confidence, + disable_preproc_auto_orient=disable_preproc_auto_orient, + disable_preproc_contrast=disable_preproc_contrast, + disable_preproc_grayscale=disable_preproc_grayscale, + disable_preproc_static_crop=disable_preproc_static_crop, + iou_threshold=iou_threshold, + fix_batch_size=fix_batch_size, + max_candidates=max_candidates, + max_detections=max_detections, + return_image_dims=return_image_dims, + **kwargs, + ) + + def make_response( + self, + predictions: List[List[float]], + img_dims: List[Tuple[int, int]], + class_filter: Optional[List[str]] = None, + *args, + **kwargs, + ) -> List[ObjectDetectionInferenceResponse]: + """Constructs object detection response objects based on predictions. + + Args: + predictions (List[List[float]]): The list of predictions. + img_dims (List[Tuple[int, int]]): Dimensions of the images. + class_filter (Optional[List[str]]): A list of class names to filter, if provided. + + Returns: + List[ObjectDetectionInferenceResponse]: A list of response objects containing object detection predictions. + """ + + if isinstance(img_dims, dict) and "img_dims" in img_dims: + img_dims = img_dims["img_dims"] + + predictions = predictions[ + : len(img_dims) + ] # If the batch size was fixed we have empty preds at the end + responses = [ + ObjectDetectionInferenceResponse( + predictions=[ + ObjectDetectionPrediction( + # Passing args as a dictionary here since one of the args is 'class' (a protected term in Python) + **{ + "x": (pred[0] + pred[2]) / 2, + "y": (pred[1] + pred[3]) / 2, + "width": pred[2] - pred[0], + "height": pred[3] - pred[1], + "confidence": pred[4], + "class": self.class_names[int(pred[6])], + "class_id": int(pred[6]), + } + ) + for pred in batch_predictions + if not class_filter + or self.class_names[int(pred[6])] in class_filter + ], + image=InferenceResponseImage( + width=img_dims[ind][1], height=img_dims[ind][0] + ), + ) + for ind, batch_predictions in enumerate(predictions) + ] + return responses + + def postprocess( + self, + predictions: Tuple[np.ndarray, ...], + preproc_return_metadata: PreprocessReturnMetadata, + class_agnostic_nms=DEFAULT_CLASS_AGNOSTIC_NMS, + confidence: float = DEFAULT_CONFIDENCE, + iou_threshold: float = DEFAULT_IOU_THRESH, + max_candidates: int = DEFAULT_MAX_CANDIDATES, + max_detections: int = DEFAUlT_MAX_DETECTIONS, + return_image_dims: bool = False, + **kwargs, + ) -> List[ObjectDetectionInferenceResponse]: + """Postprocesses the object detection predictions. + + Args: + predictions (np.ndarray): Raw predictions from the model. + img_dims (List[Tuple[int, int]]): Dimensions of the images. + class_agnostic_nms (bool): Whether to apply class-agnostic non-max suppression. Default is False. + confidence (float): Confidence threshold for filtering detections. Default is 0.5. + iou_threshold (float): IoU threshold for non-max suppression. Default is 0.5. + max_candidates (int): Maximum number of candidate detections. Default is 3000. + max_detections (int): Maximum number of final detections. Default is 300. + + Returns: + List[ObjectDetectionInferenceResponse]: The post-processed predictions. + """ + predictions = predictions[0] + + predictions = w_np_non_max_suppression( + predictions, + conf_thresh=confidence, + iou_thresh=iou_threshold, + class_agnostic=class_agnostic_nms, + max_detections=max_detections, + max_candidate_detections=max_candidates, + box_format=self.box_format, + ) + + infer_shape = (self.img_size_h, self.img_size_w) + img_dims = preproc_return_metadata["img_dims"] + predictions = post_process_bboxes( + predictions, + infer_shape, + img_dims, + self.preproc, + resize_method=self.resize_method, + disable_preproc_static_crop=preproc_return_metadata[ + "disable_preproc_static_crop" + ], + ) + return self.make_response(predictions, img_dims, **kwargs) + + def preprocess( + self, + image: Any, + disable_preproc_auto_orient: bool = False, + disable_preproc_contrast: bool = False, + disable_preproc_grayscale: bool = False, + disable_preproc_static_crop: bool = False, + fix_batch_size: bool = False, + **kwargs, + ) -> Tuple[np.ndarray, PreprocessReturnMetadata]: + """Preprocesses an object detection inference request. + + Args: + request (ObjectDetectionInferenceRequest): The request object containing images. + + Returns: + Tuple[np.ndarray, List[Tuple[int, int]]]: Preprocessed image inputs and corresponding dimensions. + """ + img_in, img_dims = self.load_image( + image, + disable_preproc_auto_orient=disable_preproc_auto_orient, + disable_preproc_contrast=disable_preproc_contrast, + disable_preproc_grayscale=disable_preproc_grayscale, + disable_preproc_static_crop=disable_preproc_static_crop, + ) + + img_in /= 255.0 + + if self.batching_enabled: + batch_padding = 0 + if FIX_BATCH_SIZE or fix_batch_size: + if MAX_BATCH_SIZE == float("inf"): + logger.warn( + "Requested fix_batch_size but MAX_BATCH_SIZE is not set. Using dynamic batching." + ) + batch_padding = 0 + else: + batch_padding = MAX_BATCH_SIZE - img_in.shape[0] + if batch_padding < 0: + raise ValueError( + f"Requested fix_batch_size but passed in {img_in.shape[0]} images " + f"when the model's batch size is {MAX_BATCH_SIZE}\n" + f"Consider turning off fix_batch_size, changing `MAX_BATCH_SIZE` in" + f"your inference server config, or passing at most {MAX_BATCH_SIZE} images at a time" + ) + width_remainder = img_in.shape[2] % 32 + height_remainder = img_in.shape[3] % 32 + if width_remainder > 0: + width_padding = 32 - (img_in.shape[2] % 32) + else: + width_padding = 0 + if height_remainder > 0: + height_padding = 32 - (img_in.shape[3] % 32) + else: + height_padding = 0 + img_in = np.pad( + img_in, + ((0, batch_padding), (0, 0), (0, width_padding), (0, height_padding)), + "constant", + ) + + return img_in, PreprocessReturnMetadata( + { + "img_dims": img_dims, + "disable_preproc_static_crop": disable_preproc_static_crop, + } + ) + + def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]: + """Runs inference on the ONNX model. + + Args: + img_in (np.ndarray): The preprocessed image(s) to run inference on. + + Returns: + Tuple[np.ndarray]: The ONNX model predictions. + + Raises: + NotImplementedError: This method must be implemented by a subclass. + """ + raise NotImplementedError("predict must be implemented by a subclass") + + def validate_model_classes(self) -> None: + output_shape = self.get_model_output_shape() + num_classes = get_num_classes_from_model_prediction_shape( + output_shape[2], masks=0 + ) + try: + assert num_classes == self.num_classes + except AssertionError: + raise ValueError( + f"Number of classes in model ({num_classes}) does not match the number of classes in the environment ({self.num_classes})" + ) diff --git a/inference/core/models/roboflow.py b/inference/core/models/roboflow.py new file mode 100644 index 0000000000000000000000000000000000000000..95e931c8e6c041ed2d976fa476cf66df8d4d0b68 --- /dev/null +++ b/inference/core/models/roboflow.py @@ -0,0 +1,858 @@ +import itertools +import json +import os +from collections import OrderedDict +from concurrent.futures import ThreadPoolExecutor +from functools import partial +from time import perf_counter +from typing import Any, Dict, List, Optional, Tuple, Union + +import cv2 +import numpy as np +import onnxruntime +from PIL import Image + +from inference.core.cache import cache +from inference.core.cache.model_artifacts import ( + are_all_files_cached, + clear_cache, + get_cache_dir, + get_cache_file_path, + initialise_cache, + load_json_from_cache, + load_text_file_from_cache, + save_bytes_in_cache, + save_json_in_cache, + save_text_lines_in_cache, +) +from inference.core.devices.utils import GLOBAL_DEVICE_ID +from inference.core.entities.requests.inference import ( + InferenceRequest, + InferenceRequestImage, +) +from inference.core.entities.responses.inference import InferenceResponse +from inference.core.env import ( + API_KEY, + API_KEY_ENV_NAMES, + AWS_ACCESS_KEY_ID, + AWS_SECRET_ACCESS_KEY, + CORE_MODEL_BUCKET, + DISABLE_PREPROC_AUTO_ORIENT, + INFER_BUCKET, + LAMBDA, + MAX_BATCH_SIZE, + MODEL_CACHE_DIR, + ONNXRUNTIME_EXECUTION_PROVIDERS, + REQUIRED_ONNX_PROVIDERS, + TENSORRT_CACHE_PATH, +) +from inference.core.exceptions import ( + MissingApiKeyError, + ModelArtefactError, + OnnxProviderNotAvailable, +) +from inference.core.logger import logger +from inference.core.models.base import Model +from inference.core.models.utils.batching import ( + calculate_input_elements, + create_batches, +) +from inference.core.roboflow_api import ( + ModelEndpointType, + get_from_url, + get_roboflow_model_data, +) +from inference.core.utils.image_utils import load_image +from inference.core.utils.onnx import get_onnxruntime_execution_providers +from inference.core.utils.preprocess import letterbox_image, prepare +from inference.core.utils.visualisation import draw_detection_predictions +from inference.models.aliases import resolve_roboflow_model_alias + +NUM_S3_RETRY = 5 +SLEEP_SECONDS_BETWEEN_RETRIES = 3 +MODEL_METADATA_CACHE_EXPIRATION_TIMEOUT = 3600 # 1 hour + +S3_CLIENT = None +if AWS_ACCESS_KEY_ID and AWS_ACCESS_KEY_ID: + try: + import boto3 + from botocore.config import Config + + from inference.core.utils.s3 import download_s3_files_to_directory + + config = Config(retries={"max_attempts": NUM_S3_RETRY, "mode": "standard"}) + S3_CLIENT = boto3.client("s3", config=config) + except: + logger.debug("Error loading boto3") + pass + +DEFAULT_COLOR_PALETTE = [ + "#4892EA", + "#00EEC3", + "#FE4EF0", + "#F4004E", + "#FA7200", + "#EEEE17", + "#90FF00", + "#78C1D2", + "#8C29FF", +] + + +class RoboflowInferenceModel(Model): + """Base Roboflow inference model.""" + + def __init__( + self, + model_id: str, + cache_dir_root=MODEL_CACHE_DIR, + api_key=None, + load_weights=True, + ): + """ + Initialize the RoboflowInferenceModel object. + + Args: + model_id (str): The unique identifier for the model. + cache_dir_root (str, optional): The root directory for the cache. Defaults to MODEL_CACHE_DIR. + api_key (str, optional): API key for authentication. Defaults to None. + """ + super().__init__() + self.load_weights = load_weights + self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0} + self.api_key = api_key if api_key else API_KEY + model_id = resolve_roboflow_model_alias(model_id=model_id) + self.dataset_id, self.version_id = model_id.split("/") + self.endpoint = model_id + self.device_id = GLOBAL_DEVICE_ID + self.cache_dir = os.path.join(cache_dir_root, self.endpoint) + self.keypoints_metadata: Optional[dict] = None + initialise_cache(model_id=self.endpoint) + + def cache_file(self, f: str) -> str: + """Get the cache file path for a given file. + + Args: + f (str): Filename. + + Returns: + str: Full path to the cached file. + """ + return get_cache_file_path(file=f, model_id=self.endpoint) + + def clear_cache(self) -> None: + """Clear the cache directory.""" + clear_cache(model_id=self.endpoint) + + def draw_predictions( + self, + inference_request: InferenceRequest, + inference_response: InferenceResponse, + ) -> bytes: + """Draw predictions from an inference response onto the original image provided by an inference request + + Args: + inference_request (ObjectDetectionInferenceRequest): The inference request containing the image on which to draw predictions + inference_response (ObjectDetectionInferenceResponse): The inference response containing predictions to be drawn + + Returns: + str: A base64 encoded image string + """ + return draw_detection_predictions( + inference_request=inference_request, + inference_response=inference_response, + colors=self.colors, + ) + + @property + def get_class_names(self): + return self.class_names + + def get_device_id(self) -> str: + """ + Get the device identifier on which the model is deployed. + + Returns: + str: Device identifier. + """ + return self.device_id + + def get_infer_bucket_file_list(self) -> List[str]: + """Get a list of inference bucket files. + + Raises: + NotImplementedError: If the method is not implemented. + + Returns: + List[str]: A list of inference bucket files. + """ + raise NotImplementedError( + self.__class__.__name__ + ".get_infer_bucket_file_list" + ) + + @property + def cache_key(self): + return f"metadata:{self.endpoint}" + + @staticmethod + def model_metadata_from_memcache_endpoint(endpoint): + model_metadata = cache.get(f"metadata:{endpoint}") + return model_metadata + + def model_metadata_from_memcache(self): + model_metadata = cache.get(self.cache_key) + return model_metadata + + def write_model_metadata_to_memcache(self, metadata): + cache.set( + self.cache_key, metadata, expire=MODEL_METADATA_CACHE_EXPIRATION_TIMEOUT + ) + + @property + def has_model_metadata(self): + return self.model_metadata_from_memcache() is not None + + def get_model_artifacts(self) -> None: + """Fetch or load the model artifacts. + + Downloads the model artifacts from S3 or the Roboflow API if they are not already cached. + """ + self.cache_model_artefacts() + self.load_model_artifacts_from_cache() + + def cache_model_artefacts(self) -> None: + infer_bucket_files = self.get_all_required_infer_bucket_file() + if are_all_files_cached(files=infer_bucket_files, model_id=self.endpoint): + return None + if is_model_artefacts_bucket_available(): + self.download_model_artefacts_from_s3() + return None + self.download_model_artifacts_from_roboflow_api() + + def get_all_required_infer_bucket_file(self) -> List[str]: + infer_bucket_files = self.get_infer_bucket_file_list() + infer_bucket_files.append(self.weights_file) + logger.debug(f"List of files required to load model: {infer_bucket_files}") + return [f for f in infer_bucket_files if f is not None] + + def download_model_artefacts_from_s3(self) -> None: + try: + logger.debug("Downloading model artifacts from S3") + infer_bucket_files = self.get_all_required_infer_bucket_file() + cache_directory = get_cache_dir() + s3_keys = [f"{self.endpoint}/{file}" for file in infer_bucket_files] + download_s3_files_to_directory( + bucket=self.model_artifact_bucket, + keys=s3_keys, + target_dir=cache_directory, + s3_client=S3_CLIENT, + ) + except Exception as error: + raise ModelArtefactError( + f"Could not obtain model artefacts from S3 with keys {s3_keys}. Cause: {error}" + ) from error + + @property + def model_artifact_bucket(self): + return INFER_BUCKET + + def download_model_artifacts_from_roboflow_api(self) -> None: + logger.debug("Downloading model artifacts from Roboflow API") + api_data = get_roboflow_model_data( + api_key=self.api_key, + model_id=self.endpoint, + endpoint_type=ModelEndpointType.ORT, + device_id=self.device_id, + ) + if "ort" not in api_data.keys(): + raise ModelArtefactError( + "Could not find `ort` key in roboflow API model description response." + ) + api_data = api_data["ort"] + if "classes" in api_data: + save_text_lines_in_cache( + content=api_data["classes"], + file="class_names.txt", + model_id=self.endpoint, + ) + if "model" not in api_data: + raise ModelArtefactError( + "Could not find `model` key in roboflow API model description response." + ) + if "environment" not in api_data: + raise ModelArtefactError( + "Could not find `environment` key in roboflow API model description response." + ) + environment = get_from_url(api_data["environment"]) + model_weights_response = get_from_url(api_data["model"], json_response=False) + save_bytes_in_cache( + content=model_weights_response.content, + file=self.weights_file, + model_id=self.endpoint, + ) + if "colors" in api_data: + environment["COLORS"] = api_data["colors"] + save_json_in_cache( + content=environment, + file="environment.json", + model_id=self.endpoint, + ) + if "keypoints_metadata" in api_data: + # TODO: make sure backend provides that + save_json_in_cache( + content=api_data["keypoints_metadata"], + file="keypoints_metadata.json", + model_id=self.endpoint, + ) + + def load_model_artifacts_from_cache(self) -> None: + logger.debug("Model artifacts already downloaded, loading model from cache") + infer_bucket_files = self.get_all_required_infer_bucket_file() + if "environment.json" in infer_bucket_files: + self.environment = load_json_from_cache( + file="environment.json", + model_id=self.endpoint, + object_pairs_hook=OrderedDict, + ) + if "class_names.txt" in infer_bucket_files: + self.class_names = load_text_file_from_cache( + file="class_names.txt", + model_id=self.endpoint, + split_lines=True, + strip_white_chars=True, + ) + else: + self.class_names = get_class_names_from_environment_file( + environment=self.environment + ) + self.colors = get_color_mapping_from_environment( + environment=self.environment, + class_names=self.class_names, + ) + if "keypoints_metadata.json" in infer_bucket_files: + self.keypoints_metadata = parse_keypoints_metadata( + load_json_from_cache( + file="keypoints_metadata.json", + model_id=self.endpoint, + object_pairs_hook=OrderedDict, + ) + ) + self.num_classes = len(self.class_names) + if "PREPROCESSING" not in self.environment: + raise ModelArtefactError( + "Could not find `PREPROCESSING` key in environment file." + ) + if issubclass(type(self.environment["PREPROCESSING"]), dict): + self.preproc = self.environment["PREPROCESSING"] + else: + self.preproc = json.loads(self.environment["PREPROCESSING"]) + if self.preproc.get("resize"): + self.resize_method = self.preproc["resize"].get("format", "Stretch to") + if self.resize_method not in [ + "Stretch to", + "Fit (black edges) in", + "Fit (white edges) in", + ]: + self.resize_method = "Stretch to" + else: + self.resize_method = "Stretch to" + logger.debug(f"Resize method is '{self.resize_method}'") + self.multiclass = self.environment.get("MULTICLASS", False) + + def initialize_model(self) -> None: + """Initialize the model. + + Raises: + NotImplementedError: If the method is not implemented. + """ + raise NotImplementedError(self.__class__.__name__ + ".initialize_model") + + def preproc_image( + self, + image: Union[Any, InferenceRequestImage], + disable_preproc_auto_orient: bool = False, + disable_preproc_contrast: bool = False, + disable_preproc_grayscale: bool = False, + disable_preproc_static_crop: bool = False, + ) -> Tuple[np.ndarray, Tuple[int, int]]: + """ + Preprocesses an inference request image by loading it, then applying any pre-processing specified by the Roboflow platform, then scaling it to the inference input dimensions. + + Args: + image (Union[Any, InferenceRequestImage]): An object containing information necessary to load the image for inference. + disable_preproc_auto_orient (bool, optional): If true, the auto orient preprocessing step is disabled for this call. Default is False. + disable_preproc_contrast (bool, optional): If true, the contrast preprocessing step is disabled for this call. Default is False. + disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False. + disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False. + + Returns: + Tuple[np.ndarray, Tuple[int, int]]: A tuple containing a numpy array of the preprocessed image pixel data and a tuple of the images original size. + """ + np_image, is_bgr = load_image( + image, + disable_preproc_auto_orient=disable_preproc_auto_orient + or "auto-orient" not in self.preproc.keys() + or DISABLE_PREPROC_AUTO_ORIENT, + ) + preprocessed_image, img_dims = self.preprocess_image( + np_image, + disable_preproc_contrast=disable_preproc_contrast, + disable_preproc_grayscale=disable_preproc_grayscale, + disable_preproc_static_crop=disable_preproc_static_crop, + ) + + if self.resize_method == "Stretch to": + resized = cv2.resize( + preprocessed_image, (self.img_size_w, self.img_size_h), cv2.INTER_CUBIC + ) + elif self.resize_method == "Fit (black edges) in": + resized = letterbox_image( + preprocessed_image, (self.img_size_w, self.img_size_h) + ) + elif self.resize_method == "Fit (white edges) in": + resized = letterbox_image( + preprocessed_image, + (self.img_size_w, self.img_size_h), + color=(255, 255, 255), + ) + + if is_bgr: + resized = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB) + img_in = np.transpose(resized, (2, 0, 1)) + img_in = img_in.astype(np.float32) + img_in = np.expand_dims(img_in, axis=0) + + return img_in, img_dims + + def preprocess_image( + self, + image: np.ndarray, + disable_preproc_contrast: bool = False, + disable_preproc_grayscale: bool = False, + disable_preproc_static_crop: bool = False, + ) -> Tuple[np.ndarray, Tuple[int, int]]: + """ + Preprocesses the given image using specified preprocessing steps. + + Args: + image (Image.Image): The PIL image to preprocess. + disable_preproc_contrast (bool, optional): If true, the contrast preprocessing step is disabled for this call. Default is False. + disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False. + disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False. + + Returns: + Image.Image: The preprocessed PIL image. + """ + return prepare( + image, + self.preproc, + disable_preproc_contrast=disable_preproc_contrast, + disable_preproc_grayscale=disable_preproc_grayscale, + disable_preproc_static_crop=disable_preproc_static_crop, + ) + + @property + def weights_file(self) -> str: + """Abstract property representing the file containing the model weights. + + Raises: + NotImplementedError: This property must be implemented in subclasses. + + Returns: + str: The file path to the weights file. + """ + raise NotImplementedError(self.__class__.__name__ + ".weights_file") + + +class RoboflowCoreModel(RoboflowInferenceModel): + """Base Roboflow inference model (Inherits from CvModel since all Roboflow models are CV models currently).""" + + def __init__( + self, + model_id: str, + api_key=None, + ): + """Initializes the RoboflowCoreModel instance. + + Args: + model_id (str): The identifier for the specific model. + api_key ([type], optional): The API key for authentication. Defaults to None. + """ + super().__init__(model_id, api_key=api_key) + self.download_weights() + + def download_weights(self) -> None: + """Downloads the model weights from the configured source. + + This method includes handling for AWS access keys and error handling. + """ + infer_bucket_files = self.get_infer_bucket_file_list() + if are_all_files_cached(files=infer_bucket_files, model_id=self.endpoint): + logger.debug("Model artifacts already downloaded, loading from cache") + return None + if is_model_artefacts_bucket_available(): + self.download_model_artefacts_from_s3() + return None + self.download_model_from_roboflow_api() + + def download_model_from_roboflow_api(self) -> None: + api_data = get_roboflow_model_data( + api_key=self.api_key, + model_id=self.endpoint, + endpoint_type=ModelEndpointType.CORE_MODEL, + device_id=self.device_id, + ) + if "weights" not in api_data: + raise ModelArtefactError( + f"`weights` key not available in Roboflow API response while downloading model weights." + ) + for weights_url_key in api_data["weights"]: + weights_url = api_data["weights"][weights_url_key] + t1 = perf_counter() + model_weights_response = get_from_url(weights_url, json_response=False) + filename = weights_url.split("?")[0].split("/")[-1] + save_bytes_in_cache( + content=model_weights_response.content, + file=filename, + model_id=self.endpoint, + ) + if perf_counter() - t1 > 120: + logger.debug( + "Weights download took longer than 120 seconds, refreshing API request" + ) + api_data = get_roboflow_model_data( + api_key=self.api_key, + model_id=self.endpoint, + endpoint_type=ModelEndpointType.CORE_MODEL, + device_id=self.device_id, + ) + + def get_device_id(self) -> str: + """Returns the device ID associated with this model. + + Returns: + str: The device ID. + """ + return self.device_id + + def get_infer_bucket_file_list(self) -> List[str]: + """Abstract method to get the list of files to be downloaded from the inference bucket. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + + Returns: + List[str]: A list of filenames. + """ + raise NotImplementedError( + "get_infer_bucket_file_list not implemented for OnnxRoboflowCoreModel" + ) + + def preprocess_image(self, image: Image.Image) -> Image.Image: + """Abstract method to preprocess an image. + + Raises: + NotImplementedError: This method must be implemented in subclasses. + + Returns: + Image.Image: The preprocessed PIL image. + """ + raise NotImplementedError(self.__class__.__name__ + ".preprocess_image") + + @property + def weights_file(self) -> str: + """Abstract property representing the file containing the model weights. For core models, all model artifacts are handled through get_infer_bucket_file_list method.""" + return None + + @property + def model_artifact_bucket(self): + return CORE_MODEL_BUCKET + + +class OnnxRoboflowInferenceModel(RoboflowInferenceModel): + """Roboflow Inference Model that operates using an ONNX model file.""" + + def __init__( + self, + model_id: str, + onnxruntime_execution_providers: List[ + str + ] = get_onnxruntime_execution_providers(ONNXRUNTIME_EXECUTION_PROVIDERS), + *args, + **kwargs, + ): + """Initializes the OnnxRoboflowInferenceModel instance. + + Args: + model_id (str): The identifier for the specific ONNX model. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + super().__init__(model_id, *args, **kwargs) + if self.load_weights or not self.has_model_metadata: + self.onnxruntime_execution_providers = onnxruntime_execution_providers + for ep in self.onnxruntime_execution_providers: + if ep == "TensorrtExecutionProvider": + ep = ( + "TensorrtExecutionProvider", + { + "trt_engine_cache_enable": True, + "trt_engine_cache_path": os.path.join( + TENSORRT_CACHE_PATH, self.endpoint + ), + "trt_fp16_enable": True, + }, + ) + self.initialize_model() + self.image_loader_threadpool = ThreadPoolExecutor(max_workers=None) + try: + self.validate_model() + except ModelArtefactError as e: + logger.error(f"Unable to validate model artifacts, clearing cache: {e}") + self.clear_cache() + raise ModelArtefactError from e + + def infer(self, image: Any, **kwargs) -> Any: + input_elements = calculate_input_elements(input_value=image) + max_batch_size = MAX_BATCH_SIZE if self.batching_enabled else self.batch_size + if (input_elements == 1) or (max_batch_size == float("inf")): + return super().infer(image, **kwargs) + logger.debug( + f"Inference will be executed in batches, as there is {input_elements} input elements and " + f"maximum batch size for a model is set to: {max_batch_size}" + ) + inference_results = [] + for batch_input in create_batches(sequence=image, batch_size=max_batch_size): + batch_inference_results = super().infer(batch_input, **kwargs) + inference_results.append(batch_inference_results) + return self.merge_inference_results(inference_results=inference_results) + + def merge_inference_results(self, inference_results: List[Any]) -> Any: + return list(itertools.chain(*inference_results)) + + def validate_model(self) -> None: + if not self.load_weights: + return + try: + assert self.onnx_session is not None + except AssertionError as e: + raise ModelArtefactError( + "ONNX session not initialized. Check that the model weights are available." + ) from e + try: + self.run_test_inference() + except Exception as e: + raise ModelArtefactError(f"Unable to run test inference. Cause: {e}") from e + try: + self.validate_model_classes() + except Exception as e: + raise ModelArtefactError( + f"Unable to validate model classes. Cause: {e}" + ) from e + + def run_test_inference(self) -> None: + test_image = (np.random.rand(1024, 1024, 3) * 255).astype(np.uint8) + return self.infer(test_image) + + def get_model_output_shape(self) -> Tuple[int, int, int]: + test_image = (np.random.rand(1024, 1024, 3) * 255).astype(np.uint8) + test_image, _ = self.preprocess(test_image) + output = self.predict(test_image)[0] + return output.shape + + def validate_model_classes(self) -> None: + pass + + def get_infer_bucket_file_list(self) -> list: + """Returns the list of files to be downloaded from the inference bucket for ONNX model. + + Returns: + list: A list of filenames specific to ONNX models. + """ + return ["environment.json", "class_names.txt"] + + def initialize_model(self) -> None: + """Initializes the ONNX model, setting up the inference session and other necessary properties.""" + self.get_model_artifacts() + logger.debug("Creating inference session") + if self.load_weights or not self.has_model_metadata: + t1_session = perf_counter() + # Create an ONNX Runtime Session with a list of execution providers in priority order. ORT attempts to load providers until one is successful. This keeps the code across devices identical. + providers = self.onnxruntime_execution_providers + if not self.load_weights: + providers = ["CPUExecutionProvider"] + try: + self.onnx_session = onnxruntime.InferenceSession( + self.cache_file(self.weights_file), + providers=providers, + ) + except Exception as e: + self.clear_cache() + raise ModelArtefactError( + f"Unable to load ONNX session. Cause: {e}" + ) from e + logger.debug(f"Session created in {perf_counter() - t1_session} seconds") + + if REQUIRED_ONNX_PROVIDERS: + available_providers = onnxruntime.get_available_providers() + for provider in REQUIRED_ONNX_PROVIDERS: + if provider not in available_providers: + raise OnnxProviderNotAvailable( + f"Required ONNX Execution Provider {provider} is not availble. Check that you are using the correct docker image on a supported device." + ) + + inputs = self.onnx_session.get_inputs()[0] + input_shape = inputs.shape + self.batch_size = input_shape[0] + self.img_size_h = input_shape[2] + self.img_size_w = input_shape[3] + self.input_name = inputs.name + if isinstance(self.img_size_h, str) or isinstance(self.img_size_w, str): + if "resize" in self.preproc: + self.img_size_h = int(self.preproc["resize"]["height"]) + self.img_size_w = int(self.preproc["resize"]["width"]) + else: + self.img_size_h = 640 + self.img_size_w = 640 + + if isinstance(self.batch_size, str): + self.batching_enabled = True + logger.debug( + f"Model {self.endpoint} is loaded with dynamic batching enabled" + ) + else: + self.batching_enabled = False + logger.debug( + f"Model {self.endpoint} is loaded with dynamic batching disabled" + ) + + model_metadata = { + "batch_size": self.batch_size, + "img_size_h": self.img_size_h, + "img_size_w": self.img_size_w, + } + logger.debug(f"Writing model metadata to memcache") + self.write_model_metadata_to_memcache(model_metadata) + if not self.load_weights: # had to load weights to get metadata + del self.onnx_session + else: + if not self.has_model_metadata: + raise ValueError( + "This should be unreachable, should get weights if we don't have model metadata" + ) + logger.debug(f"Loading model metadata from memcache") + metadata = self.model_metadata_from_memcache() + self.batch_size = metadata["batch_size"] + self.img_size_h = metadata["img_size_h"] + self.img_size_w = metadata["img_size_w"] + if isinstance(self.batch_size, str): + self.batching_enabled = True + logger.debug( + f"Model {self.endpoint} is loaded with dynamic batching enabled" + ) + else: + self.batching_enabled = False + logger.debug( + f"Model {self.endpoint} is loaded with dynamic batching disabled" + ) + + def load_image( + self, + image: Any, + disable_preproc_auto_orient: bool = False, + disable_preproc_contrast: bool = False, + disable_preproc_grayscale: bool = False, + disable_preproc_static_crop: bool = False, + ) -> Tuple[np.ndarray, Tuple[int, int]]: + if isinstance(image, list): + preproc_image = partial( + self.preproc_image, + disable_preproc_auto_orient=disable_preproc_auto_orient, + disable_preproc_contrast=disable_preproc_contrast, + disable_preproc_grayscale=disable_preproc_grayscale, + disable_preproc_static_crop=disable_preproc_static_crop, + ) + imgs_with_dims = self.image_loader_threadpool.map(preproc_image, image) + imgs, img_dims = zip(*imgs_with_dims) + img_in = np.concatenate(imgs, axis=0) + else: + img_in, img_dims = self.preproc_image( + image, + disable_preproc_auto_orient=disable_preproc_auto_orient, + disable_preproc_contrast=disable_preproc_contrast, + disable_preproc_grayscale=disable_preproc_grayscale, + disable_preproc_static_crop=disable_preproc_static_crop, + ) + img_dims = [img_dims] + return img_in, img_dims + + @property + def weights_file(self) -> str: + """Returns the file containing the ONNX model weights. + + Returns: + str: The file path to the weights file. + """ + return "weights.onnx" + + +class OnnxRoboflowCoreModel(RoboflowCoreModel): + """Roboflow Inference Model that operates using an ONNX model file.""" + + pass + + +def get_class_names_from_environment_file(environment: Optional[dict]) -> List[str]: + if environment is None: + raise ModelArtefactError( + f"Missing environment while attempting to get model class names." + ) + if class_mapping_not_available_in_environment(environment=environment): + raise ModelArtefactError( + f"Missing `CLASS_MAP` in environment or `CLASS_MAP` is not dict." + ) + class_names = [] + for i in range(len(environment["CLASS_MAP"].keys())): + class_names.append(environment["CLASS_MAP"][str(i)]) + return class_names + + +def class_mapping_not_available_in_environment(environment: dict) -> bool: + return "CLASS_MAP" not in environment or not issubclass( + type(environment["CLASS_MAP"]), dict + ) + + +def get_color_mapping_from_environment( + environment: Optional[dict], class_names: List[str] +) -> Dict[str, str]: + if color_mapping_available_in_environment(environment=environment): + return environment["COLORS"] + return { + class_name: DEFAULT_COLOR_PALETTE[i % len(DEFAULT_COLOR_PALETTE)] + for i, class_name in enumerate(class_names) + } + + +def color_mapping_available_in_environment(environment: Optional[dict]) -> bool: + return ( + environment is not None + and "COLORS" in environment + and issubclass(type(environment["COLORS"]), dict) + ) + + +def is_model_artefacts_bucket_available() -> bool: + return ( + AWS_ACCESS_KEY_ID is not None + and AWS_SECRET_ACCESS_KEY is not None + and LAMBDA + and S3_CLIENT is not None + ) + + +def parse_keypoints_metadata(metadata: list) -> dict: + return { + e["object_class_id"]: {int(key): value for key, value in e["keypoints"].items()} + for e in metadata + } diff --git a/inference/core/models/stubs.py b/inference/core/models/stubs.py new file mode 100644 index 0000000000000000000000000000000000000000..9c31f72d122d8ebc3b087c79db5ac2461458c0f3 --- /dev/null +++ b/inference/core/models/stubs.py @@ -0,0 +1,135 @@ +from abc import abstractmethod +from time import perf_counter +from typing import Any, List, Tuple, Union + +import numpy as np + +from inference.core.cache.model_artifacts import clear_cache, initialise_cache +from inference.core.entities.requests.inference import InferenceRequest +from inference.core.entities.responses.inference import InferenceResponse, StubResponse +from inference.core.models.base import Model +from inference.core.models.types import PreprocessReturnMetadata +from inference.core.utils.image_utils import np_image_to_base64 + + +class ModelStub(Model): + def __init__(self, model_id: str, api_key: str): + super().__init__() + self.model_id = model_id + self.api_key = api_key + self.dataset_id, self.version_id = model_id.split("/") + self.metrics = {"num_inferences": 0, "avg_inference_time": 0.0} + initialise_cache(model_id=model_id) + + def infer_from_request( + self, request: InferenceRequest + ) -> Union[InferenceResponse, List[InferenceResponse]]: + t1 = perf_counter() + stub_prediction = self.infer(**request.dict()) + response = self.make_response(request=request, prediction=stub_prediction) + response.time = perf_counter() - t1 + return response + + def infer(self, *args, **kwargs) -> Any: + _ = self.preprocess() + dummy_prediction = self.predict() + return self.postprocess(dummy_prediction) + + def preprocess( + self, *args, **kwargs + ) -> Tuple[np.ndarray, PreprocessReturnMetadata]: + return np.zeros((128, 128, 3), dtype=np.uint8), {} # type: ignore + + def predict(self, *args, **kwargs) -> Tuple[np.ndarray, ...]: + return (np.zeros((1, 8)),) + + def postprocess(self, predictions: Tuple[np.ndarray, ...], *args, **kwargs) -> Any: + return { + "is_stub": True, + "model_id": self.model_id, + } + + def clear_cache(self) -> None: + clear_cache(model_id=self.model_id) + + @abstractmethod + def make_response( + self, request: InferenceRequest, prediction: dict, **kwargs + ) -> Union[InferenceResponse, List[InferenceResponse]]: + pass + + +class ClassificationModelStub(ModelStub): + task_type = "classification" + + def make_response( + self, request: InferenceRequest, prediction: dict, **kwargs + ) -> Union[InferenceResponse, List[InferenceResponse]]: + stub_visualisation = None + if getattr(request, "visualize_predictions", False): + stub_visualisation = np_image_to_base64( + np.zeros((128, 128, 3), dtype=np.uint8) + ) + return StubResponse( + is_stub=prediction["is_stub"], + model_id=prediction["model_id"], + task_type=self.task_type, + visualization=stub_visualisation, + ) + + +class ObjectDetectionModelStub(ModelStub): + task_type = "object-detection" + + def make_response( + self, request: InferenceRequest, prediction: dict, **kwargs + ) -> Union[InferenceResponse, List[InferenceResponse]]: + stub_visualisation = None + if getattr(request, "visualize_predictions", False): + stub_visualisation = np_image_to_base64( + np.zeros((128, 128, 3), dtype=np.uint8) + ) + return StubResponse( + is_stub=prediction["is_stub"], + model_id=prediction["model_id"], + task_type=self.task_type, + visualization=stub_visualisation, + ) + + +class InstanceSegmentationModelStub(ModelStub): + task_type = "instance-segmentation" + + def make_response( + self, request: InferenceRequest, prediction: dict, **kwargs + ) -> Union[InferenceResponse, List[InferenceResponse]]: + stub_visualisation = None + if getattr(request, "visualize_predictions", False): + stub_visualisation = np_image_to_base64( + np.zeros((128, 128, 3), dtype=np.uint8) + ) + return StubResponse( + is_stub=prediction["is_stub"], + model_id=prediction["model_id"], + task_type=self.task_type, + visualization=stub_visualisation, + ) + + +class KeypointsDetectionModelStub(ModelStub): + task_type = "keypoint-detection" + + def make_response( + self, request: InferenceRequest, prediction: dict, **kwargs + ) -> Union[InferenceResponse, List[InferenceResponse]]: + stub_visualisation = None + if getattr(request, "visualize_predictions", False): + stub_visualisation = np_image_to_base64( + np.zeros((128, 128, 3), dtype=np.uint8) + ) + return StubResponse( + is_stub=prediction["is_stub"], + model_id=prediction["model_id"], + task_type=self.task_type, + visualization=stub_visualisation, + ) diff --git a/inference/core/models/types.py b/inference/core/models/types.py new file mode 100644 index 0000000000000000000000000000000000000000..d32e17f4a782fe7d651bbe890cc14340a58594be --- /dev/null +++ b/inference/core/models/types.py @@ -0,0 +1,3 @@ +from typing import Dict, NewType + +PreprocessReturnMetadata = NewType("PreprocessReturnMetadata", Dict) diff --git a/inference/core/models/utils/__init__.py b/inference/core/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/core/models/utils/__pycache__/__init__.cpython-310.pyc b/inference/core/models/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c89f82467ffe656ae8703d2dc56f37984ea0abf2 Binary files /dev/null and b/inference/core/models/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/models/utils/__pycache__/batching.cpython-310.pyc b/inference/core/models/utils/__pycache__/batching.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bcc1fddbc5931ce5ca928b90517659e6529e5e08 Binary files /dev/null and b/inference/core/models/utils/__pycache__/batching.cpython-310.pyc differ diff --git a/inference/core/models/utils/__pycache__/keypoints.cpython-310.pyc b/inference/core/models/utils/__pycache__/keypoints.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76d4d897f5b33d1793ff6cbc727a87f27caa88b7 Binary files /dev/null and b/inference/core/models/utils/__pycache__/keypoints.cpython-310.pyc differ diff --git a/inference/core/models/utils/__pycache__/validate.cpython-310.pyc b/inference/core/models/utils/__pycache__/validate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a716103280d2a44343278786fc1c5ee74baa82d Binary files /dev/null and b/inference/core/models/utils/__pycache__/validate.cpython-310.pyc differ diff --git a/inference/core/models/utils/batching.py b/inference/core/models/utils/batching.py new file mode 100644 index 0000000000000000000000000000000000000000..dace876ddfe7e9cc09a047586d75a03a484816e9 --- /dev/null +++ b/inference/core/models/utils/batching.py @@ -0,0 +1,21 @@ +from typing import Generator, Iterable, List, TypeVar, Union + +B = TypeVar("B") + + +def calculate_input_elements(input_value: Union[B, List[B]]) -> int: + return len(input_value) if issubclass(type(input_value), list) else 1 + + +def create_batches( + sequence: Iterable[B], batch_size: int +) -> Generator[List[B], None, None]: + batch_size = max(batch_size, 1) + current_batch = [] + for element in sequence: + if len(current_batch) == batch_size: + yield current_batch + current_batch = [] + current_batch.append(element) + if len(current_batch) > 0: + yield current_batch diff --git a/inference/core/models/utils/keypoints.py b/inference/core/models/utils/keypoints.py new file mode 100644 index 0000000000000000000000000000000000000000..ef7eb822a67b7cf540cd69a3ef4967070a45aafb --- /dev/null +++ b/inference/core/models/utils/keypoints.py @@ -0,0 +1,41 @@ +from typing import List + +from inference.core.entities.responses.inference import Keypoint +from inference.core.exceptions import ModelArtefactError + + +def superset_keypoints_count(keypoints_metadata={}) -> int: + """Returns the number of keypoints in the superset.""" + max_keypoints = 0 + for keypoints in keypoints_metadata.values(): + if len(keypoints) > max_keypoints: + max_keypoints = len(keypoints) + return max_keypoints + + +def model_keypoints_to_response( + keypoints_metadata: dict, + keypoints: List[float], + predicted_object_class_id: int, + keypoint_confidence_threshold: float, +) -> List[Keypoint]: + if keypoints_metadata is None: + raise ModelArtefactError("Keypoints metadata not available.") + keypoint_id2name = keypoints_metadata[predicted_object_class_id] + results = [] + for keypoint_id in range(len(keypoints) // 3): + if keypoint_id >= len(keypoint_id2name): + # Ultralytics only supports single class keypoint detection, so points might be padded with zeros + break + confidence = keypoints[3 * keypoint_id + 2] + if confidence < keypoint_confidence_threshold: + continue + keypoint = Keypoint( + x=keypoints[3 * keypoint_id], + y=keypoints[3 * keypoint_id + 1], + confidence=confidence, + class_id=keypoint_id, + class_name=keypoint_id2name[keypoint_id], + ) + results.append(keypoint) + return results diff --git a/inference/core/models/utils/validate.py b/inference/core/models/utils/validate.py new file mode 100644 index 0000000000000000000000000000000000000000..d373decd00ca111ed80b35642a4773077710871a --- /dev/null +++ b/inference/core/models/utils/validate.py @@ -0,0 +1,3 @@ +def get_num_classes_from_model_prediction_shape(len_prediction, masks=0, keypoints=0): + num_classes = len_prediction - 5 - masks - (keypoints * 3) + return num_classes diff --git a/inference/core/nms.py b/inference/core/nms.py new file mode 100644 index 0000000000000000000000000000000000000000..dfe8756a51b085b9f9223dd5e6eb7dd4c04a4aa0 --- /dev/null +++ b/inference/core/nms.py @@ -0,0 +1,157 @@ +from typing import Optional + +import numpy as np + + +def w_np_non_max_suppression( + prediction, + conf_thresh: float = 0.25, + iou_thresh: float = 0.45, + class_agnostic: bool = False, + max_detections: int = 300, + max_candidate_detections: int = 3000, + timeout_seconds: Optional[int] = None, + num_masks: int = 0, + box_format: str = "xywh", +): + """Applies non-maximum suppression to predictions. + + Args: + prediction (np.ndarray): Array of predictions. Format for single prediction is + [bbox x 4, max_class_confidence, (confidence) x num_of_classes, additional_element x num_masks] + conf_thresh (float, optional): Confidence threshold. Defaults to 0.25. + iou_thresh (float, optional): IOU threshold. Defaults to 0.45. + class_agnostic (bool, optional): Whether to ignore class labels. Defaults to False. + max_detections (int, optional): Maximum number of detections. Defaults to 300. + max_candidate_detections (int, optional): Maximum number of candidate detections. Defaults to 3000. + timeout_seconds (Optional[int], optional): Timeout in seconds. Defaults to None. + num_masks (int, optional): Number of masks. Defaults to 0. + box_format (str, optional): Format of bounding boxes. Either 'xywh' or 'xyxy'. Defaults to 'xywh'. + + Returns: + list: List of filtered predictions after non-maximum suppression. Format of a single result is: + [bbox x 4, max_class_confidence, max_class_confidence, id_of_class_with_max_confidence, + additional_element x num_masks] + """ + num_classes = prediction.shape[2] - 5 - num_masks + + np_box_corner = np.zeros(prediction.shape) + if box_format == "xywh": + np_box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 + np_box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 + np_box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 + np_box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 + prediction[:, :, :4] = np_box_corner[:, :, :4] + elif box_format == "xyxy": + pass + else: + raise ValueError( + "box_format must be either 'xywh' or 'xyxy', got {}".format(box_format) + ) + + batch_predictions = [] + for np_image_i, np_image_pred in enumerate(prediction): + filtered_predictions = [] + np_conf_mask = (np_image_pred[:, 4] >= conf_thresh).squeeze() + + np_image_pred = np_image_pred[np_conf_mask] + if np_image_pred.shape[0] == 0: + batch_predictions.append(filtered_predictions) + continue + np_class_conf = np.max(np_image_pred[:, 5 : num_classes + 5], 1) + np_class_pred = np.argmax(np_image_pred[:, 5 : num_classes + 5], 1) + np_class_conf = np.expand_dims(np_class_conf, axis=1) + np_class_pred = np.expand_dims(np_class_pred, axis=1) + np_mask_pred = np_image_pred[:, 5 + num_classes :] + np_detections = np.append( + np.append( + np.append(np_image_pred[:, :5], np_class_conf, axis=1), + np_class_pred, + axis=1, + ), + np_mask_pred, + axis=1, + ) + + np_unique_labels = np.unique(np_detections[:, 6]) + + if class_agnostic: + np_detections_class = sorted( + np_detections, key=lambda row: row[4], reverse=True + ) + filtered_predictions.extend( + non_max_suppression_fast(np.array(np_detections_class), iou_thresh) + ) + else: + for c in np_unique_labels: + np_detections_class = np_detections[np_detections[:, 6] == c] + np_detections_class = sorted( + np_detections_class, key=lambda row: row[4], reverse=True + ) + filtered_predictions.extend( + non_max_suppression_fast(np.array(np_detections_class), iou_thresh) + ) + filtered_predictions = sorted( + filtered_predictions, key=lambda row: row[4], reverse=True + ) + batch_predictions.append(filtered_predictions[:max_detections]) + return batch_predictions + + +# Malisiewicz et al. +def non_max_suppression_fast(boxes, overlapThresh): + """Applies non-maximum suppression to bounding boxes. + + Args: + boxes (np.ndarray): Array of bounding boxes with confidence scores. + overlapThresh (float): Overlap threshold for suppression. + + Returns: + list: List of bounding boxes after non-maximum suppression. + """ + # if there are no boxes, return an empty list + if len(boxes) == 0: + return [] + # if the bounding boxes integers, convert them to floats -- + # this is important since we'll be doing a bunch of divisions + if boxes.dtype.kind == "i": + boxes = boxes.astype("float") + # initialize the list of picked indexes + pick = [] + # grab the coordinates of the bounding boxes + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + conf = boxes[:, 4] + # compute the area of the bounding boxes and sort the bounding + # boxes by the bottom-right y-coordinate of the bounding box + area = (x2 - x1 + 1) * (y2 - y1 + 1) + idxs = np.argsort(conf) + # keep looping while some indexes still remain in the indexes + # list + while len(idxs) > 0: + # grab the last index in the indexes list and add the + # index value to the list of picked indexes + last = len(idxs) - 1 + i = idxs[last] + pick.append(i) + # find the largest (x, y) coordinates for the start of + # the bounding box and the smallest (x, y) coordinates + # for the end of the bounding box + xx1 = np.maximum(x1[i], x1[idxs[:last]]) + yy1 = np.maximum(y1[i], y1[idxs[:last]]) + xx2 = np.minimum(x2[i], x2[idxs[:last]]) + yy2 = np.minimum(y2[i], y2[idxs[:last]]) + # compute the width and height of the bounding box + w = np.maximum(0, xx2 - xx1 + 1) + h = np.maximum(0, yy2 - yy1 + 1) + # compute the ratio of overlap + overlap = (w * h) / area[idxs[:last]] + # delete all indexes from the index list that have + idxs = np.delete( + idxs, np.concatenate(([last], np.where(overlap > overlapThresh)[0])) + ) + # return only the bounding boxes that were picked using the + # integer data type + return boxes[pick].astype("float") diff --git a/inference/core/registries/__init__.py b/inference/core/registries/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/inference/core/registries/__init__.py @@ -0,0 +1 @@ + diff --git a/inference/core/registries/__pycache__/__init__.cpython-310.pyc b/inference/core/registries/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07623caff8fdb2756a870c809d4b638bf1bdbdf7 Binary files /dev/null and b/inference/core/registries/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/registries/__pycache__/base.cpython-310.pyc b/inference/core/registries/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77e05b8246f650db652b5d012a3c044a0fce93fe Binary files /dev/null and b/inference/core/registries/__pycache__/base.cpython-310.pyc differ diff --git a/inference/core/registries/__pycache__/roboflow.cpython-310.pyc b/inference/core/registries/__pycache__/roboflow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e095aff360b4b713333e6afb631c06d6de949c6 Binary files /dev/null and b/inference/core/registries/__pycache__/roboflow.cpython-310.pyc differ diff --git a/inference/core/registries/base.py b/inference/core/registries/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4a4b56a3ee398ca0a3ca3568f969cc6aa211f150 --- /dev/null +++ b/inference/core/registries/base.py @@ -0,0 +1,37 @@ +from inference.core.exceptions import ModelNotRecognisedError +from inference.core.models.base import Model + + +class ModelRegistry: + """An object which is able to return model classes based on given model IDs and model types. + + Attributes: + registry_dict (dict): A dictionary mapping model types to model classes. + """ + + def __init__(self, registry_dict) -> None: + """Initializes the ModelRegistry with the given dictionary of registered models. + + Args: + registry_dict (dict): A dictionary mapping model types to model classes. + """ + self.registry_dict = registry_dict + + def get_model(self, model_type: str, model_id: str) -> Model: + """Returns the model class based on the given model type. + + Args: + model_type (str): The type of the model to be retrieved. + model_id (str): The ID of the model to be retrieved (unused in the current implementation). + + Returns: + Model: The model class corresponding to the given model type. + + Raises: + ModelNotRecognisedError: If the model_type is not found in the registry_dict. + """ + if model_type not in self.registry_dict: + raise ModelNotRecognisedError( + f"Could not find model of type: {model_type} in configured registry." + ) + return self.registry_dict[model_type] diff --git a/inference/core/registries/roboflow.py b/inference/core/registries/roboflow.py new file mode 100644 index 0000000000000000000000000000000000000000..8fe627413d9aac80f81c51e468ee4c2963296c70 --- /dev/null +++ b/inference/core/registries/roboflow.py @@ -0,0 +1,236 @@ +import os +from typing import Optional, Tuple, Union + +from inference.core.cache import cache +from inference.core.devices.utils import GLOBAL_DEVICE_ID +from inference.core.entities.types import DatasetID, ModelType, TaskType, VersionID +from inference.core.env import LAMBDA, MODEL_CACHE_DIR +from inference.core.exceptions import ( + MissingApiKeyError, + ModelArtefactError, + ModelNotRecognisedError, +) +from inference.core.logger import logger +from inference.core.models.base import Model +from inference.core.registries.base import ModelRegistry +from inference.core.roboflow_api import ( + MODEL_TYPE_DEFAULTS, + MODEL_TYPE_KEY, + PROJECT_TASK_TYPE_KEY, + ModelEndpointType, + get_roboflow_dataset_type, + get_roboflow_model_data, + get_roboflow_workspace, +) +from inference.core.utils.file_system import dump_json, read_json +from inference.core.utils.roboflow import get_model_id_chunks +from inference.models.aliases import resolve_roboflow_model_alias + +GENERIC_MODELS = { + "clip": ("embed", "clip"), + "sam": ("embed", "sam"), + "gaze": ("gaze", "l2cs"), + "doctr": ("ocr", "doctr"), + "grounding_dino": ("object-detection", "grounding-dino"), + "cogvlm": ("llm", "cogvlm"), + "yolo_world": ("object-detection", "yolo-world"), +} + +STUB_VERSION_ID = "0" +CACHE_METADATA_LOCK_TIMEOUT = 1.0 + + +class RoboflowModelRegistry(ModelRegistry): + """A Roboflow-specific model registry which gets the model type using the model id, + then returns a model class based on the model type. + """ + + def get_model(self, model_id: str, api_key: str) -> Model: + """Returns the model class based on the given model id and API key. + + Args: + model_id (str): The ID of the model to be retrieved. + api_key (str): The API key used to authenticate. + + Returns: + Model: The model class corresponding to the given model ID and type. + + Raises: + ModelNotRecognisedError: If the model type is not supported or found. + """ + model_type = get_model_type(model_id, api_key) + if model_type not in self.registry_dict: + raise ModelNotRecognisedError(f"Model type not supported: {model_type}") + return self.registry_dict[model_type] + + +def get_model_type( + model_id: str, + api_key: Optional[str] = None, +) -> Tuple[TaskType, ModelType]: + """Retrieves the model type based on the given model ID and API key. + + Args: + model_id (str): The ID of the model. + api_key (str): The API key used to authenticate. + + Returns: + tuple: The project task type and the model type. + + Raises: + WorkspaceLoadError: If the workspace could not be loaded or if the API key is invalid. + DatasetLoadError: If the dataset could not be loaded due to invalid ID, workspace ID or version ID. + MissingDefaultModelError: If default model is not configured and API does not provide this info + MalformedRoboflowAPIResponseError: Roboflow API responds in invalid format. + """ + model_id = resolve_roboflow_model_alias(model_id=model_id) + dataset_id, version_id = get_model_id_chunks(model_id=model_id) + if dataset_id in GENERIC_MODELS: + logger.debug(f"Loading generic model: {dataset_id}.") + return GENERIC_MODELS[dataset_id] + cached_metadata = get_model_metadata_from_cache( + dataset_id=dataset_id, version_id=version_id + ) + if cached_metadata is not None: + return cached_metadata[0], cached_metadata[1] + if version_id == STUB_VERSION_ID: + if api_key is None: + raise MissingApiKeyError( + "Stub model version provided but no API key was provided. API key is required to load stub models." + ) + workspace_id = get_roboflow_workspace(api_key=api_key) + project_task_type = get_roboflow_dataset_type( + api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id + ) + model_type = "stub" + save_model_metadata_in_cache( + dataset_id=dataset_id, + version_id=version_id, + project_task_type=project_task_type, + model_type=model_type, + ) + return project_task_type, model_type + api_data = get_roboflow_model_data( + api_key=api_key, + model_id=model_id, + endpoint_type=ModelEndpointType.ORT, + device_id=GLOBAL_DEVICE_ID, + ).get("ort") + if api_data is None: + raise ModelArtefactError("Error loading model artifacts from Roboflow API.") + # some older projects do not have type field - hence defaulting + project_task_type = api_data.get("type", "object-detection") + model_type = api_data.get("modelType") + if model_type is None or model_type == "ort": + # some very old model versions do not have modelType reported - and API respond in a generic way - + # then we shall attempt using default model for given task type + model_type = MODEL_TYPE_DEFAULTS.get(project_task_type) + if model_type is None or project_task_type is None: + raise ModelArtefactError("Error loading model artifacts from Roboflow API.") + save_model_metadata_in_cache( + dataset_id=dataset_id, + version_id=version_id, + project_task_type=project_task_type, + model_type=model_type, + ) + + return project_task_type, model_type + + +def get_model_metadata_from_cache( + dataset_id: str, version_id: str +) -> Optional[Tuple[TaskType, ModelType]]: + if LAMBDA: + return _get_model_metadata_from_cache( + dataset_id=dataset_id, version_id=version_id + ) + with cache.lock( + f"lock:metadata:{dataset_id}:{version_id}", expire=CACHE_METADATA_LOCK_TIMEOUT + ): + return _get_model_metadata_from_cache( + dataset_id=dataset_id, version_id=version_id + ) + + +def _get_model_metadata_from_cache( + dataset_id: str, version_id: str +) -> Optional[Tuple[TaskType, ModelType]]: + model_type_cache_path = construct_model_type_cache_path( + dataset_id=dataset_id, version_id=version_id + ) + if not os.path.isfile(model_type_cache_path): + return None + try: + model_metadata = read_json(path=model_type_cache_path) + if model_metadata_content_is_invalid(content=model_metadata): + return None + return model_metadata[PROJECT_TASK_TYPE_KEY], model_metadata[MODEL_TYPE_KEY] + except ValueError as e: + logger.warning( + f"Could not load model description from cache under path: {model_type_cache_path} - decoding issue: {e}." + ) + return None + + +def model_metadata_content_is_invalid(content: Optional[Union[list, dict]]) -> bool: + if content is None: + logger.warning("Empty model metadata file encountered in cache.") + return True + if not issubclass(type(content), dict): + logger.warning("Malformed file encountered in cache.") + return True + if PROJECT_TASK_TYPE_KEY not in content or MODEL_TYPE_KEY not in content: + logger.warning( + f"Could not find one of required keys {PROJECT_TASK_TYPE_KEY} or {MODEL_TYPE_KEY} in cache." + ) + return True + return False + + +def save_model_metadata_in_cache( + dataset_id: DatasetID, + version_id: VersionID, + project_task_type: TaskType, + model_type: ModelType, +) -> None: + if LAMBDA: + _save_model_metadata_in_cache( + dataset_id=dataset_id, + version_id=version_id, + project_task_type=project_task_type, + model_type=model_type, + ) + return None + with cache.lock( + f"lock:metadata:{dataset_id}:{version_id}", expire=CACHE_METADATA_LOCK_TIMEOUT + ): + _save_model_metadata_in_cache( + dataset_id=dataset_id, + version_id=version_id, + project_task_type=project_task_type, + model_type=model_type, + ) + return None + + +def _save_model_metadata_in_cache( + dataset_id: DatasetID, + version_id: VersionID, + project_task_type: TaskType, + model_type: ModelType, +) -> None: + model_type_cache_path = construct_model_type_cache_path( + dataset_id=dataset_id, version_id=version_id + ) + metadata = { + PROJECT_TASK_TYPE_KEY: project_task_type, + MODEL_TYPE_KEY: model_type, + } + dump_json( + path=model_type_cache_path, content=metadata, allow_override=True, indent=4 + ) + + +def construct_model_type_cache_path(dataset_id: str, version_id: str) -> str: + cache_dir = os.path.join(MODEL_CACHE_DIR, dataset_id, version_id) + return os.path.join(cache_dir, "model_type.json") diff --git a/inference/core/roboflow_api.py b/inference/core/roboflow_api.py new file mode 100644 index 0000000000000000000000000000000000000000..54abbf5c77b8f845d1c47b4e01e23c87d1926f7e --- /dev/null +++ b/inference/core/roboflow_api.py @@ -0,0 +1,368 @@ +import json +import urllib.parse +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + +import requests +from requests import Response +from requests_toolbelt import MultipartEncoder + +from inference.core import logger +from inference.core.cache import cache +from inference.core.entities.types import ( + DatasetID, + ModelType, + TaskType, + VersionID, + WorkspaceID, +) +from inference.core.env import API_BASE_URL +from inference.core.exceptions import ( + MalformedRoboflowAPIResponseError, + MalformedWorkflowResponseError, + MissingDefaultModelError, + RoboflowAPIConnectionError, + RoboflowAPIIAlreadyAnnotatedError, + RoboflowAPIIAnnotationRejectionError, + RoboflowAPIImageUploadRejectionError, + RoboflowAPINotAuthorizedError, + RoboflowAPINotNotFoundError, + RoboflowAPIUnsuccessfulRequestError, + WorkspaceLoadError, +) +from inference.core.utils.requests import api_key_safe_raise_for_status +from inference.core.utils.url_utils import wrap_url + +MODEL_TYPE_DEFAULTS = { + "object-detection": "yolov5v2s", + "instance-segmentation": "yolact", + "classification": "vit", + "keypoint-detection": "yolov8n", +} +PROJECT_TASK_TYPE_KEY = "project_task_type" +MODEL_TYPE_KEY = "model_type" + +NOT_FOUND_ERROR_MESSAGE = ( + "Could not find requested Roboflow resource. Check that the provided dataset and " + "version are correct, and check that the provided Roboflow API key has the correct permissions." +) + + +def raise_from_lambda( + inner_error: Exception, exception_type: Type[Exception], message: str +) -> None: + raise exception_type(message) from inner_error + + +DEFAULT_ERROR_HANDLERS = { + 401: lambda e: raise_from_lambda( + e, + RoboflowAPINotAuthorizedError, + "Unauthorized access to roboflow API - check API key. Visit " + "https://docs.roboflow.com/api-reference/authentication#retrieve-an-api-key to learn how to retrieve one.", + ), + 404: lambda e: raise_from_lambda( + e, RoboflowAPINotNotFoundError, NOT_FOUND_ERROR_MESSAGE + ), +} + + +def wrap_roboflow_api_errors( + http_errors_handlers: Optional[ + Dict[int, Callable[[Union[requests.exceptions.HTTPError]], None]] + ] = None, +) -> callable: + def decorator(function: callable) -> callable: + def wrapper(*args, **kwargs) -> Any: + try: + return function(*args, **kwargs) + except (requests.exceptions.ConnectionError, ConnectionError) as error: + raise RoboflowAPIConnectionError( + "Could not connect to Roboflow API." + ) from error + except requests.exceptions.HTTPError as error: + user_handler_override = ( + http_errors_handlers if http_errors_handlers is not None else {} + ) + status_code = error.response.status_code + default_handler = DEFAULT_ERROR_HANDLERS.get(status_code) + error_handler = user_handler_override.get(status_code, default_handler) + if error_handler is not None: + error_handler(error) + raise RoboflowAPIUnsuccessfulRequestError( + f"Unsuccessful request to Roboflow API with response code: {status_code}" + ) from error + except requests.exceptions.InvalidJSONError as error: + raise MalformedRoboflowAPIResponseError( + "Could not decode JSON response from Roboflow API." + ) from error + + return wrapper + + return decorator + + +@wrap_roboflow_api_errors() +def get_roboflow_workspace(api_key: str) -> WorkspaceID: + api_url = _add_params_to_url( + url=f"{API_BASE_URL}/", + params=[("api_key", api_key), ("nocache", "true")], + ) + api_key_info = _get_from_url(url=api_url) + workspace_id = api_key_info.get("workspace") + if workspace_id is None: + raise WorkspaceLoadError(f"Empty workspace encountered, check your API key.") + return workspace_id + + +@wrap_roboflow_api_errors() +def get_roboflow_dataset_type( + api_key: str, workspace_id: WorkspaceID, dataset_id: DatasetID +) -> TaskType: + api_url = _add_params_to_url( + url=f"{API_BASE_URL}/{workspace_id}/{dataset_id}", + params=[("api_key", api_key), ("nocache", "true")], + ) + dataset_info = _get_from_url(url=api_url) + project_task_type = dataset_info.get("project", {}) + if "type" not in project_task_type: + logger.warning( + f"Project task type not defined for workspace={workspace_id} and dataset={dataset_id}, defaulting " + f"to object-detection." + ) + return project_task_type.get("type", "object-detection") + + +@wrap_roboflow_api_errors( + http_errors_handlers={ + 500: lambda e: raise_from_lambda( + e, RoboflowAPINotNotFoundError, NOT_FOUND_ERROR_MESSAGE + ) + # this is temporary solution, empirically checked that backend API responds HTTP 500 on incorrect version. + # TO BE FIXED at backend, otherwise this error handling may overshadow existing backend problems. + } +) +def get_roboflow_model_type( + api_key: str, + workspace_id: WorkspaceID, + dataset_id: DatasetID, + version_id: VersionID, + project_task_type: ModelType, +) -> ModelType: + api_url = _add_params_to_url( + url=f"{API_BASE_URL}/{workspace_id}/{dataset_id}/{version_id}", + params=[("api_key", api_key), ("nocache", "true")], + ) + version_info = _get_from_url(url=api_url) + model_type = version_info["version"] + if "modelType" not in model_type: + if project_task_type not in MODEL_TYPE_DEFAULTS: + raise MissingDefaultModelError( + f"Could not set default model for {project_task_type}" + ) + logger.warning( + f"Model type not defined - using default for {project_task_type} task." + ) + return model_type.get("modelType", MODEL_TYPE_DEFAULTS[project_task_type]) + + +class ModelEndpointType(Enum): + ORT = "ort" + CORE_MODEL = "core_model" + + +@wrap_roboflow_api_errors() +def get_roboflow_model_data( + api_key: str, + model_id: str, + endpoint_type: ModelEndpointType, + device_id: str, +) -> dict: + api_data_cache_key = f"roboflow_api_data:{endpoint_type.value}:{model_id}" + api_data = cache.get(api_data_cache_key) + if api_data is not None: + logger.debug(f"Loaded model data from cache with key: {api_data_cache_key}.") + return api_data + else: + params = [ + ("nocache", "true"), + ("device", device_id), + ("dynamic", "true"), + ] + if api_key is not None: + params.append(("api_key", api_key)) + api_url = _add_params_to_url( + url=f"{API_BASE_URL}/{endpoint_type.value}/{model_id}", + params=params, + ) + api_data = _get_from_url(url=api_url) + cache.set( + api_data_cache_key, + api_data, + expire=10, + ) + logger.debug( + f"Loaded model data from Roboflow API and saved to cache with key: {api_data_cache_key}." + ) + return api_data + + +@wrap_roboflow_api_errors() +def get_roboflow_active_learning_configuration( + api_key: str, + workspace_id: WorkspaceID, + dataset_id: DatasetID, +) -> dict: + api_url = _add_params_to_url( + url=f"{API_BASE_URL}/{workspace_id}/{dataset_id}/active_learning", + params=[("api_key", api_key)], + ) + return _get_from_url(url=api_url) + + +@wrap_roboflow_api_errors() +def register_image_at_roboflow( + api_key: str, + dataset_id: DatasetID, + local_image_id: str, + image_bytes: bytes, + batch_name: str, + tags: Optional[List[str]] = None, +) -> dict: + url = f"{API_BASE_URL}/dataset/{dataset_id}/upload" + params = [ + ("api_key", api_key), + ("batch", batch_name), + ] + tags = tags if tags is not None else [] + for tag in tags: + params.append(("tag", tag)) + wrapped_url = wrap_url(_add_params_to_url(url=url, params=params)) + m = MultipartEncoder( + fields={ + "name": f"{local_image_id}.jpg", + "file": ("imageToUpload", image_bytes, "image/jpeg"), + } + ) + response = requests.post( + url=wrapped_url, + data=m, + headers={"Content-Type": m.content_type}, + ) + api_key_safe_raise_for_status(response=response) + parsed_response = response.json() + if not parsed_response.get("duplicate") and not parsed_response.get("success"): + raise RoboflowAPIImageUploadRejectionError( + f"Server rejected image: {parsed_response}" + ) + return parsed_response + + +@wrap_roboflow_api_errors( + http_errors_handlers={ + 409: lambda e: raise_from_lambda( + e, + RoboflowAPIIAlreadyAnnotatedError, + "Given datapoint already has annotation.", + ) + } +) +def annotate_image_at_roboflow( + api_key: str, + dataset_id: DatasetID, + local_image_id: str, + roboflow_image_id: str, + annotation_content: str, + annotation_file_type: str, + is_prediction: bool = True, +) -> dict: + url = f"{API_BASE_URL}/dataset/{dataset_id}/annotate/{roboflow_image_id}" + params = [ + ("api_key", api_key), + ("name", f"{local_image_id}.{annotation_file_type}"), + ("prediction", str(is_prediction).lower()), + ] + wrapped_url = wrap_url(_add_params_to_url(url=url, params=params)) + response = requests.post( + wrapped_url, + data=annotation_content, + headers={"Content-Type": "text/plain"}, + ) + api_key_safe_raise_for_status(response=response) + parsed_response = response.json() + if "error" in parsed_response or not parsed_response.get("success"): + raise RoboflowAPIIAnnotationRejectionError( + f"Failed to save annotation for {roboflow_image_id}. API response: {parsed_response}" + ) + return parsed_response + + +@wrap_roboflow_api_errors() +def get_roboflow_labeling_batches( + api_key: str, workspace_id: WorkspaceID, dataset_id: str +) -> dict: + api_url = _add_params_to_url( + url=f"{API_BASE_URL}/{workspace_id}/{dataset_id}/batches", + params=[("api_key", api_key)], + ) + return _get_from_url(url=api_url) + + +@wrap_roboflow_api_errors() +def get_roboflow_labeling_jobs( + api_key: str, workspace_id: WorkspaceID, dataset_id: str +) -> dict: + api_url = _add_params_to_url( + url=f"{API_BASE_URL}/{workspace_id}/{dataset_id}/jobs", + params=[("api_key", api_key)], + ) + return _get_from_url(url=api_url) + + +@wrap_roboflow_api_errors() +def get_workflow_specification( + api_key: str, + workspace_id: WorkspaceID, + workflow_name: str, +) -> dict: + api_url = _add_params_to_url( + url=f"{API_BASE_URL}/{workspace_id}/workflows/{workflow_name}", + params=[("api_key", api_key)], + ) + response = _get_from_url(url=api_url) + if "workflow" not in response or "config" not in response["workflow"]: + raise MalformedWorkflowResponseError( + f"Could not found workflow specification in API response" + ) + try: + return json.loads(response["workflow"]["config"]) + except (ValueError, TypeError) as error: + raise MalformedWorkflowResponseError( + "Could not decode workflow specification in Roboflow API response" + ) from error + + +@wrap_roboflow_api_errors() +def get_from_url( + url: str, + json_response: bool = True, +) -> Union[Response, dict]: + return _get_from_url(url=url, json_response=json_response) + + +def _get_from_url(url: str, json_response: bool = True) -> Union[Response, dict]: + response = requests.get(wrap_url(url)) + api_key_safe_raise_for_status(response=response) + if json_response: + return response.json() + return response + + +def _add_params_to_url(url: str, params: List[Tuple[str, str]]) -> str: + if len(params) == 0: + return url + params_chunks = [ + f"{name}={urllib.parse.quote_plus(value)}" for name, value in params + ] + parameters_string = "&".join(params_chunks) + return f"{url}?{parameters_string}" diff --git a/inference/core/usage.py b/inference/core/usage.py new file mode 100644 index 0000000000000000000000000000000000000000..b9607923b736e83e37bc1c4d98d9c7030b2b8187 --- /dev/null +++ b/inference/core/usage.py @@ -0,0 +1,63 @@ +import json + +import elasticache_auto_discovery +from pymemcache.client.hash import HashClient + +from inference.core.env import ELASTICACHE_ENDPOINT +from inference.core.logger import logger + +nodes = elasticache_auto_discovery.discover(ELASTICACHE_ENDPOINT) + +# set up memcache +nodes = map(lambda x: (x[1], int(x[2])), nodes) +memcache_client = HashClient(nodes) + + +def trackUsage(endpoint, actor, n=1): + """Tracks the usage of an endpoint by an actor. + + This function increments the usage count for a given endpoint by an actor. + It also handles initialization if the count does not exist. + + Args: + endpoint (str): The endpoint being accessed. + actor (str): The actor accessing the endpoint. + n (int, optional): The number of times the endpoint was accessed. Defaults to 1. + + Returns: + None: This function does not return anything but updates the memcache client. + """ + # count an inference + try: + job = endpoint + "endpoint:::actor" + actor + current_infers = memcache_client.incr(job, n) + if current_infers is None: # not yet set; initialize at 1 + memcache_client.set(job, n) + current_infers = n + + # store key + job_keys = memcache_client.get("JOB_KEYS") + if job_keys is None: + memcache_client.add("JOB_KEYS", json.dumps([job])) + else: + decoded = json.loads(job_keys) + decoded.append(job) + decoded = list(set(decoded)) + memcache_client.set("JOB_KEYS", json.dumps(decoded)) + + actor_keys = memcache_client.get("ACTOR_KEYS") + if actor_keys is None: + ak = {} + ak[actor] = n + memcache_client.add("ACTOR_KEYS", json.dumps(ak)) + else: + decoded = json.loads(actor_keys) + if actor in actor_keys: + actor_keys[actor] += n + else: + actor_keys[actor] = n + memcache_client.set("ACTOR_KEYS", json.dumps(actor_keys)) + + except Exception as e: + logger.debug("WARNING: there was an error in counting this inference") + logger.debug(e) diff --git a/inference/core/utils/__init__.py b/inference/core/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/core/utils/__pycache__/__init__.cpython-310.pyc b/inference/core/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bb8ca00cd6601734cbb8fdcda7225f8823ed25b Binary files /dev/null and b/inference/core/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/core/utils/__pycache__/environment.cpython-310.pyc b/inference/core/utils/__pycache__/environment.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f10fc654daa7260ee669dc0d34ad8bb12792a06 Binary files /dev/null and b/inference/core/utils/__pycache__/environment.cpython-310.pyc differ diff --git a/inference/core/utils/__pycache__/file_system.cpython-310.pyc b/inference/core/utils/__pycache__/file_system.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94af40c42469ca705906f42e20cc42204f06b2d7 Binary files /dev/null and b/inference/core/utils/__pycache__/file_system.cpython-310.pyc differ diff --git a/inference/core/utils/__pycache__/hash.cpython-310.pyc b/inference/core/utils/__pycache__/hash.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6b41cd3bc1f2e5f896c32ac1d6ea8e7fea579c1 Binary files /dev/null and b/inference/core/utils/__pycache__/hash.cpython-310.pyc differ diff --git a/inference/core/utils/__pycache__/image_utils.cpython-310.pyc b/inference/core/utils/__pycache__/image_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25aa2015d5dea5cfaf2991b24b6d7588344c6c3c Binary files /dev/null and b/inference/core/utils/__pycache__/image_utils.cpython-310.pyc differ diff --git a/inference/core/utils/__pycache__/notebooks.cpython-310.pyc b/inference/core/utils/__pycache__/notebooks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0e01b99acc9ac0101c6fb622b99a3b8f763de4d Binary files /dev/null and b/inference/core/utils/__pycache__/notebooks.cpython-310.pyc differ diff --git a/inference/core/utils/__pycache__/onnx.cpython-310.pyc b/inference/core/utils/__pycache__/onnx.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..156b8b61fb5005f0b42c0257cdc08d29d48facdd Binary files /dev/null and b/inference/core/utils/__pycache__/onnx.cpython-310.pyc differ diff --git a/inference/core/utils/__pycache__/postprocess.cpython-310.pyc b/inference/core/utils/__pycache__/postprocess.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48972045b050f4b6f146060aa981d9364d711db1 Binary files /dev/null and b/inference/core/utils/__pycache__/postprocess.cpython-310.pyc differ diff --git a/inference/core/utils/__pycache__/preprocess.cpython-310.pyc b/inference/core/utils/__pycache__/preprocess.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf0e46568bc0f5568331565ee86029a38001d92b Binary files /dev/null and b/inference/core/utils/__pycache__/preprocess.cpython-310.pyc differ diff --git a/inference/core/utils/__pycache__/requests.cpython-310.pyc b/inference/core/utils/__pycache__/requests.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0789aa7283a64ff48df5930f0e9e4c75a0fdd150 Binary files /dev/null and b/inference/core/utils/__pycache__/requests.cpython-310.pyc differ diff --git a/inference/core/utils/__pycache__/roboflow.cpython-310.pyc b/inference/core/utils/__pycache__/roboflow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..697d60087f295aefd7aa30400c0691f01d340e27 Binary files /dev/null and b/inference/core/utils/__pycache__/roboflow.cpython-310.pyc differ diff --git a/inference/core/utils/__pycache__/s3.cpython-310.pyc b/inference/core/utils/__pycache__/s3.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4be7730ff623263ab1f3ee769785c3529a611d90 Binary files /dev/null and b/inference/core/utils/__pycache__/s3.cpython-310.pyc differ diff --git a/inference/core/utils/__pycache__/url_utils.cpython-310.pyc b/inference/core/utils/__pycache__/url_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2066d34f33012dfdba1a721289ebd194a6167593 Binary files /dev/null and b/inference/core/utils/__pycache__/url_utils.cpython-310.pyc differ diff --git a/inference/core/utils/__pycache__/visualisation.cpython-310.pyc b/inference/core/utils/__pycache__/visualisation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2603973e9970306793d2dd5f82327f0505951d93 Binary files /dev/null and b/inference/core/utils/__pycache__/visualisation.cpython-310.pyc differ diff --git a/inference/core/utils/environment.py b/inference/core/utils/environment.py new file mode 100644 index 0000000000000000000000000000000000000000..234084bd6ed93f11a3c0682d0a2b088e3c107d92 --- /dev/null +++ b/inference/core/utils/environment.py @@ -0,0 +1,69 @@ +import os +from typing import Any, Callable, List, Optional, Type, TypeVar, Union + +from inference.core.exceptions import InvalidEnvironmentVariableError + +T = TypeVar("T") + + +def safe_env_to_type( + variable_name: str, + default_value: Optional[T] = None, + type_constructor: Optional[Union[Type[T], Callable[[str], T]]] = None, +) -> Optional[T]: + """ + Converts env variable to specified type, but only if variable is set - otherwise default is returned. + If `type_constructor` is not given - value of type str will be returned. + """ + if variable_name not in os.environ: + return default_value + variable_value = os.environ[variable_name] + if type_constructor is None: + return variable_value + return type_constructor(variable_value) + + +def str2bool(value: Any) -> bool: + """ + Converts an environment variable to a boolean value. + + Args: + value (str or bool): The environment variable value to be converted. + + Returns: + bool: The converted boolean value. + + Raises: + InvalidEnvironmentVariableError: If the value is not 'true', 'false', or a boolean. + """ + if isinstance(value, bool): + return value + if not issubclass(type(value), str): + raise InvalidEnvironmentVariableError( + f"Expected a boolean environment variable (true or false) but got '{value}'" + ) + if value.lower() == "true": + return True + elif value.lower() == "false": + return False + else: + raise InvalidEnvironmentVariableError( + f"Expected a boolean environment variable (true or false) but got '{value}'" + ) + + +def safe_split_value(value: Optional[str], delimiter: str = ",") -> Optional[List[str]]: + """ + Splits a separated environment variable into a list. + + Args: + value (str): The environment variable value to be split. + delimiter(str): Delimiter to be used + + Returns: + list or None: The split values as a list, or None if the input is None. + """ + if value is None: + return None + else: + return value.split(delimiter) diff --git a/inference/core/utils/file_system.py b/inference/core/utils/file_system.py new file mode 100644 index 0000000000000000000000000000000000000000..100832ed5f57961f3e77c195f00a500c0fd702e7 --- /dev/null +++ b/inference/core/utils/file_system.py @@ -0,0 +1,62 @@ +import json +import os.path +from typing import List, Optional, Union + + +def read_text_file( + path: str, + split_lines: bool = False, + strip_white_chars: bool = False, +) -> Union[str, List[str]]: + with open(path) as f: + if split_lines: + lines = list(f.readlines()) + if strip_white_chars: + return [line.strip() for line in lines if len(line.strip()) > 0] + else: + return lines + content = f.read() + if strip_white_chars: + content = content.strip() + return content + + +def read_json(path: str, **kwargs) -> Optional[Union[dict, list]]: + with open(path) as f: + return json.load(f, **kwargs) + + +def dump_json( + path: str, content: Union[dict, list], allow_override: bool = False, **kwargs +) -> None: + ensure_write_is_allowed(path=path, allow_override=allow_override) + ensure_parent_dir_exists(path=path) + with open(path, "w") as f: + json.dump(content, fp=f, **kwargs) + + +def dump_text_lines( + path: str, content: List[str], allow_override: bool = False +) -> None: + ensure_write_is_allowed(path=path, allow_override=allow_override) + ensure_parent_dir_exists(path=path) + with open(path, "w") as f: + f.write("\n".join(content)) + + +def dump_bytes(path: str, content: bytes, allow_override: bool = False) -> None: + ensure_write_is_allowed(path=path, allow_override=allow_override) + ensure_parent_dir_exists(path=path) + with open(path, "wb") as f: + f.write(content) + + +def ensure_parent_dir_exists(path: str) -> None: + absolute_path = os.path.abspath(path) + parent_dir = os.path.dirname(absolute_path) + os.makedirs(parent_dir, exist_ok=True) + + +def ensure_write_is_allowed(path: str, allow_override: bool) -> None: + if os.path.exists(path) and not allow_override: + raise RuntimeError(f"File {path} exists and override is forbidden.") diff --git a/inference/core/utils/hash.py b/inference/core/utils/hash.py new file mode 100644 index 0000000000000000000000000000000000000000..6f1835e5b77366e4a9655448b37e64a7c9669d1b --- /dev/null +++ b/inference/core/utils/hash.py @@ -0,0 +1,15 @@ +import hashlib + + +def get_string_list_hash(text: list) -> str: + """Get the hash of a list of strings. + + Args: + text (list): The list of strings. + + Returns: + str: The hash of the list of strings. + """ + text_string = ", ".join([f"{idx}:{t}" for idx, t in enumerate(text)]) + text_hash = hashlib.md5(text_string.encode("utf-8")).hexdigest() + return text_hash diff --git a/inference/core/utils/image_utils.py b/inference/core/utils/image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c0436a9d3bcd25fa5c09acf54cf32f2bbc022406 --- /dev/null +++ b/inference/core/utils/image_utils.py @@ -0,0 +1,443 @@ +import binascii +import os +import pickle +import re +from enum import Enum +from io import BytesIO +from typing import Any, Optional, Tuple, Union + +import cv2 +import numpy as np +import pybase64 +import requests +from _io import _IOBase +from PIL import Image +from requests import RequestException + +from inference.core.entities.requests.inference import InferenceRequestImage +from inference.core.env import ALLOW_NUMPY_INPUT +from inference.core.exceptions import ( + InputFormatInferenceFailed, + InputImageLoadError, + InvalidImageTypeDeclared, + InvalidNumpyInput, +) +from inference.core.utils.requests import api_key_safe_raise_for_status + +BASE64_DATA_TYPE_PATTERN = re.compile(r"^data:image\/[a-z]+;base64,") + + +class ImageType(Enum): + BASE64 = "base64" + FILE = "file" + MULTIPART = "multipart" + NUMPY = "numpy" + NUMPY_OBJECT = "numpy_object" + PILLOW = "pil" + URL = "url" + + +def load_image_rgb(value: Any, disable_preproc_auto_orient: bool = False) -> np.ndarray: + np_image, is_bgr = load_image( + value=value, disable_preproc_auto_orient=disable_preproc_auto_orient + ) + if is_bgr: + np_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB) + return np_image + + +def load_image( + value: Any, + disable_preproc_auto_orient: bool = False, +) -> Tuple[np.ndarray, bool]: + """Loads an image based on the specified type and value. + + Args: + value (Any): Image value which could be an instance of InferenceRequestImage, + a dict with 'type' and 'value' keys, or inferred based on the value's content. + + Returns: + Image.Image: The loaded PIL image, converted to RGB. + + Raises: + NotImplementedError: If the specified image type is not supported. + InvalidNumpyInput: If the numpy input method is used and the input data is invalid. + """ + cv_imread_flags = choose_image_decoding_flags( + disable_preproc_auto_orient=disable_preproc_auto_orient + ) + value, image_type = extract_image_payload_and_type(value=value) + if image_type is not None: + np_image, is_bgr = load_image_with_known_type( + value=value, + image_type=image_type, + cv_imread_flags=cv_imread_flags, + ) + else: + np_image, is_bgr = load_image_with_inferred_type( + value, cv_imread_flags=cv_imread_flags + ) + np_image = convert_gray_image_to_bgr(image=np_image) + return np_image, is_bgr + + +def choose_image_decoding_flags(disable_preproc_auto_orient: bool) -> int: + """Choose the appropriate OpenCV image decoding flags. + + Args: + disable_preproc_auto_orient (bool): Flag to disable preprocessing auto-orientation. + + Returns: + int: OpenCV image decoding flags. + """ + cv_imread_flags = cv2.IMREAD_COLOR + if disable_preproc_auto_orient: + cv_imread_flags = cv_imread_flags | cv2.IMREAD_IGNORE_ORIENTATION + return cv_imread_flags + + +def extract_image_payload_and_type(value: Any) -> Tuple[Any, Optional[ImageType]]: + """Extract the image payload and type from the given value. + + This function supports different types of image inputs (e.g., InferenceRequestImage, dict, etc.) + and extracts the relevant data and image type for further processing. + + Args: + value (Any): The input value which can be an image or information to derive the image. + + Returns: + Tuple[Any, Optional[ImageType]]: A tuple containing the extracted image data and the corresponding image type. + """ + image_type = None + if issubclass(type(value), InferenceRequestImage): + image_type = value.type + value = value.value + elif issubclass(type(value), dict): + image_type = value.get("type") + value = value.get("value") + allowed_payload_types = {e.value for e in ImageType} + if image_type is None: + return value, image_type + if image_type.lower() not in allowed_payload_types: + raise InvalidImageTypeDeclared( + f"Declared image type: {image_type.lower()} which is not in allowed types: {allowed_payload_types}." + ) + return value, ImageType(image_type.lower()) + + +def load_image_with_known_type( + value: Any, + image_type: ImageType, + cv_imread_flags: int = cv2.IMREAD_COLOR, +) -> Tuple[np.ndarray, bool]: + """Load an image using the known image type. + + Supports various image types (e.g., NUMPY, PILLOW, etc.) and loads them into a numpy array format. + + Args: + value (Any): The image data. + image_type (ImageType): The type of the image. + cv_imread_flags (int): Flags used for OpenCV's imread function. + + Returns: + Tuple[np.ndarray, bool]: A tuple of the loaded image as a numpy array and a boolean indicating if the image is in BGR format. + """ + if image_type is ImageType.NUMPY and not ALLOW_NUMPY_INPUT: + raise InvalidImageTypeDeclared( + f"NumPy image type is not supported in this configuration of `inference`." + ) + loader = IMAGE_LOADERS[image_type] + is_bgr = True if image_type is not ImageType.PILLOW else False + image = loader(value, cv_imread_flags) + return image, is_bgr + + +def load_image_with_inferred_type( + value: Any, + cv_imread_flags: int = cv2.IMREAD_COLOR, +) -> Tuple[np.ndarray, bool]: + """Load an image by inferring its type. + + Args: + value (Any): The image data. + cv_imread_flags (int): Flags used for OpenCV's imread function. + + Returns: + Tuple[np.ndarray, bool]: Loaded image as a numpy array and a boolean indicating if the image is in BGR format. + + Raises: + NotImplementedError: If the image type could not be inferred. + """ + if isinstance(value, (np.ndarray, np.generic)): + validate_numpy_image(data=value) + return value, True + elif isinstance(value, Image.Image): + return np.asarray(value.convert("RGB")), False + elif isinstance(value, str) and (value.startswith("http")): + return load_image_from_url(value=value, cv_imread_flags=cv_imread_flags), True + elif isinstance(value, str) and os.path.isfile(value): + return cv2.imread(value, cv_imread_flags), True + else: + return attempt_loading_image_from_string( + value=value, cv_imread_flags=cv_imread_flags + ) + + +def attempt_loading_image_from_string( + value: Union[str, bytes, bytearray, _IOBase], + cv_imread_flags: int = cv2.IMREAD_COLOR, +) -> Tuple[np.ndarray, bool]: + """ + Attempt to load an image from a string. + + Args: + value (Union[str, bytes, bytearray, _IOBase]): The image data in string format. + cv_imread_flags (int): OpenCV flags used for image reading. + + Returns: + Tuple[np.ndarray, bool]: A tuple of the loaded image in numpy array format and a boolean flag indicating if the image is in BGR format. + """ + try: + return load_image_base64(value=value, cv_imread_flags=cv_imread_flags), True + except: + pass + try: + return ( + load_image_from_encoded_bytes(value=value, cv_imread_flags=cv_imread_flags), + True, + ) + except: + pass + try: + return ( + load_image_from_buffer(value=value, cv_imread_flags=cv_imread_flags), + True, + ) + except: + pass + try: + return load_image_from_numpy_str(value=value), True + except InvalidNumpyInput as error: + raise InputFormatInferenceFailed( + "Input image format could not be inferred from string." + ) from error + + +def load_image_base64( + value: Union[str, bytes], cv_imread_flags=cv2.IMREAD_COLOR +) -> np.ndarray: + """Loads an image from a base64 encoded string using OpenCV. + + Args: + value (str): Base64 encoded string representing the image. + + Returns: + np.ndarray: The loaded image as a numpy array. + """ + # New routes accept images via json body (str), legacy routes accept bytes which need to be decoded as strings + if not isinstance(value, str): + value = value.decode("utf-8") + value = BASE64_DATA_TYPE_PATTERN.sub("", value) + value = pybase64.b64decode(value) + image_np = np.frombuffer(value, np.uint8) + result = cv2.imdecode(image_np, cv_imread_flags) + if result is None: + raise InputImageLoadError("Could not load valid image from base64 string.") + return result + + +def load_image_from_buffer( + value: _IOBase, + cv_imread_flags: int = cv2.IMREAD_COLOR, +) -> np.ndarray: + """Loads an image from a multipart-encoded input. + + Args: + value (Any): Multipart-encoded input representing the image. + + Returns: + Image.Image: The loaded PIL image. + """ + value.seek(0) + image_np = np.frombuffer(value.read(), np.uint8) + result = cv2.imdecode(image_np, cv_imread_flags) + if result is None: + raise InputImageLoadError("Could not load valid image from buffer.") + return result + + +def load_image_from_numpy_str(value: Union[bytes, str]) -> np.ndarray: + """Loads an image from a numpy array string. + + Args: + value (Union[bytes, str]): Base64 string or byte sequence representing the pickled numpy array of the image. + + Returns: + Image.Image: The loaded PIL image. + + Raises: + InvalidNumpyInput: If the numpy data is invalid. + """ + try: + if isinstance(value, str): + value = pybase64.b64decode(value) + data = pickle.loads(value) + except (EOFError, TypeError, pickle.UnpicklingError, binascii.Error) as error: + raise InvalidNumpyInput( + f"Could not unpickle image data. Cause: {error}" + ) from error + validate_numpy_image(data=data) + return data + + +def load_image_from_numpy_object(value: np.ndarray) -> np.ndarray: + validate_numpy_image(data=value) + return value + + +def validate_numpy_image(data: np.ndarray) -> None: + """ + Validate if the provided data is a valid numpy image. + + Args: + data (np.ndarray): The numpy array representing an image. + + Raises: + InvalidNumpyInput: If the provided data is not a valid numpy image. + """ + if not issubclass(type(data), np.ndarray): + raise InvalidNumpyInput( + f"Data provided as input could not be decoded into np.ndarray object." + ) + if len(data.shape) != 3 and len(data.shape) != 2: + raise InvalidNumpyInput( + f"For image given as np.ndarray expected 2 or 3 dimensions, got {len(data.shape)} dimensions." + ) + if data.shape[-1] != 3 and data.shape[-1] != 1: + raise InvalidNumpyInput( + f"For image given as np.ndarray expected 1 or 3 channels, got {data.shape[-1]} channels." + ) + + +def load_image_from_url( + value: str, cv_imread_flags: int = cv2.IMREAD_COLOR +) -> np.ndarray: + """Loads an image from a given URL. + + Args: + value (str): URL of the image. + + Returns: + Image.Image: The loaded PIL image. + """ + try: + response = requests.get(value, stream=True) + api_key_safe_raise_for_status(response=response) + return load_image_from_encoded_bytes( + value=response.content, cv_imread_flags=cv_imread_flags + ) + except (RequestException, ConnectionError) as error: + raise InputImageLoadError( + f"Error while loading image from url: {value}. Details: {error}" + ) + + +def load_image_from_encoded_bytes( + value: bytes, cv_imread_flags: int = cv2.IMREAD_COLOR +) -> np.ndarray: + """ + Load an image from encoded bytes. + + Args: + value (bytes): The byte sequence representing the image. + cv_imread_flags (int): OpenCV flags used for image reading. + + Returns: + np.ndarray: The loaded image as a numpy array. + """ + image_np = np.asarray(bytearray(value), dtype=np.uint8) + image = cv2.imdecode(image_np, cv_imread_flags) + if image is None: + raise InputImageLoadError( + f"Could not parse response content from url {value} into image." + ) + return image + + +IMAGE_LOADERS = { + ImageType.BASE64: load_image_base64, + ImageType.FILE: cv2.imread, + ImageType.MULTIPART: load_image_from_buffer, + ImageType.NUMPY: lambda v, _: load_image_from_numpy_str(v), + ImageType.NUMPY_OBJECT: lambda v, _: load_image_from_numpy_object(v), + ImageType.PILLOW: lambda v, _: np.asarray(v.convert("RGB")), + ImageType.URL: load_image_from_url, +} + + +def convert_gray_image_to_bgr(image: np.ndarray) -> np.ndarray: + """ + Convert a grayscale image to BGR format. + + Args: + image (np.ndarray): The grayscale image. + + Returns: + np.ndarray: The converted BGR image. + """ + + if len(image.shape) == 2 or image.shape[2] == 1: + image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + return image + + +def np_image_to_base64(image: np.ndarray) -> bytes: + """ + Convert a numpy image to a base64 encoded byte string. + + Args: + image (np.ndarray): The numpy array representing an image. + + Returns: + bytes: The base64 encoded image. + """ + image = Image.fromarray(image) + with BytesIO() as buffer: + image = image.convert("RGB") + image.save(buffer, format="JPEG") + buffer.seek(0) + return buffer.getvalue() + + +def xyxy_to_xywh(xyxy): + """ + Convert bounding box format from (xmin, ymin, xmax, ymax) to (xcenter, ycenter, width, height). + + Args: + xyxy (List[int]): List containing the coordinates in (xmin, ymin, xmax, ymax) format. + + Returns: + List[int]: List containing the converted coordinates in (xcenter, ycenter, width, height) format. + """ + x_temp = (xyxy[0] + xyxy[2]) / 2 + y_temp = (xyxy[1] + xyxy[3]) / 2 + w_temp = abs(xyxy[0] - xyxy[2]) + h_temp = abs(xyxy[1] - xyxy[3]) + + return [int(x_temp), int(y_temp), int(w_temp), int(h_temp)] + + +def encode_image_to_jpeg_bytes(image: np.ndarray, jpeg_quality: int = 90) -> bytes: + """ + Encode a numpy image to JPEG format in bytes. + + Args: + image (np.ndarray): The numpy array representing an image. + jpeg_quality (int): Quality of the JPEG image. + + Returns: + bytes: The JPEG encoded image. + """ + encoding_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality] + _, img_encoded = cv2.imencode(".jpg", image, encoding_param) + return np.array(img_encoded).tobytes() diff --git a/inference/core/utils/notebooks.py b/inference/core/utils/notebooks.py new file mode 100644 index 0000000000000000000000000000000000000000..77c5aa497a78954b19d8d771479cc93d293cca67 --- /dev/null +++ b/inference/core/utils/notebooks.py @@ -0,0 +1,24 @@ +import os +import subprocess + +import requests + +from inference.core.env import NOTEBOOK_PASSWORD, NOTEBOOK_PORT + + +def check_notebook_is_running(): + try: + response = requests.get(f"http://localhost:{NOTEBOOK_PORT}/") + return response.status_code == 200 + except: + return False + + +def start_notebook(): + if not check_notebook_is_running(): + os.makedirs("/notebooks", exist_ok=True) + subprocess.Popen( + f"jupyter-lab --allow-root --port={NOTEBOOK_PORT} --ip=0.0.0.0 --notebook-dir=/notebooks --NotebookApp.token='{NOTEBOOK_PASSWORD}' --NotebookApp.password='{NOTEBOOK_PASSWORD}'".split( + " " + ) + ) diff --git a/inference/core/utils/onnx.py b/inference/core/utils/onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..ae49fe4bbd7581506ed5e50c3c64e19c95bbbe93 --- /dev/null +++ b/inference/core/utils/onnx.py @@ -0,0 +1,19 @@ +from typing import List + + +def get_onnxruntime_execution_providers(value: str) -> List[str]: + """Extracts the ONNX runtime execution providers from the given string. + + The input string is expected to be a comma-separated list, possibly enclosed + within square brackets and containing single quotes. + + Args: + value (str): The string containing the list of ONNX runtime execution providers. + + Returns: + List[str]: A list of strings representing each execution provider. + """ + if len(value) == 0: + return [] + value = value.replace("[", "").replace("]", "").replace("'", "").replace(" ", "") + return value.split(",") diff --git a/inference/core/utils/postprocess.py b/inference/core/utils/postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..46554ccf0e601855aee7931a47388cd8b05dc7bd --- /dev/null +++ b/inference/core/utils/postprocess.py @@ -0,0 +1,618 @@ +from copy import deepcopy +from typing import Dict, List, Tuple, Union + +import cv2 +import numpy as np + +from inference.core.exceptions import PostProcessingError +from inference.core.utils.preprocess import ( + STATIC_CROP_KEY, + static_crop_should_be_applied, +) + + +def cosine_similarity(a: np.ndarray, b: np.ndarray) -> Union[np.number, np.ndarray]: + """ + Compute the cosine similarity between two vectors. + + Args: + a (np.ndarray): Vector A. + b (np.ndarray): Vector B. + + Returns: + float: Cosine similarity between vectors A and B. + """ + return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) + + +def masks2poly(masks: np.ndarray) -> List[np.ndarray]: + """Converts binary masks to polygonal segments. + + Args: + masks (numpy.ndarray): A set of binary masks, where masks are multiplied by 255 and converted to uint8 type. + + Returns: + list: A list of segments, where each segment is obtained by converting the corresponding mask. + """ + segments = [] + masks = (masks * 255.0).astype(np.uint8) + for mask in masks: + segments.append(mask2poly(mask)) + return segments + + +def mask2poly(mask: np.ndarray) -> np.ndarray: + """ + Find contours in the mask and return them as a float32 array. + + Args: + mask (np.ndarray): A binary mask. + + Returns: + np.ndarray: Contours represented as a float32 array. + """ + contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0] + if contours: + contours = np.array( + contours[np.array([len(x) for x in contours]).argmax()] + ).reshape(-1, 2) + else: + contours = np.zeros((0, 2)) + return contours.astype("float32") + + +def post_process_bboxes( + predictions: List[List[List[float]]], + infer_shape: Tuple[int, int], + img_dims: List[Tuple[int, int]], + preproc: dict, + disable_preproc_static_crop: bool = False, + resize_method: str = "Stretch to", +) -> List[List[List[float]]]: + """ + Postprocesses each patch of detections by scaling them to the original image coordinates and by shifting them based on a static crop preproc (if applied). + + Args: + predictions (List[List[List[float]]]): The predictions output from NMS, indices are: batch x prediction x [x1, y1, x2, y2, ...]. + infer_shape (Tuple[int, int]): The shape of the inference image. + img_dims (List[Tuple[int, int]]): The dimensions of the original image for each batch, indices are: batch x [height, width]. + preproc (dict): Preprocessing configuration dictionary. + disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False. + resize_method (str, optional): Resize method for image. Defaults to "Stretch to". + + Returns: + List[List[List[float]]]: The scaled and shifted predictions, indices are: batch x prediction x [x1, y1, x2, y2, ...]. + """ + + # Get static crop params + scaled_predictions = [] + # Loop through batches + for i, batch_predictions in enumerate(predictions): + if len(batch_predictions) == 0: + scaled_predictions.append([]) + continue + np_batch_predictions = np.array(batch_predictions) + # Get bboxes from predictions (x1,y1,x2,y2) + predicted_bboxes = np_batch_predictions[:, :4] + (crop_shift_x, crop_shift_y), origin_shape = get_static_crop_dimensions( + img_dims[i], + preproc, + disable_preproc_static_crop=disable_preproc_static_crop, + ) + if resize_method == "Stretch to": + predicted_bboxes = stretch_bboxes( + predicted_bboxes=predicted_bboxes, + infer_shape=infer_shape, + origin_shape=origin_shape, + ) + elif ( + resize_method == "Fit (black edges) in" + or resize_method == "Fit (white edges) in" + ): + predicted_bboxes = undo_image_padding_for_predicted_boxes( + predicted_bboxes=predicted_bboxes, + infer_shape=infer_shape, + origin_shape=origin_shape, + ) + predicted_bboxes = clip_boxes_coordinates( + predicted_bboxes=predicted_bboxes, + origin_shape=origin_shape, + ) + predicted_bboxes = shift_bboxes( + bboxes=predicted_bboxes, + shift_x=crop_shift_x, + shift_y=crop_shift_y, + ) + np_batch_predictions[:, :4] = predicted_bboxes + scaled_predictions.append(np_batch_predictions.tolist()) + return scaled_predictions + + +def stretch_bboxes( + predicted_bboxes: np.ndarray, + infer_shape: Tuple[int, int], + origin_shape: Tuple[int, int], +) -> np.ndarray: + scale_height = origin_shape[0] / infer_shape[0] + scale_width = origin_shape[1] / infer_shape[1] + return scale_bboxes( + bboxes=predicted_bboxes, + scale_x=scale_width, + scale_y=scale_height, + ) + + +def undo_image_padding_for_predicted_boxes( + predicted_bboxes: np.ndarray, + infer_shape: Tuple[int, int], + origin_shape: Tuple[int, int], +) -> np.ndarray: + scale = min(infer_shape[0] / origin_shape[0], infer_shape[1] / origin_shape[1]) + inter_h = round(origin_shape[0] * scale) + inter_w = round(origin_shape[1] * scale) + pad_x = (infer_shape[0] - inter_w) / 2 + pad_y = (infer_shape[1] - inter_h) / 2 + predicted_bboxes = shift_bboxes( + bboxes=predicted_bboxes, shift_x=-pad_x, shift_y=-pad_y + ) + predicted_bboxes /= scale + return predicted_bboxes + + +def clip_boxes_coordinates( + predicted_bboxes: np.ndarray, + origin_shape: Tuple[int, int], +) -> np.ndarray: + predicted_bboxes[:, 0] = np.round( + np.clip(predicted_bboxes[:, 0], a_min=0, a_max=origin_shape[1]) + ) + predicted_bboxes[:, 2] = np.round( + np.clip(predicted_bboxes[:, 2], a_min=0, a_max=origin_shape[1]) + ) + predicted_bboxes[:, 1] = np.round( + np.clip(predicted_bboxes[:, 1], a_min=0, a_max=origin_shape[0]) + ) + predicted_bboxes[:, 3] = np.round( + np.clip(predicted_bboxes[:, 3], a_min=0, a_max=origin_shape[0]) + ) + return predicted_bboxes + + +def shift_bboxes( + bboxes: np.ndarray, + shift_x: Union[int, float], + shift_y: Union[int, float], +) -> np.ndarray: + bboxes[:, 0] += shift_x + bboxes[:, 2] += shift_x + bboxes[:, 1] += shift_y + bboxes[:, 3] += shift_y + return bboxes + + +def process_mask_accurate( + protos: np.ndarray, + masks_in: np.ndarray, + bboxes: np.ndarray, + shape: Tuple[int, int], +) -> np.ndarray: + """Returns masks that are the size of the original image. + + Args: + protos (numpy.ndarray): Prototype masks. + masks_in (numpy.ndarray): Input masks. + bboxes (numpy.ndarray): Bounding boxes. + shape (tuple): Target shape. + + Returns: + numpy.ndarray: Processed masks. + """ + masks = preprocess_segmentation_masks( + protos=protos, + masks_in=masks_in, + shape=shape, + ) + + # Order = 1 -> bilinear + if len(masks.shape) == 2: + masks = np.expand_dims(masks, axis=0) + masks = masks.transpose((1, 2, 0)) + masks = cv2.resize(masks, (shape[1], shape[0]), cv2.INTER_LINEAR) + if len(masks.shape) == 2: + masks = np.expand_dims(masks, axis=2) + masks = masks.transpose((2, 0, 1)) + masks = crop_mask(masks, bboxes) + masks[masks < 0.5] = 0 + return masks + + +def process_mask_tradeoff( + protos: np.ndarray, + masks_in: np.ndarray, + bboxes: np.ndarray, + shape: Tuple[int, int], + tradeoff_factor: float, +) -> np.ndarray: + """Returns masks that are the size of the original image with a tradeoff factor applied. + + Args: + protos (numpy.ndarray): Prototype masks. + masks_in (numpy.ndarray): Input masks. + bboxes (numpy.ndarray): Bounding boxes. + shape (tuple): Target shape. + tradeoff_factor (float): Tradeoff factor for resizing masks. + + Returns: + numpy.ndarray: Processed masks. + """ + c, mh, mw = protos.shape # CHW + masks = preprocess_segmentation_masks( + protos=protos, + masks_in=masks_in, + shape=shape, + ) + + # Order = 1 -> bilinear + if len(masks.shape) == 2: + masks = np.expand_dims(masks, axis=0) + masks = masks.transpose((1, 2, 0)) + ih, iw = shape + h = int(mh * (1 - tradeoff_factor) + ih * tradeoff_factor) + w = int(mw * (1 - tradeoff_factor) + iw * tradeoff_factor) + size = (h, w) + if tradeoff_factor != 0: + masks = cv2.resize(masks, size, cv2.INTER_LINEAR) + if len(masks.shape) == 2: + masks = np.expand_dims(masks, axis=2) + masks = masks.transpose((2, 0, 1)) + c, mh, mw = masks.shape + down_sampled_boxes = scale_bboxes( + bboxes=deepcopy(bboxes), + scale_x=mw / iw, + scale_y=mh / ih, + ) + masks = crop_mask(masks, down_sampled_boxes) + masks[masks < 0.5] = 0 + return masks + + +def process_mask_fast( + protos: np.ndarray, + masks_in: np.ndarray, + bboxes: np.ndarray, + shape: Tuple[int, int], +) -> np.ndarray: + """Returns masks in their original size. + + Args: + protos (numpy.ndarray): Prototype masks. + masks_in (numpy.ndarray): Input masks. + bboxes (numpy.ndarray): Bounding boxes. + shape (tuple): Target shape. + + Returns: + numpy.ndarray: Processed masks. + """ + ih, iw = shape + c, mh, mw = protos.shape # CHW + masks = preprocess_segmentation_masks( + protos=protos, + masks_in=masks_in, + shape=shape, + ) + down_sampled_boxes = scale_bboxes( + bboxes=deepcopy(bboxes), + scale_x=mw / iw, + scale_y=mh / ih, + ) + masks = crop_mask(masks, down_sampled_boxes) + masks[masks < 0.5] = 0 + return masks + + +def preprocess_segmentation_masks( + protos: np.ndarray, + masks_in: np.ndarray, + shape: Tuple[int, int], +) -> np.ndarray: + c, mh, mw = protos.shape # CHW + masks = protos.astype(np.float32) + masks = masks.reshape((c, -1)) + masks = masks_in @ masks + masks = sigmoid(masks) + masks = masks.reshape((-1, mh, mw)) + gain = min(mh / shape[0], mw / shape[1]) # gain = old / new + pad = (mw - shape[1] * gain) / 2, (mh - shape[0] * gain) / 2 # wh padding + top, left = int(pad[1]), int(pad[0]) # y, x + bottom, right = int(mh - pad[1]), int(mw - pad[0]) + return masks[:, top:bottom, left:right] + + +def scale_bboxes(bboxes: np.ndarray, scale_x: float, scale_y: float) -> np.ndarray: + bboxes[:, 0] *= scale_x + bboxes[:, 2] *= scale_x + bboxes[:, 1] *= scale_y + bboxes[:, 3] *= scale_y + return bboxes + + +def crop_mask(masks: np.ndarray, boxes: np.ndarray) -> np.ndarray: + """ + "Crop" predicted masks by zeroing out everything not in the predicted bbox. + Vectorized by Chong (thanks Chong). + + Args: + - masks should be a size [h, w, n] tensor of masks + - boxes should be a size [n, 4] tensor of bbox coords in relative point form + """ + + n, h, w = masks.shape + x1, y1, x2, y2 = np.split(boxes[:, :, None], 4, 1) # x1 shape(1,1,n) + r = np.arange(w, dtype=x1.dtype)[None, None, :] # rows shape(1,w,1) + c = np.arange(h, dtype=x1.dtype)[None, :, None] # cols shape(h,1,1) + + masks = masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2)) + return masks + + +def post_process_polygons( + origin_shape: Tuple[int, int], + polys: List[List[Tuple[float, float]]], + infer_shape: Tuple[int, int], + preproc: dict, + resize_method: str = "Stretch to", +) -> List[List[Tuple[float, float]]]: + """Scales and shifts polygons based on the given image shapes and preprocessing method. + + This function performs polygon scaling and shifting based on the specified resizing method and + pre-processing steps. The polygons are transformed according to the ratio and padding between two images. + + Args: + origin_shape (tuple of int): Shape of the source image (height, width). + infer_shape (tuple of int): Shape of the target image (height, width). + polys (list of list of tuple): List of polygons, where each polygon is represented by a list of (x, y) coordinates. + preproc (object): Preprocessing details used for generating the transformation. + resize_method (str, optional): Resizing method, either "Stretch to", "Fit (black edges) in", or "Fit (white edges) in". Defaults to "Stretch to". + + Returns: + list of list of tuple: A list of shifted and scaled polygons. + """ + (crop_shift_x, crop_shift_y), origin_shape = get_static_crop_dimensions( + origin_shape, preproc + ) + new_polys = [] + if resize_method == "Stretch to": + width_ratio = origin_shape[1] / infer_shape[1] + height_ratio = origin_shape[0] / infer_shape[0] + new_polys = scale_polygons( + polygons=polys, + x_scale=width_ratio, + y_scale=height_ratio, + ) + elif resize_method in {"Fit (black edges) in", "Fit (white edges) in"}: + new_polys = undo_image_padding_for_predicted_polygons( + polygons=polys, + infer_shape=infer_shape, + origin_shape=origin_shape, + ) + shifted_polys = [] + for poly in new_polys: + poly = [(p[0] + crop_shift_x, p[1] + crop_shift_y) for p in poly] + shifted_polys.append(poly) + return shifted_polys + + +def scale_polygons( + polygons: List[List[Tuple[float, float]]], + x_scale: float, + y_scale: float, +) -> List[List[Tuple[float, float]]]: + result = [] + for poly in polygons: + poly = [(p[0] * x_scale, p[1] * y_scale) for p in poly] + result.append(poly) + return result + + +def undo_image_padding_for_predicted_polygons( + polygons: List[List[Tuple[float, float]]], + origin_shape: Tuple[int, int], + infer_shape: Tuple[int, int], +) -> List[List[Tuple[float, float]]]: + scale = min(infer_shape[0] / origin_shape[0], infer_shape[1] / origin_shape[1]) + inter_w = int(origin_shape[1] * scale) + inter_h = int(origin_shape[0] * scale) + pad_x = (infer_shape[1] - inter_w) / 2 + pad_y = (infer_shape[0] - inter_h) / 2 + result = [] + for poly in polygons: + poly = [((p[0] - pad_x) / scale, (p[1] - pad_y) / scale) for p in poly] + result.append(poly) + return result + + +def get_static_crop_dimensions( + orig_shape: Tuple[int, int], + preproc: dict, + disable_preproc_static_crop: bool = False, +) -> Tuple[Tuple[int, int], Tuple[int, int]]: + """ + Generates a transformation based on preprocessing configuration. + + Args: + orig_shape (tuple): The original shape of the object (e.g., image) - (height, width). + preproc (dict): Preprocessing configuration dictionary, containing information such as static cropping. + disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False. + + Returns: + tuple: A tuple containing the shift in the x and y directions, and the updated original shape after cropping. + """ + try: + if static_crop_should_be_applied( + preprocessing_config=preproc, + disable_preproc_static_crop=disable_preproc_static_crop, + ): + x_min, y_min, x_max, y_max = standardise_static_crop( + static_crop_config=preproc[STATIC_CROP_KEY] + ) + else: + x_min, y_min, x_max, y_max = 0, 0, 1, 1 + crop_shift_x, crop_shift_y = ( + round(x_min * orig_shape[1]), + round(y_min * orig_shape[0]), + ) + cropped_percent_x = x_max - x_min + cropped_percent_y = y_max - y_min + orig_shape = ( + round(orig_shape[0] * cropped_percent_y), + round(orig_shape[1] * cropped_percent_x), + ) + return (crop_shift_x, crop_shift_y), orig_shape + except KeyError as error: + raise PostProcessingError( + f"Could not find a proper configuration key {error} in post-processing." + ) + + +def standardise_static_crop( + static_crop_config: Dict[str, int] +) -> Tuple[float, float, float, float]: + return tuple(static_crop_config[key] / 100 for key in ["x_min", "y_min", "x_max", "y_max"]) # type: ignore + + +def post_process_keypoints( + predictions: List[List[List[float]]], + keypoints_start_index: int, + infer_shape: Tuple[int, int], + img_dims: List[Tuple[int, int]], + preproc: dict, + disable_preproc_static_crop: bool = False, + resize_method: str = "Stretch to", +) -> List[List[List[float]]]: + """Scales and shifts keypoints based on the given image shapes and preprocessing method. + + This function performs polygon scaling and shifting based on the specified resizing method and + pre-processing steps. The polygons are transformed according to the ratio and padding between two images. + + Args: + predictions: predictions from model + keypoints_start_index: offset in the 3rd dimension pointing where in the prediction start keypoints [(x, y, cfg), ...] for each keypoint class + img_dims list of (tuple of int): Shape of the source image (height, width). + infer_shape (tuple of int): Shape of the target image (height, width). + preproc (object): Preprocessing details used for generating the transformation. + resize_method (str, optional): Resizing method, either "Stretch to", "Fit (black edges) in", or "Fit (white edges) in". Defaults to "Stretch to". + disable_preproc_static_crop: flag to disable static crop + Returns: + list of list of list: predictions with post-processed keypoints + """ + # Get static crop params + scaled_predictions = [] + # Loop through batches + for i, batch_predictions in enumerate(predictions): + if len(batch_predictions) == 0: + scaled_predictions.append([]) + continue + np_batch_predictions = np.array(batch_predictions) + keypoints = np_batch_predictions[:, keypoints_start_index:] + (crop_shift_x, crop_shift_y), origin_shape = get_static_crop_dimensions( + img_dims[i], + preproc, + disable_preproc_static_crop=disable_preproc_static_crop, + ) + if resize_method == "Stretch to": + keypoints = stretch_keypoints( + keypoints=keypoints, + infer_shape=infer_shape, + origin_shape=origin_shape, + ) + elif ( + resize_method == "Fit (black edges) in" + or resize_method == "Fit (white edges) in" + ): + keypoints = undo_image_padding_for_predicted_keypoints( + keypoints=keypoints, + infer_shape=infer_shape, + origin_shape=origin_shape, + ) + keypoints = clip_keypoints_coordinates( + keypoints=keypoints, origin_shape=origin_shape + ) + keypoints = shift_keypoints( + keypoints=keypoints, shift_x=crop_shift_x, shift_y=crop_shift_y + ) + np_batch_predictions[:, keypoints_start_index:] = keypoints + scaled_predictions.append(np_batch_predictions.tolist()) + return scaled_predictions + + +def stretch_keypoints( + keypoints: np.ndarray, + infer_shape: Tuple[int, int], + origin_shape: Tuple[int, int], +) -> np.ndarray: + scale_width = origin_shape[1] / infer_shape[1] + scale_height = origin_shape[0] / infer_shape[0] + for keypoint_id in range(keypoints.shape[1] // 3): + keypoints[:, keypoint_id * 3] *= scale_width + keypoints[:, keypoint_id * 3 + 1] *= scale_height + return keypoints + + +def undo_image_padding_for_predicted_keypoints( + keypoints: np.ndarray, + infer_shape: Tuple[int, int], + origin_shape: Tuple[int, int], +) -> np.ndarray: + # Undo scaling and padding from letterbox resize preproc operation + scale = min(infer_shape[0] / origin_shape[0], infer_shape[1] / origin_shape[1]) + inter_w = int(origin_shape[1] * scale) + inter_h = int(origin_shape[0] * scale) + + pad_x = (infer_shape[1] - inter_w) / 2 + pad_y = (infer_shape[0] - inter_h) / 2 + for coord_id in range(keypoints.shape[1] // 3): + keypoints[:, coord_id * 3] -= pad_x + keypoints[:, coord_id * 3] /= scale + keypoints[:, coord_id * 3 + 1] -= pad_y + keypoints[:, coord_id * 3 + 1] /= scale + return keypoints + + +def clip_keypoints_coordinates( + keypoints: np.ndarray, + origin_shape: Tuple[int, int], +) -> np.ndarray: + for keypoint_id in range(keypoints.shape[1] // 3): + keypoints[:, keypoint_id * 3] = np.round( + np.clip(keypoints[:, keypoint_id * 3], a_min=0, a_max=origin_shape[1]) + ) + keypoints[:, keypoint_id * 3 + 1] = np.round( + np.clip(keypoints[:, keypoint_id * 3 + 1], a_min=0, a_max=origin_shape[0]) + ) + return keypoints + + +def shift_keypoints( + keypoints: np.ndarray, + shift_x: Union[int, float], + shift_y: Union[int, float], +) -> np.ndarray: + for keypoint_id in range(keypoints.shape[1] // 3): + keypoints[:, keypoint_id * 3] += shift_x + keypoints[:, keypoint_id * 3 + 1] += shift_y + return keypoints + + +def sigmoid(x: Union[float, np.ndarray]) -> Union[float, np.number, np.ndarray]: + """Computes the sigmoid function for the given input. + + The sigmoid function is defined as: + f(x) = 1 / (1 + exp(-x)) + + Args: + x (float or numpy.ndarray): Input value or array for which the sigmoid function is to be computed. + + Returns: + float or numpy.ndarray: The computed sigmoid value(s). + """ + return 1 / (1 + np.exp(-x)) diff --git a/inference/core/utils/preprocess.py b/inference/core/utils/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..1e9b6691bfcfd4aae5d9321e28da215575c9a773 --- /dev/null +++ b/inference/core/utils/preprocess.py @@ -0,0 +1,243 @@ +from enum import Enum +from typing import Dict, Tuple + +import cv2 +import numpy as np +from skimage.exposure import rescale_intensity + +from inference.core.env import ( + DISABLE_PREPROC_CONTRAST, + DISABLE_PREPROC_GRAYSCALE, + DISABLE_PREPROC_STATIC_CROP, +) +from inference.core.exceptions import PreProcessingError + +STATIC_CROP_KEY = "static-crop" +CONTRAST_KEY = "contrast" +GRAYSCALE_KEY = "grayscale" +ENABLED_KEY = "enabled" +TYPE_KEY = "type" + + +class ContrastAdjustmentType(Enum): + CONTRAST_STRETCHING = "Contrast Stretching" + HISTOGRAM_EQUALISATION = "Histogram Equalization" + ADAPTIVE_EQUALISATION = "Adaptive Equalization" + + +def prepare( + image: np.ndarray, + preproc, + disable_preproc_contrast: bool = False, + disable_preproc_grayscale: bool = False, + disable_preproc_static_crop: bool = False, +) -> Tuple[np.ndarray, Tuple[int, int]]: + """ + Prepares an image by applying a series of preprocessing steps defined in the `preproc` dictionary. + + Args: + image (PIL.Image.Image): The input PIL image object. + preproc (dict): Dictionary containing preprocessing steps. Example: + { + "resize": {"enabled": true, "width": 416, "height": 416, "format": "Stretch to"}, + "static-crop": {"y_min": 25, "x_max": 75, "y_max": 75, "enabled": true, "x_min": 25}, + "auto-orient": {"enabled": true}, + "grayscale": {"enabled": true}, + "contrast": {"enabled": true, "type": "Adaptive Equalization"} + } + disable_preproc_contrast (bool, optional): If true, the contrast preprocessing step is disabled for this call. Default is False. + disable_preproc_grayscale (bool, optional): If true, the grayscale preprocessing step is disabled for this call. Default is False. + disable_preproc_static_crop (bool, optional): If true, the static crop preprocessing step is disabled for this call. Default is False. + + Returns: + PIL.Image.Image: The preprocessed image object. + tuple: The dimensions of the image. + + Note: + The function uses global flags like `DISABLE_PREPROC_AUTO_ORIENT`, `DISABLE_PREPROC_STATIC_CROP`, etc. + to conditionally enable or disable certain preprocessing steps. + """ + try: + h, w = image.shape[0:2] + img_dims = (h, w) + if static_crop_should_be_applied( + preprocessing_config=preproc, + disable_preproc_static_crop=disable_preproc_static_crop, + ): + image = take_static_crop( + image=image, crop_parameters=preproc[STATIC_CROP_KEY] + ) + if contrast_adjustments_should_be_applied( + preprocessing_config=preproc, + disable_preproc_contrast=disable_preproc_contrast, + ): + adjustment_type = ContrastAdjustmentType(preproc[CONTRAST_KEY][TYPE_KEY]) + image = apply_contrast_adjustment( + image=image, adjustment_type=adjustment_type + ) + if grayscale_conversion_should_be_applied( + preprocessing_config=preproc, + disable_preproc_grayscale=disable_preproc_grayscale, + ): + image = apply_grayscale_conversion(image=image) + return image, img_dims + except KeyError as error: + raise PreProcessingError( + f"Pre-processing of image failed due to misconfiguration. Missing key: {error}." + ) from error + + +def static_crop_should_be_applied( + preprocessing_config: dict, + disable_preproc_static_crop: bool, +) -> bool: + return ( + STATIC_CROP_KEY in preprocessing_config.keys() + and not DISABLE_PREPROC_STATIC_CROP + and not disable_preproc_static_crop + and preprocessing_config[STATIC_CROP_KEY][ENABLED_KEY] + ) + + +def take_static_crop(image: np.ndarray, crop_parameters: Dict[str, int]) -> np.ndarray: + height, width = image.shape[0:2] + x_min = int(crop_parameters["x_min"] / 100 * width) + y_min = int(crop_parameters["y_min"] / 100 * height) + x_max = int(crop_parameters["x_max"] / 100 * width) + y_max = int(crop_parameters["y_max"] / 100 * height) + return image[y_min:y_max, x_min:x_max, :] + + +def contrast_adjustments_should_be_applied( + preprocessing_config: dict, + disable_preproc_contrast: bool, +) -> bool: + return ( + CONTRAST_KEY in preprocessing_config.keys() + and not DISABLE_PREPROC_CONTRAST + and not disable_preproc_contrast + and preprocessing_config[CONTRAST_KEY][ENABLED_KEY] + ) + + +def apply_contrast_adjustment( + image: np.ndarray, + adjustment_type: ContrastAdjustmentType, +) -> np.ndarray: + adjustment = CONTRAST_ADJUSTMENTS_METHODS[adjustment_type] + return adjustment(image) + + +def apply_contrast_stretching(image: np.ndarray) -> np.ndarray: + p2, p98 = np.percentile(image, (2, 98)) + return rescale_intensity(image, in_range=(p2, p98)) # type: ignore + + +def apply_histogram_equalisation(image: np.ndarray) -> np.ndarray: + image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + image = cv2.equalizeHist(image) + return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + + +def apply_adaptive_equalisation(image: np.ndarray) -> np.ndarray: + image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + clahe = cv2.createCLAHE(clipLimit=0.03, tileGridSize=(8, 8)) + image = clahe.apply(image) + return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + + +CONTRAST_ADJUSTMENTS_METHODS = { + ContrastAdjustmentType.CONTRAST_STRETCHING: apply_contrast_stretching, + ContrastAdjustmentType.HISTOGRAM_EQUALISATION: apply_histogram_equalisation, + ContrastAdjustmentType.ADAPTIVE_EQUALISATION: apply_adaptive_equalisation, +} + + +def grayscale_conversion_should_be_applied( + preprocessing_config: dict, + disable_preproc_grayscale: bool, +) -> bool: + return ( + GRAYSCALE_KEY in preprocessing_config.keys() + and not DISABLE_PREPROC_GRAYSCALE + and not disable_preproc_grayscale + and preprocessing_config[GRAYSCALE_KEY][ENABLED_KEY] + ) + + +def apply_grayscale_conversion(image: np.ndarray) -> np.ndarray: + image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + return cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) + + +def letterbox_image( + image: np.ndarray, + desired_size: Tuple[int, int], + color: Tuple[int, int, int] = (0, 0, 0), +) -> np.ndarray: + """ + Resize and pad image to fit the desired size, preserving its aspect ratio. + + Parameters: + - image: numpy array representing the image. + - desired_size: tuple (width, height) representing the target dimensions. + - color: tuple (B, G, R) representing the color to pad with. + + Returns: + - letterboxed image. + """ + resized_img = resize_image_keeping_aspect_ratio( + image=image, + desired_size=desired_size, + ) + new_height, new_width = resized_img.shape[:2] + top_padding = (desired_size[1] - new_height) // 2 + bottom_padding = desired_size[1] - new_height - top_padding + left_padding = (desired_size[0] - new_width) // 2 + right_padding = desired_size[0] - new_width - left_padding + return cv2.copyMakeBorder( + resized_img, + top_padding, + bottom_padding, + left_padding, + right_padding, + cv2.BORDER_CONSTANT, + value=color, + ) + + +def downscale_image_keeping_aspect_ratio( + image: np.ndarray, + desired_size: Tuple[int, int], +) -> np.ndarray: + if image.shape[0] <= desired_size[1] and image.shape[1] <= desired_size[0]: + return image + return resize_image_keeping_aspect_ratio(image=image, desired_size=desired_size) + + +def resize_image_keeping_aspect_ratio( + image: np.ndarray, + desired_size: Tuple[int, int], +) -> np.ndarray: + """ + Resize reserving its aspect ratio. + + Parameters: + - image: numpy array representing the image. + - desired_size: tuple (width, height) representing the target dimensions. + """ + img_ratio = image.shape[1] / image.shape[0] + desired_ratio = desired_size[0] / desired_size[1] + + # Determine the new dimensions + if img_ratio >= desired_ratio: + # Resize by width + new_width = desired_size[0] + new_height = int(desired_size[0] / img_ratio) + else: + # Resize by height + new_height = desired_size[1] + new_width = int(desired_size[1] * img_ratio) + + # Resize the image to new dimensions + return cv2.resize(image, (new_width, new_height)) diff --git a/inference/core/utils/requests.py b/inference/core/utils/requests.py new file mode 100644 index 0000000000000000000000000000000000000000..abf9abc219a8df57cd4666f7388380028f73a440 --- /dev/null +++ b/inference/core/utils/requests.py @@ -0,0 +1,24 @@ +import re + +from requests import Response + +API_KEY_PATTERN = re.compile(r"api_key=(.[^&]*)") +KEY_VALUE_GROUP = 1 +MIN_KEY_LENGTH_TO_REVEAL_PREFIX = 8 + + +def api_key_safe_raise_for_status(response: Response) -> None: + request_is_successful = response.status_code < 400 + if request_is_successful: + return None + response.url = API_KEY_PATTERN.sub(deduct_api_key, response.url) + response.raise_for_status() + + +def deduct_api_key(match: re.Match) -> str: + key_value = match.group(KEY_VALUE_GROUP) + if len(key_value) < MIN_KEY_LENGTH_TO_REVEAL_PREFIX: + return f"api_key=***" + key_prefix = key_value[:2] + key_postfix = key_value[-2:] + return f"api_key={key_prefix}***{key_postfix}" diff --git a/inference/core/utils/roboflow.py b/inference/core/utils/roboflow.py new file mode 100644 index 0000000000000000000000000000000000000000..0fe05c2f358da0214da092fb7aced73eb2c59ce4 --- /dev/null +++ b/inference/core/utils/roboflow.py @@ -0,0 +1,11 @@ +from typing import Tuple + +from inference.core.entities.types import DatasetID, VersionID +from inference.core.exceptions import InvalidModelIDError + + +def get_model_id_chunks(model_id: str) -> Tuple[DatasetID, VersionID]: + model_id_chunks = model_id.split("/") + if len(model_id_chunks) != 2: + raise InvalidModelIDError(f"Model ID: `{model_id}` is invalid.") + return model_id_chunks[0], model_id_chunks[1] diff --git a/inference/core/utils/s3.py b/inference/core/utils/s3.py new file mode 100644 index 0000000000000000000000000000000000000000..d27c18a9af734dd45cbc071be02b5feb745e491c --- /dev/null +++ b/inference/core/utils/s3.py @@ -0,0 +1,20 @@ +import os +from typing import List + +from botocore.client import BaseClient + + +def download_s3_files_to_directory( + bucket: str, + keys: List[str], + target_dir: str, + s3_client: BaseClient, +) -> None: + os.makedirs(target_dir, exist_ok=True) + for key in keys: + target_path = os.path.join(target_dir, key) + s3_client.download_file( + bucket, + key, + target_path, + ) diff --git a/inference/core/utils/url_utils.py b/inference/core/utils/url_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..bff55ba1d01cfd8162c908ca7c98cb66536183d5 --- /dev/null +++ b/inference/core/utils/url_utils.py @@ -0,0 +1,11 @@ +import urllib + +from inference.core.env import LICENSE_SERVER + + +def wrap_url(url: str) -> str: + if not LICENSE_SERVER: + return url + return f"http://{LICENSE_SERVER}/proxy?url=" + urllib.parse.quote( + url, safe="~()*!'" + ) diff --git a/inference/core/utils/visualisation.py b/inference/core/utils/visualisation.py new file mode 100644 index 0000000000000000000000000000000000000000..82c4c0d14cc9ef4a7d6ace7043c220dee2035f97 --- /dev/null +++ b/inference/core/utils/visualisation.py @@ -0,0 +1,156 @@ +from typing import Dict, List, Tuple, Union + +import cv2 +import numpy as np + +from inference.core.entities.requests.inference import ( + InstanceSegmentationInferenceRequest, + KeypointsDetectionInferenceRequest, + ObjectDetectionInferenceRequest, +) +from inference.core.entities.responses.inference import ( + InstanceSegmentationPrediction, + Keypoint, + KeypointsPrediction, + ObjectDetectionInferenceResponse, + ObjectDetectionPrediction, + Point, +) +from inference.core.utils.image_utils import load_image_rgb, np_image_to_base64 + + +def draw_detection_predictions( + inference_request: Union[ + ObjectDetectionInferenceRequest, + InstanceSegmentationInferenceRequest, + KeypointsDetectionInferenceRequest, + ], + inference_response: Union[ + ObjectDetectionInferenceResponse, + InstanceSegmentationPrediction, + KeypointsPrediction, + ], + colors: Dict[str, str], +) -> bytes: + image = load_image_rgb(inference_request.image) + for box in inference_response.predictions: + color = tuple( + int(colors.get(box.class_name, "#4892EA")[i : i + 2], 16) for i in (1, 3, 5) + ) + image = draw_bbox( + image=image, + box=box, + color=color, + thickness=inference_request.visualization_stroke_width, + ) + if hasattr(box, "points"): + image = draw_instance_segmentation_points( + image=image, + points=box.points, + color=color, + thickness=inference_request.visualization_stroke_width, + ) + if hasattr(box, "keypoints"): + draw_keypoints( + image=image, + keypoints=box.keypoints, + color=color, + thickness=inference_request.visualization_stroke_width, + ) + if inference_request.visualization_labels: + image = draw_labels( + image=image, + box=box, + color=color, + ) + return np_image_to_base64(image=image) + + +def draw_bbox( + image: np.ndarray, + box: ObjectDetectionPrediction, + color: Tuple[int, ...], + thickness: int, +) -> np.ndarray: + left_top, right_bottom = bbox_to_points(box=box) + return cv2.rectangle( + image, + left_top, + right_bottom, + color=color, + thickness=thickness, + ) + + +def draw_instance_segmentation_points( + image: np.ndarray, + points: List[Point], + color: Tuple[int, ...], + thickness: int, +) -> np.ndarray: + points_array = np.array([(int(p.x), int(p.y)) for p in points], np.int32) + if len(points) > 2: + image = cv2.polylines( + image, + [points_array], + isClosed=True, + color=color, + thickness=thickness, + ) + return image + + +def draw_keypoints( + image: np.ndarray, + keypoints: List[Keypoint], + color: Tuple[int, ...], + thickness: int, +) -> None: + for keypoint in keypoints: + center_coordinates = (round(keypoint.x), round(keypoint.y)) + image = cv2.circle( + image, + center_coordinates, + thickness, + color, + -1, + ) + + +def draw_labels( + image: np.ndarray, + box: Union[ObjectDetectionPrediction, InstanceSegmentationPrediction], + color: Tuple[int, ...], +) -> np.ndarray: + (x1, y1), _ = bbox_to_points(box=box) + text = f"{box.class_name} {box.confidence:.2f}" + (text_width, text_height), _ = cv2.getTextSize( + text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1 + ) + button_size = (text_width + 20, text_height + 20) + button_img = np.full( + (button_size[1], button_size[0], 3), color[::-1], dtype=np.uint8 + ) + cv2.putText( + button_img, + text, + (10, 10 + text_height), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 255, 255), + 1, + ) + end_x = min(x1 + button_size[0], image.shape[1]) + end_y = min(y1 + button_size[1], image.shape[0]) + image[y1:end_y, x1:end_x] = button_img[: end_y - y1, : end_x - x1] + return image + + +def bbox_to_points( + box: Union[ObjectDetectionPrediction, InstanceSegmentationPrediction], +) -> Tuple[Tuple[int, int], Tuple[int, int]]: + x1 = int(box.x - box.width / 2) + x2 = int(box.x + box.width / 2) + y1 = int(box.y - box.height / 2) + y2 = int(box.y + box.height / 2) + return (x1, y1), (x2, y2) diff --git a/inference/core/version.py b/inference/core/version.py new file mode 100644 index 0000000000000000000000000000000000000000..4dca9902ac1235eb292427b28f813502f2fe2dc6 --- /dev/null +++ b/inference/core/version.py @@ -0,0 +1,5 @@ +__version__ = "0.9.13" + + +if __name__ == "__main__": + print(__version__) diff --git a/inference/enterprise/__init__.py b/inference/enterprise/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/enterprise/__pycache__/__init__.cpython-310.pyc b/inference/enterprise/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f510b88ff00d486a4669e33fc14f768b1eee5e35 Binary files /dev/null and b/inference/enterprise/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/enterprise/device_manager/__init__.py b/inference/enterprise/device_manager/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/enterprise/device_manager/__pycache__/__init__.cpython-310.pyc b/inference/enterprise/device_manager/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59889da3b77fd919e8720ea547e8fbdffc101e16 Binary files /dev/null and b/inference/enterprise/device_manager/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/enterprise/device_manager/__pycache__/command_handler.cpython-310.pyc b/inference/enterprise/device_manager/__pycache__/command_handler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff403ec6a7b6bf0e6b7451572f5c8a9dc9ee72bf Binary files /dev/null and b/inference/enterprise/device_manager/__pycache__/command_handler.cpython-310.pyc differ diff --git a/inference/enterprise/device_manager/__pycache__/container_service.cpython-310.pyc b/inference/enterprise/device_manager/__pycache__/container_service.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efdcf65f828f866973897f9a241b4eba6199c2e4 Binary files /dev/null and b/inference/enterprise/device_manager/__pycache__/container_service.cpython-310.pyc differ diff --git a/inference/enterprise/device_manager/__pycache__/device_manager.cpython-310.pyc b/inference/enterprise/device_manager/__pycache__/device_manager.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd7bfef1cc24a354575a0ba1b163311628be8d34 Binary files /dev/null and b/inference/enterprise/device_manager/__pycache__/device_manager.cpython-310.pyc differ diff --git a/inference/enterprise/device_manager/__pycache__/helpers.cpython-310.pyc b/inference/enterprise/device_manager/__pycache__/helpers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..990b8f0e34adfbe761c3e171764d007860397fcb Binary files /dev/null and b/inference/enterprise/device_manager/__pycache__/helpers.cpython-310.pyc differ diff --git a/inference/enterprise/device_manager/__pycache__/metrics_service.cpython-310.pyc b/inference/enterprise/device_manager/__pycache__/metrics_service.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9ca6c48d71c149f5921bbf1a2772b48b6aa8214 Binary files /dev/null and b/inference/enterprise/device_manager/__pycache__/metrics_service.cpython-310.pyc differ diff --git a/inference/enterprise/device_manager/command_handler.py b/inference/enterprise/device_manager/command_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..f11da5ec94d039d8bec70db660ab39f24c3a00f5 --- /dev/null +++ b/inference/enterprise/device_manager/command_handler.py @@ -0,0 +1,89 @@ +from typing import Literal, Optional + +import requests +from pydantic import BaseModel + +import docker +from inference.core.devices.utils import GLOBAL_DEVICE_ID +from inference.core.env import API_BASE_URL, API_KEY +from inference.core.logger import logger +from inference.core.utils.url_utils import wrap_url +from inference.enterprise.device_manager.container_service import get_container_by_id + + +class Command(BaseModel): + id: str + containerId: str + command: Literal["restart", "stop", "ping", "snapshot", "update_version"] + deviceId: str + requested_on: Optional[int] = None + + +def fetch_commands(): + url = wrap_url( + f"{API_BASE_URL}/devices/{GLOBAL_DEVICE_ID}/commands?api_key={API_KEY}" + ) + resp = requests.get(url).json() + for cmd in resp.get("data", []): + handle_command(cmd) + + +def handle_command(cmd_payload: dict): + was_processed = False + container_id = cmd_payload.get("containerId") + container = get_container_by_id(container_id) + if not container: + logger.warn(f"Container with id {container_id} not found") + ack_command(cmd_payload.get("id"), was_processed) + return + cmd = cmd_payload.get("command") + data = None + match cmd: + case "restart": + was_processed, data = container.restart() + case "stop": + was_processed, data = container.stop() + case "ping": + was_processed, data = container.ping() + case "snapshot": + was_processed, data = container.snapshot() + case "start": + was_processed, data = container.start() + case "update_version": + was_processed, data = handle_version_update(container) + case _: + logger.error("Unknown command: {}".format(cmd)) + return ack_command(cmd_payload.get("id"), was_processed, data=data) + + +def ack_command(command_id, was_processed, data=None): + post_body = dict() + post_body["api_key"] = API_KEY + post_body["commandId"] = command_id + post_body["wasProcessed"] = was_processed + if data: + post_body["data"] = data + url = wrap_url(f"{API_BASE_URL}/devices/{GLOBAL_DEVICE_ID}/commands/ack") + requests.post(url, json=post_body) + + +def handle_version_update(container): + try: + config = container.get_startup_config() + image_name = config["image"].split(":")[0] + container.kill() + client = docker.from_env() + new_container = client.containers.run( + image=f"{image_name}:latest", + detach=config["detach"], + privileged=config["privileged"], + labels=config["labels"], + ports=config["port_bindings"], + environment=config["env"], + network="host", + ) + logger.info(f"New container started {new_container}") + return True, None + except Exception as e: + logger.error(e) + return False, None diff --git a/inference/enterprise/device_manager/container_service.py b/inference/enterprise/device_manager/container_service.py new file mode 100644 index 0000000000000000000000000000000000000000..32aba4382d150b3d551c444664081e77d950e54e --- /dev/null +++ b/inference/enterprise/device_manager/container_service.py @@ -0,0 +1,282 @@ +import base64 +import time +from dataclasses import dataclass +from datetime import datetime + +import requests + +import docker +from inference.core.cache import cache +from inference.core.env import METRICS_INTERVAL +from inference.core.logger import logger +from inference.core.utils.image_utils import load_image_rgb +from inference.enterprise.device_manager.helpers import get_cache_model_items + + +@dataclass +class InferServerContainer: + status: str + id: str + port: int + host: str + startup_time: float + version: str + + def __init__(self, docker_container, details): + self.container = docker_container + self.status = details.get("status") + self.id = details.get("uuid") + self.port = details.get("port") + self.host = details.get("host") + self.version = details.get("version") + t = details.get("startup_time_ts").split(".")[0] + self.startup_time = ( + datetime.strptime(t, "%Y-%m-%dT%H:%M:%S").timestamp() + if t is not None + else datetime.now().timestamp() + ) + + def kill(self): + try: + self.container.kill() + return True, None + except Exception as e: + logger.error(e) + return False, None + + def restart(self): + try: + self.container.restart() + return True, None + except Exception as e: + logger.error(e) + return False, None + + def stop(self): + try: + self.container.stop() + return True, None + except Exception as e: + logger.error(e) + return False, None + + def start(self): + try: + self.container.start() + return True, None + except Exception as e: + logger.error(e) + return False, None + + def inspect(self): + try: + info = requests.get(f"http://{self.host}:{self.port}/info").json() + return True, info + except Exception as e: + logger.error(e) + return False, None + + def snapshot(self): + try: + snapshot = self.get_latest_inferred_images() + snapshot.update({"container_id": self.id}) + return True, snapshot + except Exception as e: + logger.error(e) + return False, None + + def get_latest_inferred_images(self, max=4): + """ + Retrieve the latest inferred images and associated information for this container. + + This method fetches the most recent inferred images within the time interval defined by METRICS_INTERVAL. + + Args: + max (int, optional): The maximum number of inferred images to retrieve. + Defaults to 4. + + Returns: + dict: A dictionary where each key represents a model ID associated with this + container, and the corresponding value is a list of dictionaries containing + information about the latest inferred images. Each dictionary has the following keys: + - "image" (str): The base64-encoded image data. + - "dimensions" (dict): Image dimensions (width and height). + - "predictions" (list): A list of predictions or results associated with the image. + + Notes: + - This method uses the global constant METRICS_INTERVAL to specify the time interval. + """ + + now = time.time() + start = now - METRICS_INTERVAL + api_keys = get_cache_model_items().get(self.id, dict()).keys() + model_ids = [] + for api_key in api_keys: + mids = get_cache_model_items().get(self.id, dict()).get(api_key, []) + model_ids.extend(mids) + num_images = 0 + latest_inferred_images = dict() + for model_id in model_ids: + if num_images >= max: + break + latest_reqs = cache.zrangebyscore( + f"inference:{self.id}:{model_id}", min=start, max=now + ) + for req in latest_reqs: + images = req["request"]["image"] + image_dims = req.get("response", {}).get("image", dict()) + predictions = req.get("response", {}).get("predictions", []) + if images is None or len(images) == 0: + continue + if type(images) is not list: + images = [images] + for image in images: + value = None + if image["type"] == "base64": + value = image["value"] + else: + loaded_image = load_image_rgb(image) + image_bytes = loaded_image.tobytes() + image_base64 = base64.b64encode(image_bytes).decode("utf-8") + value = image_base64 + if latest_inferred_images.get(model_id) is None: + latest_inferred_images[model_id] = [] + inference = dict( + image=value, dimensions=image_dims, predictions=predictions + ) + latest_inferred_images[model_id].append(inference) + num_images += 1 + return latest_inferred_images + + def get_startup_config(self): + """ + Get the startup configuration for this container. + + Returns: + dict: A dictionary containing the startup configuration for this container. + """ + env_vars = self.container.attrs.get("Config", {}).get("Env", {}) + port_bindings = self.container.attrs.get("HostConfig", {}).get( + "PortBindings", {} + ) + detached = self.container.attrs.get("HostConfig", {}).get("Detached", False) + image = self.container.attrs.get("Config", {}).get("Image", "") + privileged = self.container.attrs.get("HostConfig", {}).get("Privileged", False) + labels = self.container.attrs.get("Config", {}).get("Labels", {}) + env = [] + for var in env_vars: + name, value = var.split("=") + env.append(f"{name}={value}") + return { + "env": env, + "port_bindings": port_bindings, + "detach": detached, + "image": image, + "privileged": privileged, + "labels": labels, + # TODO: add device requests + } + + +def is_inference_server_container(container): + """ + Checks if a container is an inference server container + + Args: + container (any): A container object from the Docker SDK + + Returns: + boolean: True if the container is an inference server container, False otherwise + """ + image_tags = container.image.tags + for t in image_tags: + if t.startswith("roboflow/roboflow-inference-server"): + return True + return False + + +def get_inference_containers(): + """ + Discovers inference server containers running on the host + and parses their information into a list of InferServerContainer objects + """ + client = docker.from_env() + containers = client.containers.list() + inference_containers = [] + for c in containers: + if is_inference_server_container(c): + details = parse_container_info(c) + info = {} + try: + info = requests.get( + f"http://{details['host']}:{details['port']}/info", timeout=3 + ).json() + except Exception as e: + logger.error(f"Failed to get info from container {c.id} {details} {e}") + details.update(info) + infer_container = InferServerContainer(c, details) + if len(inference_containers) == 0: + inference_containers.append(infer_container) + continue + for ic in inference_containers: + if ic.id == infer_container.id: + continue + inference_containers.append(infer_container) + return inference_containers + + +def parse_container_info(c): + """ + Parses the container information into a dictionary + + Args: + c (any): Docker SDK Container object + + Returns: + dict: A dictionary containing the container information + """ + env = c.attrs.get("Config", {}).get("Env", {}) + info = {"container_id": c.id, "port": 9001, "host": "0.0.0.0"} + for var in env: + if var.startswith("PORT="): + info["port"] = var.split("=")[1] + elif var.startswith("HOST="): + info["host"] = var.split("=")[1] + status = c.attrs.get("State", {}).get("Status") + if status: + info["status"] = status + container_name = c.attrs.get("Name") + if container_name: + info["container_name_on_host"] = container_name + startup_time = c.attrs.get("State", {}).get("StartedAt") + if startup_time: + info["startup_time_ts"] = startup_time + return info + + +def get_container_by_id(id): + """ + Gets an inference server container by its id + + Args: + id (string): The id of the container + + Returns: + container: The container object if found, None otherwise + """ + containers = get_inference_containers() + for c in containers: + if c.id == id: + return c + return None + + +def get_container_ids(): + """ + Gets the ids of the inference server containers + + Returns: + list: A list of container ids + """ + containers = get_inference_containers() + return [c.id for c in containers] diff --git a/inference/enterprise/device_manager/device_manager.py b/inference/enterprise/device_manager/device_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..8c6275d760417ee5c71bcb90e002e1c9add11d1f --- /dev/null +++ b/inference/enterprise/device_manager/device_manager.py @@ -0,0 +1,62 @@ +from apscheduler.schedulers.background import BackgroundScheduler +from fastapi import FastAPI + +from inference.core.env import METRICS_INTERVAL +from inference.core.version import __version__ +from inference.enterprise.device_manager.command_handler import ( + Command, + fetch_commands, + handle_command, +) +from inference.enterprise.device_manager.metrics_service import ( + report_metrics_and_handle_commands, +) + +app = FastAPI( + title="Roboflow Device Manager", + description="The device manager enables remote control and monitoring of Roboflow inference server containers", + version=__version__, + terms_of_service="https://roboflow.com/terms", + contact={ + "name": "Roboflow Inc.", + "url": "https://roboflow.com/contact", + "email": "help@roboflow.com", + }, + license_info={ + "name": "Apache 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0.html", + }, + root_path="/", +) + + +@app.get("/") +def root(): + return { + "name": "Roboflow Device Manager", + "version": __version__, + "terms_of_service": "https://roboflow.com/terms", + "contact": { + "name": "Roboflow Inc.", + "url": "https://roboflow.com/contact", + "email": "help@roboflow.com", + }, + "license_info": { + "name": "Apache 2.0", + "url": "https://www.apache.org/licenses/LICENSE-2.0.html", + }, + } + + +@app.post("/exec_command") +async def exec_command(command: Command): + handle_command(command.dict()) + return {"status": "ok"} + + +scheduler = BackgroundScheduler(job_defaults={"coalesce": True, "max_instances": 3}) +scheduler.add_job( + report_metrics_and_handle_commands, "interval", seconds=METRICS_INTERVAL +) +scheduler.add_job(fetch_commands, "interval", seconds=3) +scheduler.start() diff --git a/inference/enterprise/device_manager/helpers.py b/inference/enterprise/device_manager/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..c3c23cf8c0bcbc2724fb0c94099aaa68a2c9154e --- /dev/null +++ b/inference/enterprise/device_manager/helpers.py @@ -0,0 +1,38 @@ +import time + +from inference.core.cache import cache +from inference.core.env import METRICS_INTERVAL + + +def get_cache_model_items(): + """ + Retrieve and organize cached model items within a specified time interval. + + This method queries a cache for model items and retrieves those that fall + within the time interval defined by the global constant METRICS_INTERVAL. + It organizes the retrieved items into a hierarchical dictionary structure + for efficient access. + + Returns: + dict: A dictionary containing model items organized by server ID, API key, + and model ID. The structure is as follows: + - Keys: Server IDs associated with models. + - Sub-keys: API keys associated with models on the server. + - Values: Lists of model IDs associated with each API key on the server. + + Notes: + - This method relies on a cache system for storing and retrieving model items. + - It uses the global constant METRICS_INTERVAL to specify the time interval. + """ + now = time.time() + start = now - METRICS_INTERVAL + models = cache.zrangebyscore("models", min=start, max=now) + model_items = dict() + for model in models: + server_id, api_key, model_id = model.split(":") + if server_id not in model_items: + model_items[server_id] = dict() + if api_key not in model_items[server_id]: + model_items[server_id][api_key] = [] + model_items[server_id][api_key].append(model_id) + return model_items diff --git a/inference/enterprise/device_manager/metrics_service.py b/inference/enterprise/device_manager/metrics_service.py new file mode 100644 index 0000000000000000000000000000000000000000..c193b8a4ca132ae3e7eda775c5ab65616ed852e4 --- /dev/null +++ b/inference/enterprise/device_manager/metrics_service.py @@ -0,0 +1,136 @@ +import time + +import requests + +from inference.core.devices.utils import GLOBAL_DEVICE_ID +from inference.core.env import API_KEY, METRICS_INTERVAL, METRICS_URL, TAGS +from inference.core.logger import logger +from inference.core.managers.metrics import get_model_metrics, get_system_info +from inference.core.utils.requests import api_key_safe_raise_for_status +from inference.core.version import __version__ +from inference.enterprise.device_manager.command_handler import handle_command +from inference.enterprise.device_manager.container_service import ( + get_container_by_id, + get_container_ids, +) +from inference.enterprise.device_manager.helpers import get_cache_model_items + + +def aggregate_model_stats(container_id): + """ + Aggregate statistics for models within a specified container. + + This function retrieves and aggregates performance metrics for all models + associated with the given container within a specified time interval. + + Args: + container_id (str): The unique identifier of the container for which + model statistics are to be aggregated. + + Returns: + list: A list of dictionaries, where each dictionary represents a model's + statistics with the following keys: + - "dataset_id" (str): The ID of the dataset associated with the model. + - "version" (str): The version of the model. + - "api_key" (str): The API key that was used to make an inference against this model + - "metrics" (dict): A dictionary containing performance metrics for the model: + - "num_inferences" (int): Number of inferences made + - "num_errors" (int): Number of errors + - "avg_inference_time" (float): Average inference time in seconds + + Notes: + - The function calculates statistics over a time interval defined by + the global constant METRICS_INTERVAL, passed in when starting up the container. + """ + now = time.time() + start = now - METRICS_INTERVAL + models = [] + api_keys = get_cache_model_items().get(container_id, dict()).keys() + for api_key in api_keys: + model_ids = get_cache_model_items().get(container_id, dict()).get(api_key, []) + for model_id in model_ids: + model = { + "dataset_id": model_id.split("/")[0], + "version": model_id.split("/")[1], + "api_key": api_key, + "metrics": get_model_metrics( + container_id, model_id, min=start, max=now + ), + } + models.append(model) + return models + + +def build_container_stats(): + """ + Build statistics for containers and their associated models. + + Returns: + list: A list of dictionaries, where each dictionary represents statistics + for a container and its associated models with the following keys: + - "uuid" (str): The unique identifier (UUID) of the container. + - "startup_time" (float): The timestamp representing the container's startup time. + - "models" (list): A list of dictionaries representing statistics for each + model associated with the container (see `aggregate_model_stats` for format). + + Notes: + - This method relies on a singleton `container_service` for container information. + """ + containers = [] + for id in get_container_ids(): + container = get_container_by_id(id) + if container: + container_stats = {} + models = aggregate_model_stats(id) + container_stats["uuid"] = container.id + container_stats["version"] = container.version + container_stats["startup_time"] = container.startup_time + container_stats["models"] = models + if container.status == "running" or container.status == "restarting": + container_stats["status"] = "running" + elif container.status == "exited": + container_stats["status"] = "stopped" + elif container.status == "paused": + container_stats["status"] = "idle" + else: + container_stats["status"] = "processing" + containers.append(container_stats) + return containers + + +def aggregate_device_stats(): + """ + Aggregate statistics for the device. + """ + window_start_timestamp = str(int(time.time())) + all_data = { + "api_key": API_KEY, + "timestamp": window_start_timestamp, + "device": { + "id": GLOBAL_DEVICE_ID, + "name": GLOBAL_DEVICE_ID, + "type": f"roboflow-device-manager=={__version__}", + "tags": TAGS, + "system_info": get_system_info(), + "containers": build_container_stats(), + }, + } + return all_data + + +def report_metrics_and_handle_commands(): + """ + Report metrics to Roboflow. + + This function aggregates statistics for the device and its containers and + sends them to Roboflow. If Roboflow sends back any commands, they are + handled by the `handle_command` function. + """ + all_data = aggregate_device_stats() + logger.info(f"Sending metrics to Roboflow {str(all_data)}.") + res = requests.post(METRICS_URL, json=all_data) + api_key_safe_raise_for_status(response=res) + response = res.json() + for cmd in response.get("data", []): + if cmd: + handle_command(cmd) diff --git a/inference/enterprise/parallel/__init__.py b/inference/enterprise/parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/enterprise/parallel/__pycache__/__init__.cpython-310.pyc b/inference/enterprise/parallel/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b5708944ced12b75ee6189152e9b317b371ecc9 Binary files /dev/null and b/inference/enterprise/parallel/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/enterprise/parallel/__pycache__/celeryconfig.cpython-310.pyc b/inference/enterprise/parallel/__pycache__/celeryconfig.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ad57c29ad6105486fbf71318548818f61c9d06e Binary files /dev/null and b/inference/enterprise/parallel/__pycache__/celeryconfig.cpython-310.pyc differ diff --git a/inference/enterprise/parallel/__pycache__/dispatch_manager.cpython-310.pyc b/inference/enterprise/parallel/__pycache__/dispatch_manager.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..752feb1a902c56921d002dc510accb28fb104349 Binary files /dev/null and b/inference/enterprise/parallel/__pycache__/dispatch_manager.cpython-310.pyc differ diff --git a/inference/enterprise/parallel/__pycache__/entrypoint.cpython-310.pyc b/inference/enterprise/parallel/__pycache__/entrypoint.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24e7b78efbed0d2cc33126c4af6b844638c1f74c Binary files /dev/null and b/inference/enterprise/parallel/__pycache__/entrypoint.cpython-310.pyc differ diff --git a/inference/enterprise/parallel/__pycache__/infer.cpython-310.pyc b/inference/enterprise/parallel/__pycache__/infer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83979a76bcb4c8c8ea10dc4c85093db630d6b1e0 Binary files /dev/null and b/inference/enterprise/parallel/__pycache__/infer.cpython-310.pyc differ diff --git a/inference/enterprise/parallel/__pycache__/parallel_http_api.cpython-310.pyc b/inference/enterprise/parallel/__pycache__/parallel_http_api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53256f2ba4783bd717530e649a29f210b39742dd Binary files /dev/null and b/inference/enterprise/parallel/__pycache__/parallel_http_api.cpython-310.pyc differ diff --git a/inference/enterprise/parallel/__pycache__/parallel_http_config.cpython-310.pyc b/inference/enterprise/parallel/__pycache__/parallel_http_config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ac150fe7c288cf55c905186aab6176a79b49a3d Binary files /dev/null and b/inference/enterprise/parallel/__pycache__/parallel_http_config.cpython-310.pyc differ diff --git a/inference/enterprise/parallel/__pycache__/tasks.cpython-310.pyc b/inference/enterprise/parallel/__pycache__/tasks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49cb17dbb4e95e2823107109273f048cea52c40b Binary files /dev/null and b/inference/enterprise/parallel/__pycache__/tasks.cpython-310.pyc differ diff --git a/inference/enterprise/parallel/__pycache__/utils.cpython-310.pyc b/inference/enterprise/parallel/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68d6a847d3ff0443457187fb9eb185de2f44c54c Binary files /dev/null and b/inference/enterprise/parallel/__pycache__/utils.cpython-310.pyc differ diff --git a/inference/enterprise/parallel/celeryconfig.py b/inference/enterprise/parallel/celeryconfig.py new file mode 100644 index 0000000000000000000000000000000000000000..f506d3e533c663682d07ff6c843da6c391155e2f --- /dev/null +++ b/inference/enterprise/parallel/celeryconfig.py @@ -0,0 +1,2 @@ +broker_pool_limit = 32 +redis_socket_keepalive = True diff --git a/inference/enterprise/parallel/dispatch_manager.py b/inference/enterprise/parallel/dispatch_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..d43d134a5d1e37204473df00f80c731ab475e922 --- /dev/null +++ b/inference/enterprise/parallel/dispatch_manager.py @@ -0,0 +1,151 @@ +import asyncio +from asyncio import BoundedSemaphore +from time import perf_counter, time +from typing import Any, Dict, List, Optional + +import orjson +from redis.asyncio import Redis + +from inference.core.entities.requests.inference import ( + InferenceRequest, + request_from_type, +) +from inference.core.entities.responses.inference import response_from_type +from inference.core.env import NUM_PARALLEL_TASKS +from inference.core.managers.base import ModelManager +from inference.core.registries.base import ModelRegistry +from inference.core.registries.roboflow import get_model_type +from inference.enterprise.parallel.tasks import preprocess +from inference.enterprise.parallel.utils import FAILURE_STATE, SUCCESS_STATE + + +class ResultsChecker: + """ + Class responsible for queuing asyncronous inference runs, + keeping track of running requests, and awaiting their results. + """ + + def __init__(self, redis: Redis): + self.tasks: Dict[str, asyncio.Event] = {} + self.dones = dict() + self.errors = dict() + self.running = True + self.redis = redis + self.semaphore: BoundedSemaphore = BoundedSemaphore(NUM_PARALLEL_TASKS) + + async def add_task(self, task_id: str, request: InferenceRequest): + """ + Wait until there's available cylce to queue a task. + When there are cycles, add the task's id to a list to keep track of its results, + launch the preprocess celeryt task, set the task's status to in progress in redis. + """ + await self.semaphore.acquire() + self.tasks[task_id] = asyncio.Event() + preprocess.s(request.dict()).delay() + + def get_result(self, task_id: str) -> Any: + """ + Check the done tasks and errored tasks for this task id. + """ + if task_id in self.dones: + return self.dones.pop(task_id) + elif task_id in self.errors: + message = self.errors.pop(task_id) + raise Exception(message) + else: + raise RuntimeError( + "Task result not found in either success or error dict. Unreachable" + ) + + async def loop(self): + """ + Main loop. Check all in progress tasks for their status, and if their status is final, + (either failure or success) then add their results to the appropriate results dictionary. + """ + async with self.redis.pubsub() as pubsub: + await pubsub.subscribe("results") + async for message in pubsub.listen(): + if message["type"] != "message": + continue + message = orjson.loads(message["data"]) + task_id = message.pop("task_id") + if task_id not in self.tasks: + continue + self.semaphore.release() + status = message.pop("status") + if status == FAILURE_STATE: + self.errors[task_id] = message["payload"] + elif status == SUCCESS_STATE: + self.dones[task_id] = message["payload"] + else: + raise RuntimeError( + "Task result not found in possible states. Unreachable" + ) + self.tasks[task_id].set() + await asyncio.sleep(0) + + async def wait_for_response(self, key: str): + event = self.tasks[key] + await event.wait() + del self.tasks[key] + return self.get_result(key) + + +class DispatchModelManager(ModelManager): + def __init__( + self, + model_registry: ModelRegistry, + checker: ResultsChecker, + models: Optional[dict] = None, + ): + super().__init__(model_registry, models) + self.checker = checker + + async def model_infer(self, model_id: str, request: InferenceRequest, **kwargs): + if request.visualize_predictions: + raise NotImplementedError("Visualisation of prediction is not supported") + request.start = time() + t = perf_counter() + task_type = self.get_task_type(model_id, request.api_key) + + list_mode = False + if isinstance(request.image, list): + list_mode = True + request_dict = request.dict() + images = request_dict.pop("image") + del request_dict["id"] + requests = [ + request_from_type(task_type, dict(**request_dict, image=image)) + for image in images + ] + else: + requests = [request] + + start_task_awaitables = [] + results_awaitables = [] + for r in requests: + start_task_awaitables.append(self.checker.add_task(r.id, r)) + results_awaitables.append(self.checker.wait_for_response(r.id)) + + await asyncio.gather(*start_task_awaitables) + response_jsons = await asyncio.gather(*results_awaitables) + responses = [] + for response_json in response_jsons: + response = response_from_type(task_type, response_json) + response.time = perf_counter() - t + responses.append(response) + + if list_mode: + return responses + return responses[0] + + def add_model( + self, model_id: str, api_key: str, model_id_alias: str = None + ) -> None: + pass + + def __contains__(self, model_id: str) -> bool: + return True + + def get_task_type(self, model_id: str, api_key: str = None) -> str: + return get_model_type(model_id, api_key)[0] diff --git a/inference/enterprise/parallel/entrypoint.py b/inference/enterprise/parallel/entrypoint.py new file mode 100644 index 0000000000000000000000000000000000000000..8b612e19e8a8bb810a9d3fec6eb12ecb1fa2754e --- /dev/null +++ b/inference/enterprise/parallel/entrypoint.py @@ -0,0 +1,18 @@ +import os + +from inference.core.env import ( + CELERY_LOG_LEVEL, + HOST, + NUM_CELERY_WORKERS, + NUM_WORKERS, + PORT, + REDIS_PORT, +) + +os.system( + f'redis-server --io-threads 8 --save ""--port {REDIS_PORT} &' + f"celery -A inference.enterprise.parallel.tasks worker --prefetch-multiplier=4 --concurrency={NUM_CELERY_WORKERS} -Q pre --loglevel={CELERY_LOG_LEVEL} &" + f"celery -A inference.enterprise.parallel.tasks worker --prefetch-multiplier=4 --concurrency={NUM_CELERY_WORKERS} -Q post --loglevel={CELERY_LOG_LEVEL} &" + f"python3 inference/enterprise/parallel/infer.py &" + f"gunicorn parallel_http:app --workers={NUM_WORKERS} --bind={HOST}:{PORT} -k uvicorn.workers.UvicornWorker && fg " +) diff --git a/inference/enterprise/parallel/infer.py b/inference/enterprise/parallel/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..716d358dd88a245c6aa9a51c3c637c0af58bffcf --- /dev/null +++ b/inference/enterprise/parallel/infer.py @@ -0,0 +1,207 @@ +import logging +import time +from asyncio import Queue as AioQueue +from dataclasses import asdict +from multiprocessing import shared_memory +from queue import Queue +from threading import Thread +from typing import Dict, List, Tuple + +import numpy as np +import orjson +from redis import ConnectionPool, Redis + +from inference.core.entities.requests.inference import ( + InferenceRequest, + request_from_type, +) +from inference.core.env import MAX_ACTIVE_MODELS, MAX_BATCH_SIZE, REDIS_HOST, REDIS_PORT +from inference.core.managers.base import ModelManager +from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache +from inference.core.models.roboflow import RoboflowInferenceModel +from inference.core.registries.roboflow import RoboflowModelRegistry +from inference.enterprise.parallel.tasks import postprocess +from inference.enterprise.parallel.utils import ( + SharedMemoryMetadata, + failure_handler, + shm_manager, +) + +logging.basicConfig(level=logging.WARNING) +logger = logging.getLogger() + +from inference.models.utils import ROBOFLOW_MODEL_TYPES + +BATCH_SIZE = MAX_BATCH_SIZE +if BATCH_SIZE == float("inf"): + BATCH_SIZE = 32 +AGE_TRADEOFF_SECONDS_FACTOR = 30 + + +class InferServer: + def __init__(self, redis: Redis) -> None: + self.redis = redis + model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES) + model_manager = ModelManager(model_registry) + self.model_manager = WithFixedSizeCache( + model_manager, max_size=MAX_ACTIVE_MODELS + ) + self.running = True + self.response_queue = Queue() + self.write_thread = Thread(target=self.write_responses) + self.write_thread.start() + self.batch_queue = Queue(maxsize=1) + self.infer_thread = Thread(target=self.infer) + self.infer_thread.start() + + def write_responses(self): + while True: + try: + response = self.response_queue.get() + write_infer_arrays_and_launch_postprocess(*response) + except Exception as error: + logger.warning( + f"Encountered error while writiing response:\n" + str(error) + ) + + def infer_loop(self): + while self.running: + try: + model_names = get_requested_model_names(self.redis) + if not model_names: + time.sleep(0.001) + continue + self.get_batch(model_names) + except Exception as error: + logger.warning("Encountered error in infer loop:\n" + str(error)) + continue + + def infer(self): + while True: + model_id, images, batch, preproc_return_metadatas = self.batch_queue.get() + outputs = self.model_manager.predict(model_id, images) + for output, b, metadata in zip( + zip(*outputs), batch, preproc_return_metadatas + ): + self.response_queue.put_nowait((output, b["request"], metadata)) + + def get_batch(self, model_names): + start = time.perf_counter() + batch, model_id = get_batch(self.redis, model_names) + logger.info(f"Inferring: model<{model_id}> batch_size<{len(batch)}>") + with failure_handler(self.redis, *[b["request"]["id"] for b in batch]): + self.model_manager.add_model(model_id, batch[0]["request"]["api_key"]) + model_type = self.model_manager.get_task_type(model_id) + for b in batch: + request = request_from_type(model_type, b["request"]) + b["request"] = request + b["shm_metadata"] = SharedMemoryMetadata(**b["shm_metadata"]) + + metadata_processed = time.perf_counter() + logger.info( + f"Took {(metadata_processed - start):3f} seconds to process metadata" + ) + with shm_manager( + *[b["shm_metadata"].shm_name for b in batch], unlink_on_success=True + ) as shms: + images, preproc_return_metadatas = load_batch(batch, shms) + loaded = time.perf_counter() + logger.info( + f"Took {(loaded - metadata_processed):3f} seconds to load batch" + ) + self.batch_queue.put( + (model_id, images, batch, preproc_return_metadatas) + ) + + +def get_requested_model_names(redis: Redis) -> List[str]: + request_counts = redis.hgetall("requests") + model_names = [ + model_name for model_name, count in request_counts.items() if int(count) > 0 + ] + return model_names + + +def get_batch(redis: Redis, model_names: List[str]) -> Tuple[List[Dict], str]: + """ + Run a heuristic to select the best batch to infer on + redis[Redis]: redis client + model_names[List[str]]: list of models with nonzero number of requests + returns: + Tuple[List[Dict], str] + List[Dict] represents a batch of request dicts + str is the model id + """ + batch_sizes = [ + RoboflowInferenceModel.model_metadata_from_memcache_endpoint(m)["batch_size"] + for m in model_names + ] + batch_sizes = [b if not isinstance(b, str) else BATCH_SIZE for b in batch_sizes] + batches = [ + redis.zrange(f"infer:{m}", 0, b - 1, withscores=True) + for m, b in zip(model_names, batch_sizes) + ] + model_index = select_best_inference_batch(batches, batch_sizes) + batch = batches[model_index] + selected_model = model_names[model_index] + redis.zrem(f"infer:{selected_model}", *[b[0] for b in batch]) + redis.hincrby(f"requests", selected_model, -len(batch)) + batch = [orjson.loads(b[0]) for b in batch] + return batch, selected_model + + +def select_best_inference_batch(batches, batch_sizes): + now = time.time() + average_ages = [np.mean([float(b[1]) - now for b in batch]) for batch in batches] + lengths = [ + len(batch) / batch_size for batch, batch_size in zip(batches, batch_sizes) + ] + fitnesses = [ + age / AGE_TRADEOFF_SECONDS_FACTOR + length + for age, length in zip(average_ages, lengths) + ] + model_index = fitnesses.index(max(fitnesses)) + return model_index + + +def load_batch( + batch: List[Dict[str, str]], shms: List[shared_memory.SharedMemory] +) -> Tuple[List[np.ndarray], List[Dict]]: + images = [] + preproc_return_metadatas = [] + for b, shm in zip(batch, shms): + shm_metadata: SharedMemoryMetadata = b["shm_metadata"] + image = np.ndarray( + shm_metadata.array_shape, dtype=shm_metadata.array_dtype, buffer=shm.buf + ).copy() + images.append(image) + preproc_return_metadatas.append(b["preprocess_metadata"]) + return images, preproc_return_metadatas + + +def write_infer_arrays_and_launch_postprocess( + arrs: Tuple[np.ndarray, ...], + request: InferenceRequest, + preproc_return_metadata: Dict, +): + """Write inference results to shared memory and launch the postprocessing task""" + shms = [shared_memory.SharedMemory(create=True, size=arr.nbytes) for arr in arrs] + with shm_manager(*shms): + shm_metadatas = [] + for arr, shm in zip(arrs, shms): + shared = np.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf) + shared[:] = arr[:] + shm_metadata = SharedMemoryMetadata( + shm_name=shm.name, array_shape=arr.shape, array_dtype=arr.dtype.name + ) + shm_metadatas.append(asdict(shm_metadata)) + + postprocess.s( + tuple(shm_metadatas), request.dict(), preproc_return_metadata + ).delay() + + +if __name__ == "__main__": + pool = ConnectionPool(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True) + redis = Redis(connection_pool=pool) + InferServer(redis).infer_loop() diff --git a/inference/enterprise/parallel/parallel_http_api.py b/inference/enterprise/parallel/parallel_http_api.py new file mode 100644 index 0000000000000000000000000000000000000000..c06f45940266f02c4ce52fd9ad1bc5488a96d8b6 --- /dev/null +++ b/inference/enterprise/parallel/parallel_http_api.py @@ -0,0 +1,28 @@ +import asyncio +from threading import Thread + +from redis.asyncio import Redis as AsyncRedis + +from inference.core.env import REDIS_HOST, REDIS_PORT +from inference.core.interfaces.http.http_api import HttpInterface +from inference.core.registries.roboflow import RoboflowModelRegistry +from inference.enterprise.parallel.dispatch_manager import ( + DispatchModelManager, + ResultsChecker, +) +from inference.models.utils import ROBOFLOW_MODEL_TYPES + + +class ParallelHttpInterface(HttpInterface): + def __init__(self, model_manager: DispatchModelManager, root_path: str = None): + super().__init__(model_manager, root_path) + + @self.app.on_event("startup") + async def app_startup(): + model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES) + checker = ResultsChecker(AsyncRedis(host=REDIS_HOST, port=REDIS_PORT)) + self.model_manager = DispatchModelManager(model_registry, checker) + self.model_manager.init_pingback() + task = asyncio.create_task(self.model_manager.checker.loop()) + # keep checker loop reference so it doesn't get gc'd + self.checker_loop = task diff --git a/inference/enterprise/parallel/parallel_http_config.py b/inference/enterprise/parallel/parallel_http_config.py new file mode 100644 index 0000000000000000000000000000000000000000..93c80d8792872f957419796d67347a561fef7df8 --- /dev/null +++ b/inference/enterprise/parallel/parallel_http_config.py @@ -0,0 +1,20 @@ +from redis import ConnectionPool, Redis +from redis.asyncio import Redis as AsyncRedis + +from inference.core.env import REDIS_HOST, REDIS_PORT +from inference.core.registries.roboflow import RoboflowModelRegistry +from inference.enterprise.parallel.dispatch_manager import ( + DispatchModelManager, + ResultsChecker, +) +from inference.enterprise.parallel.parallel_http_api import ParallelHttpInterface +from inference.models.utils import ROBOFLOW_MODEL_TYPES + +model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES) +if REDIS_HOST is None: + raise RuntimeError("Redis must be configured to use async inference") +pool = ConnectionPool(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True) +model_manager = None +interface = ParallelHttpInterface(model_manager) + +app = interface.app diff --git a/inference/enterprise/parallel/tasks.py b/inference/enterprise/parallel/tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..a8e0426b6958b1a52349131dfd0aaf6c23e94729 --- /dev/null +++ b/inference/enterprise/parallel/tasks.py @@ -0,0 +1,134 @@ +import json +from dataclasses import asdict +from multiprocessing import shared_memory +from typing import Dict, List, Tuple + +import numpy as np +from celery import Celery +from redis import ConnectionPool, Redis + +import inference.enterprise.parallel.celeryconfig +from inference.core.entities.requests.inference import ( + InferenceRequest, + request_from_type, +) +from inference.core.entities.responses.inference import InferenceResponse +from inference.core.env import REDIS_HOST, REDIS_PORT, STUB_CACHE_SIZE +from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache +from inference.core.managers.decorators.locked_load import ( + LockedLoadModelManagerDecorator, +) +from inference.core.managers.stub_loader import StubLoaderManager +from inference.core.registries.roboflow import RoboflowModelRegistry +from inference.enterprise.parallel.utils import ( + SUCCESS_STATE, + SharedMemoryMetadata, + failure_handler, + shm_manager, +) +from inference.models.utils import ROBOFLOW_MODEL_TYPES + +pool = ConnectionPool(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True) +app = Celery("tasks", broker=f"redis://{REDIS_HOST}:{REDIS_PORT}") +app.config_from_object(inference.enterprise.parallel.celeryconfig) +model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES) +model_manager = StubLoaderManager(model_registry) +model_manager = WithFixedSizeCache( + LockedLoadModelManagerDecorator(model_manager), max_size=STUB_CACHE_SIZE +) + + +@app.task(queue="pre") +def preprocess(request: Dict): + redis_client = Redis(connection_pool=pool) + with failure_handler(redis_client, request["id"]): + model_manager.add_model(request["model_id"], request["api_key"]) + model_type = model_manager.get_task_type(request["model_id"]) + request = request_from_type(model_type, request) + image, preprocess_return_metadata = model_manager.preprocess( + request.model_id, request + ) + # multi image requests are split into single image requests upstream and rebatched later + image = image[0] + request.image.value = None # avoid writing image again since it's in memory + shm = shared_memory.SharedMemory(create=True, size=image.nbytes) + with shm_manager(shm): + shared = np.ndarray(image.shape, dtype=image.dtype, buffer=shm.buf) + shared[:] = image[:] + shm_metadata = SharedMemoryMetadata(shm.name, image.shape, image.dtype.name) + queue_infer_task( + redis_client, shm_metadata, request, preprocess_return_metadata + ) + + +@app.task(queue="post") +def postprocess( + shm_info_list: Tuple[Dict], request: Dict, preproc_return_metadata: Dict +): + redis_client = Redis(connection_pool=pool) + shm_info_list: List[SharedMemoryMetadata] = [ + SharedMemoryMetadata(**metadata) for metadata in shm_info_list + ] + with failure_handler(redis_client, request["id"]): + with shm_manager( + *[shm_metadata.shm_name for shm_metadata in shm_info_list], + unlink_on_success=True, + ) as shms: + model_manager.add_model(request["model_id"], request["api_key"]) + model_type = model_manager.get_task_type(request["model_id"]) + request = request_from_type(model_type, request) + + outputs = load_outputs(shm_info_list, shms) + + request_dict = dict(**request.dict()) + model_id = request_dict.pop("model_id") + + response = model_manager.postprocess( + model_id, + outputs, + preproc_return_metadata, + **request_dict, + return_image_dims=True, + )[0] + + write_response(redis_client, response, request.id) + + +def load_outputs( + shm_info_list: List[SharedMemoryMetadata], shms: List[shared_memory.SharedMemory] +) -> Tuple[np.ndarray, ...]: + outputs = [] + for args, shm in zip(shm_info_list, shms): + output = np.ndarray( + [1] + args.array_shape, dtype=args.array_dtype, buffer=shm.buf + ) + outputs.append(output) + return tuple(outputs) + + +def queue_infer_task( + redis: Redis, + shm_metadata: SharedMemoryMetadata, + request: InferenceRequest, + preprocess_return_metadata: Dict, +): + return_vals = { + "shm_metadata": asdict(shm_metadata), + "request": request.dict(), + "preprocess_metadata": preprocess_return_metadata, + } + return_vals = json.dumps(return_vals) + pipe = redis.pipeline() + pipe.zadd(f"infer:{request.model_id}", {return_vals: request.start}) + pipe.hincrby(f"requests", request.model_id, 1) + pipe.execute() + + +def write_response(redis: Redis, response: InferenceResponse, request_id: str): + response = response.dict(exclude_none=True, by_alias=True) + redis.publish( + f"results", + json.dumps( + {"status": SUCCESS_STATE, "task_id": request_id, "payload": response} + ), + ) diff --git a/inference/enterprise/parallel/utils.py b/inference/enterprise/parallel/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..307f91494bc456f1e1b5c852d4b1da9cb90d4936 --- /dev/null +++ b/inference/enterprise/parallel/utils.py @@ -0,0 +1,70 @@ +import json +from contextlib import contextmanager +from dataclasses import dataclass +from multiprocessing import shared_memory +from typing import List, Union + +from redis import Redis + +SUCCESS_STATE = 1 +FAILURE_STATE = -1 + + +@contextmanager +def failure_handler(redis: Redis, *request_ids: str): + """ + Context manager that updates the status/results key in redis with exception + info on failure. + """ + try: + yield + except Exception as error: + message = type(error).__name__ + ": " + str(error) + for request_id in request_ids: + redis.publish( + "results", + json.dumps( + {"task_id": request_id, "status": FAILURE_STATE, "payload": message} + ), + ) + raise + + +@contextmanager +def shm_manager( + *shms: Union[str, shared_memory.SharedMemory], unlink_on_success: bool = False +): + """Context manager that closes and frees shared memory objects.""" + try: + loaded_shms = [] + for shm in shms: + errors = [] + try: + if isinstance(shm, str): + shm = shared_memory.SharedMemory(name=shm) + loaded_shms.append(shm) + except BaseException as error: + errors.append(error) + if errors: + raise Exception(errors) + + yield loaded_shms + except: + for shm in loaded_shms: + shm.close() + shm.unlink() + raise + else: + for shm in loaded_shms: + shm.close() + if unlink_on_success: + shm.unlink() + + +@dataclass +class SharedMemoryMetadata: + """Info needed to load array from shared memory""" + + shm_name: str + array_shape: List[int] + array_dtype: str diff --git a/inference/enterprise/stream_management/__init__.py b/inference/enterprise/stream_management/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/enterprise/stream_management/__pycache__/__init__.cpython-310.pyc b/inference/enterprise/stream_management/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86700201631731ee00927c5d541c10448f967e16 Binary files /dev/null and b/inference/enterprise/stream_management/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/enterprise/stream_management/api/__init__.py b/inference/enterprise/stream_management/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/enterprise/stream_management/api/__pycache__/__init__.cpython-310.pyc b/inference/enterprise/stream_management/api/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27fdd405aa36b6cfa92156ad7f85ce93d81b005d Binary files /dev/null and b/inference/enterprise/stream_management/api/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/enterprise/stream_management/api/__pycache__/app.cpython-310.pyc b/inference/enterprise/stream_management/api/__pycache__/app.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b91a246c68fc80a153a8bcfdcdd934f4a89acc7 Binary files /dev/null and b/inference/enterprise/stream_management/api/__pycache__/app.cpython-310.pyc differ diff --git a/inference/enterprise/stream_management/api/__pycache__/entities.cpython-310.pyc b/inference/enterprise/stream_management/api/__pycache__/entities.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c7450a2065addf7ddaef69633d8a084c391ce1e Binary files /dev/null and b/inference/enterprise/stream_management/api/__pycache__/entities.cpython-310.pyc differ diff --git a/inference/enterprise/stream_management/api/__pycache__/errors.cpython-310.pyc b/inference/enterprise/stream_management/api/__pycache__/errors.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..435898589ff7eb9b0222df2a62a15fdc1a473244 Binary files /dev/null and b/inference/enterprise/stream_management/api/__pycache__/errors.cpython-310.pyc differ diff --git a/inference/enterprise/stream_management/api/__pycache__/stream_manager_client.cpython-310.pyc b/inference/enterprise/stream_management/api/__pycache__/stream_manager_client.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19cfd25f804d426362a3122491bd1dd50640d041 Binary files /dev/null and b/inference/enterprise/stream_management/api/__pycache__/stream_manager_client.cpython-310.pyc differ diff --git a/inference/enterprise/stream_management/api/app.py b/inference/enterprise/stream_management/api/app.py new file mode 100644 index 0000000000000000000000000000000000000000..0b228d7e0d38ece98485d93c7d80f18251a8af0e --- /dev/null +++ b/inference/enterprise/stream_management/api/app.py @@ -0,0 +1,178 @@ +import os +from functools import wraps +from typing import Any, Awaitable, Callable + +import uvicorn +from fastapi import FastAPI, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +from inference.core import logger +from inference.enterprise.stream_management.api.entities import ( + CommandResponse, + InferencePipelineStatusResponse, + ListPipelinesResponse, + PipelineInitialisationRequest, +) +from inference.enterprise.stream_management.api.errors import ( + ConnectivityError, + ProcessesManagerAuthorisationError, + ProcessesManagerClientError, + ProcessesManagerInvalidPayload, + ProcessesManagerNotFoundError, +) +from inference.enterprise.stream_management.api.stream_manager_client import ( + StreamManagerClient, +) +from inference.enterprise.stream_management.manager.entities import ( + STATUS_KEY, + OperationStatus, +) + +API_HOST = os.getenv("STREAM_MANAGEMENT_API_HOST", "127.0.0.1") +API_PORT = int(os.getenv("STREAM_MANAGEMENT_API_PORT", "8080")) + +OPERATIONS_TIMEOUT = os.getenv("STREAM_MANAGER_OPERATIONS_TIMEOUT") +if OPERATIONS_TIMEOUT is not None: + OPERATIONS_TIMEOUT = float(OPERATIONS_TIMEOUT) + +STREAM_MANAGER_CLIENT = StreamManagerClient.init( + host=os.getenv("STREAM_MANAGER_HOST", "127.0.0.1"), + port=int(os.getenv("STREAM_MANAGER_PORT", "7070")), + operations_timeout=OPERATIONS_TIMEOUT, +) + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +def with_route_exceptions(route: callable) -> Callable[[Any], Awaitable[JSONResponse]]: + @wraps(route) + async def wrapped_route(*args, **kwargs): + try: + return await route(*args, **kwargs) + except ProcessesManagerInvalidPayload as error: + resp = JSONResponse( + status_code=400, + content={STATUS_KEY: OperationStatus.FAILURE, "message": str(error)}, + ) + logger.exception("Processes Manager - invalid payload error") + return resp + except ProcessesManagerAuthorisationError as error: + resp = JSONResponse( + status_code=401, + content={STATUS_KEY: OperationStatus.FAILURE, "message": str(error)}, + ) + logger.exception("Processes Manager - authorisation error") + return resp + except ProcessesManagerNotFoundError as error: + resp = JSONResponse( + status_code=404, + content={STATUS_KEY: OperationStatus.FAILURE, "message": str(error)}, + ) + logger.exception("Processes Manager - not found error") + return resp + except ConnectivityError as error: + resp = JSONResponse( + status_code=503, + content={STATUS_KEY: OperationStatus.FAILURE, "message": str(error)}, + ) + logger.exception("Processes Manager connectivity error occurred") + return resp + except ProcessesManagerClientError as error: + resp = JSONResponse( + status_code=500, + content={STATUS_KEY: OperationStatus.FAILURE, "message": str(error)}, + ) + logger.exception("Processes Manager error occurred") + return resp + except Exception: + resp = JSONResponse( + status_code=500, + content={ + STATUS_KEY: OperationStatus.FAILURE, + "message": "Internal error.", + }, + ) + logger.exception("Internal error in API") + return resp + + return wrapped_route + + +@app.get( + "/list_pipelines", + response_model=ListPipelinesResponse, + summary="List active pipelines", + description="Listing all active pipelines in the state of ProcessesManager being queried.", +) +@with_route_exceptions +async def list_pipelines(_: Request) -> ListPipelinesResponse: + return await STREAM_MANAGER_CLIENT.list_pipelines() + + +@app.get( + "/status/{pipeline_id}", + response_model=InferencePipelineStatusResponse, + summary="Get status of pipeline", + description="Returns detailed statis of Inference Pipeline in the state of ProcessesManager being queried.", +) +@with_route_exceptions +async def get_status(pipeline_id: str) -> InferencePipelineStatusResponse: + return await STREAM_MANAGER_CLIENT.get_status(pipeline_id=pipeline_id) + + +@app.post( + "/initialise", + response_model=CommandResponse, + summary="Initialise the pipeline", + description="Starts new Inference Pipeline within the state of ProcessesManager being queried.", +) +@with_route_exceptions +async def initialise(request: PipelineInitialisationRequest) -> CommandResponse: + return await STREAM_MANAGER_CLIENT.initialise_pipeline( + initialisation_request=request + ) + + +@app.post( + "/pause/{pipeline_id}", + response_model=CommandResponse, + summary="Pauses the pipeline processing", + description="Mutes the VideoSource of Inference Pipeline within the state of ProcessesManager being queried.", +) +@with_route_exceptions +async def pause(pipeline_id: str) -> CommandResponse: + return await STREAM_MANAGER_CLIENT.pause_pipeline(pipeline_id=pipeline_id) + + +@app.post( + "/resume/{pipeline_id}", + response_model=CommandResponse, + summary="Resumes the pipeline processing", + description="Resumes the VideoSource of Inference Pipeline within the state of ProcessesManager being queried.", +) +@with_route_exceptions +async def resume(pipeline_id: str) -> CommandResponse: + return await STREAM_MANAGER_CLIENT.resume_pipeline(pipeline_id=pipeline_id) + + +@app.post( + "/terminate/{pipeline_id}", + response_model=CommandResponse, + summary="Terminates the pipeline processing", + description="Terminates the VideoSource of Inference Pipeline within the state of ProcessesManager being queried.", +) +@with_route_exceptions +async def terminate(pipeline_id: str) -> CommandResponse: + return await STREAM_MANAGER_CLIENT.terminate_pipeline(pipeline_id=pipeline_id) + + +if __name__ == "__main__": + uvicorn.run(app, host=API_HOST, port=API_PORT) diff --git a/inference/enterprise/stream_management/api/entities.py b/inference/enterprise/stream_management/api/entities.py new file mode 100644 index 0000000000000000000000000000000000000000..16648a6e876f5202f94cfd1121f2cdd96b55c71b --- /dev/null +++ b/inference/enterprise/stream_management/api/entities.py @@ -0,0 +1,95 @@ +from typing import List, Optional, Union + +from pydantic import BaseModel, Field + +from inference.core.interfaces.camera.video_source import ( + BufferConsumptionStrategy, + BufferFillingStrategy, +) + + +class UDPSinkConfiguration(BaseModel): + type: str = Field( + description="Type identifier field. Must be `udp_sink`", default="udp_sink" + ) + host: str = Field(description="Host of UDP sink.") + port: int = Field(description="Port of UDP sink.") + + +class ObjectDetectionModelConfiguration(BaseModel): + type: str = Field( + description="Type identifier field. Must be `object-detection`", + default="object-detection", + ) + class_agnostic_nms: Optional[bool] = Field( + description="Flag to decide if class agnostic NMS to be applied. If not given, default or InferencePipeline host env will be used.", + default=None, + ) + confidence: Optional[float] = Field( + description="Confidence threshold for predictions. If not given, default or InferencePipeline host env will be used.", + default=None, + ) + iou_threshold: Optional[float] = Field( + description="IoU threshold of post-processing. If not given, default or InferencePipeline host env will be used.", + default=None, + ) + max_candidates: Optional[int] = Field( + description="Max candidates in post-processing. If not given, default or InferencePipeline host env will be used.", + default=None, + ) + max_detections: Optional[int] = Field( + description="Max detections in post-processing. If not given, default or InferencePipeline host env will be used.", + default=None, + ) + + +class PipelineInitialisationRequest(BaseModel): + model_id: str = Field(description="Roboflow model id") + video_reference: Union[str, int] = Field( + description="Reference to video source - either stream, video file or device. It must be accessible from the host running inference stream" + ) + sink_configuration: UDPSinkConfiguration = Field( + description="Configuration of the sink." + ) + api_key: Optional[str] = Field(description="Roboflow API key", default=None) + max_fps: Optional[Union[float, int]] = Field( + description="Limit of FPS in video processing.", default=None + ) + source_buffer_filling_strategy: Optional[str] = Field( + description=f"`source_buffer_filling_strategy` parameter of Inference Pipeline (see docs). One of {[e.value for e in BufferFillingStrategy]}", + default=None, + ) + source_buffer_consumption_strategy: Optional[str] = Field( + description=f"`source_buffer_consumption_strategy` parameter of Inference Pipeline (see docs). One of {[e.value for e in BufferConsumptionStrategy]}", + default=None, + ) + model_configuration: ObjectDetectionModelConfiguration = Field( + description="Configuration of the model", + default_factory=ObjectDetectionModelConfiguration, + ) + active_learning_enabled: Optional[bool] = Field( + description="Flag to decide if Active Learning middleware should be enabled. If not given - env variable `ACTIVE_LEARNING_ENABLED` will be used (with default `True`).", + default=None, + ) + + +class CommandContext(BaseModel): + request_id: Optional[str] = Field( + description="Server-side request ID", default=None + ) + pipeline_id: Optional[str] = Field( + description="Identifier of pipeline connected to operation", default=None + ) + + +class CommandResponse(BaseModel): + status: str = Field(description="Operation status") + context: CommandContext = Field(description="Context of the command.") + + +class InferencePipelineStatusResponse(CommandResponse): + report: dict + + +class ListPipelinesResponse(CommandResponse): + pipelines: List[str] = Field(description="List IDs of active pipelines") diff --git a/inference/enterprise/stream_management/api/errors.py b/inference/enterprise/stream_management/api/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..3669bac2a06b88c6bd6cb6d7f0d74e376a5c4f62 --- /dev/null +++ b/inference/enterprise/stream_management/api/errors.py @@ -0,0 +1,26 @@ +class ProcessesManagerClientError(Exception): + pass + + +class ConnectivityError(ProcessesManagerClientError): + pass + + +class ProcessesManagerInternalError(ProcessesManagerClientError): + pass + + +class ProcessesManagerOperationError(ProcessesManagerClientError): + pass + + +class ProcessesManagerInvalidPayload(ProcessesManagerClientError): + pass + + +class ProcessesManagerNotFoundError(ProcessesManagerClientError): + pass + + +class ProcessesManagerAuthorisationError(ProcessesManagerClientError): + pass diff --git a/inference/enterprise/stream_management/api/stream_manager_client.py b/inference/enterprise/stream_management/api/stream_manager_client.py new file mode 100644 index 0000000000000000000000000000000000000000..e89e7e914eebea734e52108de9b495d5c3ef7ce8 --- /dev/null +++ b/inference/enterprise/stream_management/api/stream_manager_client.py @@ -0,0 +1,288 @@ +import asyncio +import json +from asyncio import StreamReader, StreamWriter +from json import JSONDecodeError +from typing import Optional, Tuple + +from inference.core import logger +from inference.enterprise.stream_management.api.entities import ( + CommandContext, + CommandResponse, + InferencePipelineStatusResponse, + ListPipelinesResponse, + PipelineInitialisationRequest, +) +from inference.enterprise.stream_management.api.errors import ( + ConnectivityError, + ProcessesManagerAuthorisationError, + ProcessesManagerClientError, + ProcessesManagerInternalError, + ProcessesManagerInvalidPayload, + ProcessesManagerNotFoundError, + ProcessesManagerOperationError, +) +from inference.enterprise.stream_management.manager.entities import ( + ERROR_TYPE_KEY, + PIPELINE_ID_KEY, + REQUEST_ID_KEY, + RESPONSE_KEY, + STATUS_KEY, + TYPE_KEY, + CommandType, + ErrorType, + OperationStatus, +) +from inference.enterprise.stream_management.manager.errors import ( + CommunicationProtocolError, + MalformedHeaderError, + MalformedPayloadError, + MessageToBigError, + TransmissionChannelClosed, +) + +BUFFER_SIZE = 16384 +HEADER_SIZE = 4 + +ERRORS_MAPPING = { + ErrorType.INTERNAL_ERROR.value: ProcessesManagerInternalError, + ErrorType.INVALID_PAYLOAD.value: ProcessesManagerInvalidPayload, + ErrorType.NOT_FOUND.value: ProcessesManagerNotFoundError, + ErrorType.OPERATION_ERROR.value: ProcessesManagerOperationError, + ErrorType.AUTHORISATION_ERROR.value: ProcessesManagerAuthorisationError, +} + + +class StreamManagerClient: + @classmethod + def init( + cls, + host: str, + port: int, + operations_timeout: Optional[float] = None, + header_size: int = HEADER_SIZE, + buffer_size: int = BUFFER_SIZE, + ) -> "StreamManagerClient": + return cls( + host=host, + port=port, + operations_timeout=operations_timeout, + header_size=header_size, + buffer_size=buffer_size, + ) + + def __init__( + self, + host: str, + port: int, + operations_timeout: Optional[float], + header_size: int, + buffer_size: int, + ): + self._host = host + self._port = port + self._operations_timeout = operations_timeout + self._header_size = header_size + self._buffer_size = buffer_size + + async def list_pipelines(self) -> ListPipelinesResponse: + command = { + TYPE_KEY: CommandType.LIST_PIPELINES, + } + response = await self._handle_command(command=command) + status = response[RESPONSE_KEY][STATUS_KEY] + context = CommandContext( + request_id=response.get(REQUEST_ID_KEY), + pipeline_id=response.get(PIPELINE_ID_KEY), + ) + pipelines = response[RESPONSE_KEY]["pipelines"] + return ListPipelinesResponse( + status=status, + context=context, + pipelines=pipelines, + ) + + async def initialise_pipeline( + self, initialisation_request: PipelineInitialisationRequest + ) -> CommandResponse: + command = initialisation_request.dict(exclude_none=True) + command[TYPE_KEY] = CommandType.INIT + response = await self._handle_command(command=command) + return build_response(response=response) + + async def terminate_pipeline(self, pipeline_id: str) -> CommandResponse: + command = { + TYPE_KEY: CommandType.TERMINATE, + PIPELINE_ID_KEY: pipeline_id, + } + response = await self._handle_command(command=command) + return build_response(response=response) + + async def pause_pipeline(self, pipeline_id: str) -> CommandResponse: + command = { + TYPE_KEY: CommandType.MUTE, + PIPELINE_ID_KEY: pipeline_id, + } + response = await self._handle_command(command=command) + return build_response(response=response) + + async def resume_pipeline(self, pipeline_id: str) -> CommandResponse: + command = { + TYPE_KEY: CommandType.RESUME, + PIPELINE_ID_KEY: pipeline_id, + } + response = await self._handle_command(command=command) + return build_response(response=response) + + async def get_status(self, pipeline_id: str) -> InferencePipelineStatusResponse: + command = { + TYPE_KEY: CommandType.STATUS, + PIPELINE_ID_KEY: pipeline_id, + } + response = await self._handle_command(command=command) + status = response[RESPONSE_KEY][STATUS_KEY] + context = CommandContext( + request_id=response.get(REQUEST_ID_KEY), + pipeline_id=response.get(PIPELINE_ID_KEY), + ) + report = response[RESPONSE_KEY]["report"] + return InferencePipelineStatusResponse( + status=status, + context=context, + report=report, + ) + + async def _handle_command(self, command: dict) -> dict: + response = await send_command( + host=self._host, + port=self._port, + command=command, + header_size=self._header_size, + buffer_size=self._buffer_size, + timeout=self._operations_timeout, + ) + if is_request_unsuccessful(response=response): + dispatch_error(error_response=response) + return response + + +async def send_command( + host: str, + port: int, + command: dict, + header_size: int, + buffer_size: int, + timeout: Optional[float] = None, +) -> dict: + try: + reader, writer = await establish_socket_connection( + host=host, port=port, timeout=timeout + ) + await send_message( + writer=writer, message=command, header_size=header_size, timeout=timeout + ) + data = await receive_message( + reader, header_size=header_size, buffer_size=buffer_size, timeout=timeout + ) + writer.close() + await writer.wait_closed() + return json.loads(data) + except JSONDecodeError as error: + raise MalformedPayloadError( + f"Could not decode response. Cause: {error}" + ) from error + except (OSError, asyncio.TimeoutError) as errors: + raise ConnectivityError( + f"Could not communicate with Process Manager" + ) from errors + + +async def establish_socket_connection( + host: str, port: int, timeout: Optional[float] = None +) -> Tuple[StreamReader, StreamWriter]: + return await asyncio.wait_for(asyncio.open_connection(host, port), timeout=timeout) + + +async def send_message( + writer: StreamWriter, + message: dict, + header_size: int, + timeout: Optional[float] = None, +) -> None: + try: + body = json.dumps(message).encode("utf-8") + header = len(body).to_bytes(length=header_size, byteorder="big") + payload = header + body + writer.write(payload) + await asyncio.wait_for(writer.drain(), timeout=timeout) + except TypeError as error: + raise MalformedPayloadError(f"Could not serialise message. Details: {error}") + except OverflowError as error: + raise MessageToBigError( + f"Could not send message due to size overflow. Details: {error}" + ) + except asyncio.TimeoutError as error: + raise ConnectivityError( + f"Could not communicate with Process Manager" + ) from error + except Exception as error: + raise CommunicationProtocolError( + f"Could not send message. Cause: {error}" + ) from error + + +async def receive_message( + reader: StreamReader, + header_size: int, + buffer_size: int, + timeout: Optional[float] = None, +) -> bytes: + header = await asyncio.wait_for(reader.read(header_size), timeout=timeout) + if len(header) != header_size: + raise MalformedHeaderError("Header size missmatch") + payload_size = int.from_bytes(bytes=header, byteorder="big") + received = b"" + while len(received) < payload_size: + chunk = await asyncio.wait_for(reader.read(buffer_size), timeout=timeout) + if len(chunk) == 0: + raise TransmissionChannelClosed( + "Socket was closed to read before payload was decoded." + ) + received += chunk + return received + + +def is_request_unsuccessful(response: dict) -> bool: + return ( + response.get(RESPONSE_KEY, {}).get(STATUS_KEY, OperationStatus.FAILURE.value) + != OperationStatus.SUCCESS.value + ) + + +def dispatch_error(error_response: dict) -> None: + response_payload = error_response.get(RESPONSE_KEY, {}) + error_type = response_payload.get(ERROR_TYPE_KEY) + error_class = response_payload.get("error_class", "N/A") + error_message = response_payload.get("error_message", "N/A") + logger.error( + f"Error in ProcessesManagerClient. error_type={error_type} error_class={error_class} " + f"error_message={error_message}" + ) + if error_type in ERRORS_MAPPING: + raise ERRORS_MAPPING[error_type]( + f"Error in ProcessesManagerClient. Error type: {error_type}. Details: {error_message}" + ) + raise ProcessesManagerClientError( + f"Error in ProcessesManagerClient. Error type: {error_type}. Details: {error_message}" + ) + + +def build_response(response: dict) -> CommandResponse: + status = response[RESPONSE_KEY][STATUS_KEY] + context = CommandContext( + request_id=response.get(REQUEST_ID_KEY), + pipeline_id=response.get(PIPELINE_ID_KEY), + ) + return CommandResponse( + status=status, + context=context, + ) diff --git a/inference/enterprise/stream_management/manager/__init__.py b/inference/enterprise/stream_management/manager/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/enterprise/stream_management/manager/__pycache__/__init__.cpython-310.pyc b/inference/enterprise/stream_management/manager/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b8fcba12692763e08667b2c169fd15a9266a5d8 Binary files /dev/null and b/inference/enterprise/stream_management/manager/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/enterprise/stream_management/manager/__pycache__/app.cpython-310.pyc b/inference/enterprise/stream_management/manager/__pycache__/app.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6af3f479299c1b2942014aa1fb1ff0140d9e7533 Binary files /dev/null and b/inference/enterprise/stream_management/manager/__pycache__/app.cpython-310.pyc differ diff --git a/inference/enterprise/stream_management/manager/__pycache__/communication.cpython-310.pyc b/inference/enterprise/stream_management/manager/__pycache__/communication.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b382f11e6921c594e12dc956eeeb118a9666f05 Binary files /dev/null and b/inference/enterprise/stream_management/manager/__pycache__/communication.cpython-310.pyc differ diff --git a/inference/enterprise/stream_management/manager/__pycache__/entities.cpython-310.pyc b/inference/enterprise/stream_management/manager/__pycache__/entities.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2ae88aad62fbaa9d72f868a04ac84d2956b84ed Binary files /dev/null and b/inference/enterprise/stream_management/manager/__pycache__/entities.cpython-310.pyc differ diff --git a/inference/enterprise/stream_management/manager/__pycache__/errors.cpython-310.pyc b/inference/enterprise/stream_management/manager/__pycache__/errors.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..252110b07b52dd035ea9a12f772cb07d9f97a0bd Binary files /dev/null and b/inference/enterprise/stream_management/manager/__pycache__/errors.cpython-310.pyc differ diff --git a/inference/enterprise/stream_management/manager/__pycache__/inference_pipeline_manager.cpython-310.pyc b/inference/enterprise/stream_management/manager/__pycache__/inference_pipeline_manager.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2180f2be0546824e3f262d8323378f444ac7af8 Binary files /dev/null and b/inference/enterprise/stream_management/manager/__pycache__/inference_pipeline_manager.cpython-310.pyc differ diff --git a/inference/enterprise/stream_management/manager/__pycache__/serialisation.cpython-310.pyc b/inference/enterprise/stream_management/manager/__pycache__/serialisation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a7fe848014d558a2c1924588637782a07b73146 Binary files /dev/null and b/inference/enterprise/stream_management/manager/__pycache__/serialisation.cpython-310.pyc differ diff --git a/inference/enterprise/stream_management/manager/__pycache__/tcp_server.cpython-310.pyc b/inference/enterprise/stream_management/manager/__pycache__/tcp_server.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..903f7d8707b64a3b762fa7759cb804969990bc1d Binary files /dev/null and b/inference/enterprise/stream_management/manager/__pycache__/tcp_server.cpython-310.pyc differ diff --git a/inference/enterprise/stream_management/manager/app.py b/inference/enterprise/stream_management/manager/app.py new file mode 100644 index 0000000000000000000000000000000000000000..7c5f836ba8cfe5c52131f1a85076a81a402f6e0a --- /dev/null +++ b/inference/enterprise/stream_management/manager/app.py @@ -0,0 +1,273 @@ +import os +import signal +import socket +import sys +from functools import partial +from multiprocessing import Process, Queue +from socketserver import BaseRequestHandler, BaseServer +from types import FrameType +from typing import Any, Dict, Optional, Tuple +from uuid import uuid4 + +from inference.core import logger +from inference.enterprise.stream_management.manager.communication import ( + receive_socket_data, + send_data_trough_socket, +) +from inference.enterprise.stream_management.manager.entities import ( + PIPELINE_ID_KEY, + STATUS_KEY, + TYPE_KEY, + CommandType, + ErrorType, + OperationStatus, +) +from inference.enterprise.stream_management.manager.errors import MalformedPayloadError +from inference.enterprise.stream_management.manager.inference_pipeline_manager import ( + InferencePipelineManager, +) +from inference.enterprise.stream_management.manager.serialisation import ( + describe_error, + prepare_error_response, + prepare_response, +) +from inference.enterprise.stream_management.manager.tcp_server import RoboflowTCPServer + +PROCESSES_TABLE: Dict[str, Tuple[Process, Queue, Queue]] = {} +HEADER_SIZE = 4 +SOCKET_BUFFER_SIZE = 16384 +HOST = os.getenv("STREAM_MANAGER_HOST", "127.0.0.1") +PORT = int(os.getenv("STREAM_MANAGER_PORT", "7070")) +SOCKET_TIMEOUT = float(os.getenv("STREAM_MANAGER_SOCKET_TIMEOUT", "5.0")) + + +class InferencePipelinesManagerHandler(BaseRequestHandler): + def __init__( + self, + request: socket.socket, + client_address: Any, + server: BaseServer, + processes_table: Dict[str, Tuple[Process, Queue, Queue]], + ): + self._processes_table = processes_table # in this case it's required to set the state of class before superclass init - as it invokes handle() + super().__init__(request, client_address, server) + + def handle(self) -> None: + pipeline_id: Optional[str] = None + request_id = str(uuid4()) + try: + data = receive_socket_data( + source=self.request, + header_size=HEADER_SIZE, + buffer_size=SOCKET_BUFFER_SIZE, + ) + data[TYPE_KEY] = CommandType(data[TYPE_KEY]) + if data[TYPE_KEY] is CommandType.LIST_PIPELINES: + return self._list_pipelines(request_id=request_id) + if data[TYPE_KEY] is CommandType.INIT: + return self._initialise_pipeline(request_id=request_id, command=data) + pipeline_id = data[PIPELINE_ID_KEY] + if data[TYPE_KEY] is CommandType.TERMINATE: + self._terminate_pipeline( + request_id=request_id, pipeline_id=pipeline_id, command=data + ) + else: + response = handle_command( + processes_table=self._processes_table, + request_id=request_id, + pipeline_id=pipeline_id, + command=data, + ) + serialised_response = prepare_response( + request_id=request_id, response=response, pipeline_id=pipeline_id + ) + send_data_trough_socket( + target=self.request, + header_size=HEADER_SIZE, + data=serialised_response, + request_id=request_id, + pipeline_id=pipeline_id, + ) + except (KeyError, ValueError, MalformedPayloadError) as error: + logger.error( + f"Invalid payload in processes manager. error={error} request_id={request_id}..." + ) + payload = prepare_error_response( + request_id=request_id, + error=error, + error_type=ErrorType.INVALID_PAYLOAD, + pipeline_id=pipeline_id, + ) + send_data_trough_socket( + target=self.request, + header_size=HEADER_SIZE, + data=payload, + request_id=request_id, + pipeline_id=pipeline_id, + ) + except Exception as error: + logger.error( + f"Internal error in processes manager. error={error} request_id={request_id}..." + ) + payload = prepare_error_response( + request_id=request_id, + error=error, + error_type=ErrorType.INTERNAL_ERROR, + pipeline_id=pipeline_id, + ) + send_data_trough_socket( + target=self.request, + header_size=HEADER_SIZE, + data=payload, + request_id=request_id, + pipeline_id=pipeline_id, + ) + + def _list_pipelines(self, request_id: str) -> None: + serialised_response = prepare_response( + request_id=request_id, + response={ + "pipelines": list(self._processes_table.keys()), + STATUS_KEY: OperationStatus.SUCCESS, + }, + pipeline_id=None, + ) + send_data_trough_socket( + target=self.request, + header_size=HEADER_SIZE, + data=serialised_response, + request_id=request_id, + ) + + def _initialise_pipeline(self, request_id: str, command: dict) -> None: + pipeline_id = str(uuid4()) + command_queue = Queue() + responses_queue = Queue() + inference_pipeline_manager = InferencePipelineManager.init( + command_queue=command_queue, + responses_queue=responses_queue, + ) + inference_pipeline_manager.start() + self._processes_table[pipeline_id] = ( + inference_pipeline_manager, + command_queue, + responses_queue, + ) + command_queue.put((request_id, command)) + response = get_response_ignoring_thrash( + responses_queue=responses_queue, matching_request_id=request_id + ) + serialised_response = prepare_response( + request_id=request_id, response=response, pipeline_id=pipeline_id + ) + send_data_trough_socket( + target=self.request, + header_size=HEADER_SIZE, + data=serialised_response, + request_id=request_id, + pipeline_id=pipeline_id, + ) + + def _terminate_pipeline( + self, request_id: str, pipeline_id: str, command: dict + ) -> None: + response = handle_command( + processes_table=self._processes_table, + request_id=request_id, + pipeline_id=pipeline_id, + command=command, + ) + if response[STATUS_KEY] is OperationStatus.SUCCESS: + logger.info( + f"Joining inference pipeline. pipeline_id={pipeline_id} request_id={request_id}" + ) + join_inference_pipeline( + processes_table=self._processes_table, pipeline_id=pipeline_id + ) + logger.info( + f"Joined inference pipeline. pipeline_id={pipeline_id} request_id={request_id}" + ) + serialised_response = prepare_response( + request_id=request_id, response=response, pipeline_id=pipeline_id + ) + send_data_trough_socket( + target=self.request, + header_size=HEADER_SIZE, + data=serialised_response, + request_id=request_id, + pipeline_id=pipeline_id, + ) + + +def handle_command( + processes_table: Dict[str, Tuple[Process, Queue, Queue]], + request_id: str, + pipeline_id: str, + command: dict, +) -> dict: + if pipeline_id not in processes_table: + return describe_error(exception=None, error_type=ErrorType.NOT_FOUND) + _, command_queue, responses_queue = processes_table[pipeline_id] + command_queue.put((request_id, command)) + return get_response_ignoring_thrash( + responses_queue=responses_queue, matching_request_id=request_id + ) + + +def get_response_ignoring_thrash( + responses_queue: Queue, matching_request_id: str +) -> dict: + while True: + response = responses_queue.get() + if response[0] == matching_request_id: + return response[1] + logger.warning( + f"Dropping response for request_id={response[0]} with payload={response[1]}" + ) + + +def execute_termination( + signal_number: int, + frame: FrameType, + processes_table: Dict[str, Tuple[Process, Queue, Queue]], +) -> None: + pipeline_ids = list(processes_table.keys()) + for pipeline_id in pipeline_ids: + logger.info(f"Terminating pipeline: {pipeline_id}") + processes_table[pipeline_id][0].terminate() + logger.info(f"Pipeline: {pipeline_id} terminated.") + logger.info(f"Joining pipeline: {pipeline_id}") + processes_table[pipeline_id][0].join() + logger.info(f"Pipeline: {pipeline_id} joined.") + logger.info(f"Termination handler completed.") + sys.exit(0) + + +def join_inference_pipeline( + processes_table: Dict[str, Tuple[Process, Queue, Queue]], pipeline_id: str +) -> None: + inference_pipeline_manager, command_queue, responses_queue = processes_table[ + pipeline_id + ] + inference_pipeline_manager.join() + del processes_table[pipeline_id] + + +if __name__ == "__main__": + signal.signal( + signal.SIGINT, partial(execute_termination, processes_table=PROCESSES_TABLE) + ) + signal.signal( + signal.SIGTERM, partial(execute_termination, processes_table=PROCESSES_TABLE) + ) + with RoboflowTCPServer( + server_address=(HOST, PORT), + handler_class=partial( + InferencePipelinesManagerHandler, processes_table=PROCESSES_TABLE + ), + socket_operations_timeout=SOCKET_TIMEOUT, + ) as tcp_server: + logger.info( + f"Inference Pipeline Processes Manager is ready to accept connections at {(HOST, PORT)}" + ) + tcp_server.serve_forever() diff --git a/inference/enterprise/stream_management/manager/communication.py b/inference/enterprise/stream_management/manager/communication.py new file mode 100644 index 0000000000000000000000000000000000000000..0c5f8884960d05b0220f45a86d70fea7b6572eb5 --- /dev/null +++ b/inference/enterprise/stream_management/manager/communication.py @@ -0,0 +1,76 @@ +import json +import socket +from typing import Optional + +from inference.core import logger +from inference.enterprise.stream_management.manager.entities import ErrorType +from inference.enterprise.stream_management.manager.errors import ( + MalformedHeaderError, + MalformedPayloadError, + TransmissionChannelClosed, +) +from inference.enterprise.stream_management.manager.serialisation import ( + prepare_error_response, +) + + +def receive_socket_data( + source: socket.socket, header_size: int, buffer_size: int +) -> dict: + header = source.recv(header_size) + if len(header) != header_size: + raise MalformedHeaderError( + f"Expected header size: {header_size}, received: {header}" + ) + payload_size = int.from_bytes(bytes=header, byteorder="big") + if payload_size <= 0: + raise MalformedHeaderError( + f"Header is indicating non positive payload size: {payload_size}" + ) + received = b"" + while len(received) < payload_size: + chunk = source.recv(buffer_size) + if len(chunk) == 0: + raise TransmissionChannelClosed( + "Socket was closed to read before payload was decoded." + ) + received += chunk + try: + return json.loads(received) + except ValueError: + raise MalformedPayloadError("Received payload that is not in a JSON format") + + +def send_data_trough_socket( + target: socket.socket, + header_size: int, + data: bytes, + request_id: str, + recover_from_overflow: bool = True, + pipeline_id: Optional[str] = None, +) -> None: + try: + data_size = len(data) + header = data_size.to_bytes(length=header_size, byteorder="big") + payload = header + data + target.sendall(payload) + except OverflowError as error: + if not recover_from_overflow: + logger.error(f"OverflowError was suppressed. {error}") + return None + error_response = prepare_error_response( + request_id=request_id, + error=error, + error_type=ErrorType.INTERNAL_ERROR, + pipeline_id=pipeline_id, + ) + send_data_trough_socket( + target=target, + header_size=header_size, + data=error_response, + request_id=request_id, + recover_from_overflow=False, + pipeline_id=pipeline_id, + ) + except Exception as error: + logger.error(f"Could not send the response through socket. Error: {error}") diff --git a/inference/enterprise/stream_management/manager/entities.py b/inference/enterprise/stream_management/manager/entities.py new file mode 100644 index 0000000000000000000000000000000000000000..cd02ee66ca7567d0e2d6cf8b40967bafb91cc63d --- /dev/null +++ b/inference/enterprise/stream_management/manager/entities.py @@ -0,0 +1,32 @@ +from enum import Enum + +STATUS_KEY = "status" +TYPE_KEY = "type" +ERROR_TYPE_KEY = "error_type" +REQUEST_ID_KEY = "request_id" +PIPELINE_ID_KEY = "pipeline_id" +COMMAND_KEY = "command" +RESPONSE_KEY = "response" +ENCODING = "utf-8" + + +class OperationStatus(str, Enum): + SUCCESS = "success" + FAILURE = "failure" + + +class ErrorType(str, Enum): + INTERNAL_ERROR = "internal_error" + INVALID_PAYLOAD = "invalid_payload" + NOT_FOUND = "not_found" + OPERATION_ERROR = "operation_error" + AUTHORISATION_ERROR = "authorisation_error" + + +class CommandType(str, Enum): + INIT = "init" + MUTE = "mute" + RESUME = "resume" + STATUS = "status" + TERMINATE = "terminate" + LIST_PIPELINES = "list_pipelines" diff --git a/inference/enterprise/stream_management/manager/errors.py b/inference/enterprise/stream_management/manager/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..3a3df9a14029425476245fc9ebb8a8e61a7d9a29 --- /dev/null +++ b/inference/enterprise/stream_management/manager/errors.py @@ -0,0 +1,18 @@ +class CommunicationProtocolError(Exception): + pass + + +class MessageToBigError(CommunicationProtocolError): + pass + + +class MalformedHeaderError(CommunicationProtocolError): + pass + + +class TransmissionChannelClosed(CommunicationProtocolError): + pass + + +class MalformedPayloadError(CommunicationProtocolError): + pass diff --git a/inference/enterprise/stream_management/manager/inference_pipeline_manager.py b/inference/enterprise/stream_management/manager/inference_pipeline_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..36c0ba7edb2928c8cab75c170f875c838aeb9928 --- /dev/null +++ b/inference/enterprise/stream_management/manager/inference_pipeline_manager.py @@ -0,0 +1,258 @@ +import os +import signal +from dataclasses import asdict +from multiprocessing import Process, Queue +from types import FrameType +from typing import Callable, Optional, Tuple + +from inference.core import logger +from inference.core.exceptions import ( + MissingApiKeyError, + RoboflowAPINotAuthorizedError, + RoboflowAPINotNotFoundError, +) +from inference.core.interfaces.camera.entities import VideoFrame +from inference.core.interfaces.camera.exceptions import StreamOperationNotAllowedError +from inference.core.interfaces.camera.video_source import ( + BufferConsumptionStrategy, + BufferFillingStrategy, +) +from inference.core.interfaces.stream.entities import ObjectDetectionPrediction +from inference.core.interfaces.stream.inference_pipeline import InferencePipeline +from inference.core.interfaces.stream.sinks import UDPSink +from inference.core.interfaces.stream.watchdog import ( + BasePipelineWatchDog, + PipelineWatchDog, +) +from inference.enterprise.stream_management.manager.entities import ( + STATUS_KEY, + TYPE_KEY, + CommandType, + ErrorType, + OperationStatus, +) +from inference.enterprise.stream_management.manager.serialisation import describe_error + + +def ignore_signal(signal_number: int, frame: FrameType) -> None: + pid = os.getpid() + logger.info( + f"Ignoring signal {signal_number} in InferencePipelineManager in process:{pid}" + ) + + +class InferencePipelineManager(Process): + @classmethod + def init( + cls, command_queue: Queue, responses_queue: Queue + ) -> "InferencePipelineManager": + return cls(command_queue=command_queue, responses_queue=responses_queue) + + def __init__(self, command_queue: Queue, responses_queue: Queue): + super().__init__() + self._command_queue = command_queue + self._responses_queue = responses_queue + self._inference_pipeline: Optional[InferencePipeline] = None + self._watchdog: Optional[PipelineWatchDog] = None + self._stop = False + + def run(self) -> None: + signal.signal(signal.SIGINT, ignore_signal) + signal.signal(signal.SIGTERM, self._handle_termination_signal) + while not self._stop: + command: Optional[Tuple[str, dict]] = self._command_queue.get() + if command is None: + break + request_id, payload = command + self._handle_command(request_id=request_id, payload=payload) + + def _handle_command(self, request_id: str, payload: dict) -> None: + try: + logger.info(f"Processing request={request_id}...") + command_type = payload[TYPE_KEY] + if command_type is CommandType.INIT: + return self._initialise_pipeline(request_id=request_id, payload=payload) + if command_type is CommandType.TERMINATE: + return self._terminate_pipeline(request_id=request_id) + if command_type is CommandType.MUTE: + return self._mute_pipeline(request_id=request_id) + if command_type is CommandType.RESUME: + return self._resume_pipeline(request_id=request_id) + if command_type is CommandType.STATUS: + return self._get_pipeline_status(request_id=request_id) + raise NotImplementedError( + f"Command type `{command_type}` cannot be handled" + ) + except (KeyError, NotImplementedError) as error: + self._handle_error( + request_id=request_id, error=error, error_type=ErrorType.INVALID_PAYLOAD + ) + except Exception as error: + self._handle_error( + request_id=request_id, error=error, error_type=ErrorType.INTERNAL_ERROR + ) + + def _initialise_pipeline(self, request_id: str, payload: dict) -> None: + try: + watchdog = BasePipelineWatchDog() + sink = assembly_pipeline_sink(sink_config=payload["sink_configuration"]) + source_buffer_filling_strategy, source_buffer_consumption_strategy = ( + None, + None, + ) + if "source_buffer_filling_strategy" in payload: + source_buffer_filling_strategy = BufferFillingStrategy( + payload["source_buffer_filling_strategy"].upper() + ) + if "source_buffer_consumption_strategy" in payload: + source_buffer_consumption_strategy = BufferConsumptionStrategy( + payload["source_buffer_consumption_strategy"].upper() + ) + model_configuration = payload["model_configuration"] + if model_configuration["type"] != "object-detection": + raise NotImplementedError("Only object-detection models are supported") + self._inference_pipeline = InferencePipeline.init( + model_id=payload["model_id"], + video_reference=payload["video_reference"], + on_prediction=sink, + api_key=payload.get("api_key"), + max_fps=payload.get("max_fps"), + watchdog=watchdog, + source_buffer_filling_strategy=source_buffer_filling_strategy, + source_buffer_consumption_strategy=source_buffer_consumption_strategy, + class_agnostic_nms=model_configuration.get("class_agnostic_nms"), + confidence=model_configuration.get("confidence"), + iou_threshold=model_configuration.get("iou_threshold"), + max_candidates=model_configuration.get("max_candidates"), + max_detections=model_configuration.get("max_detections"), + active_learning_enabled=payload.get("active_learning_enabled"), + ) + self._watchdog = watchdog + self._inference_pipeline.start(use_main_thread=False) + self._responses_queue.put( + (request_id, {STATUS_KEY: OperationStatus.SUCCESS}) + ) + logger.info(f"Pipeline initialised. request_id={request_id}...") + except (MissingApiKeyError, KeyError, NotImplementedError) as error: + self._handle_error( + request_id=request_id, error=error, error_type=ErrorType.INVALID_PAYLOAD + ) + except RoboflowAPINotAuthorizedError as error: + self._handle_error( + request_id=request_id, + error=error, + error_type=ErrorType.AUTHORISATION_ERROR, + ) + except RoboflowAPINotNotFoundError as error: + self._handle_error( + request_id=request_id, error=error, error_type=ErrorType.NOT_FOUND + ) + + def _terminate_pipeline(self, request_id: str) -> None: + if self._inference_pipeline is None: + self._responses_queue.put( + (request_id, {STATUS_KEY: OperationStatus.SUCCESS}) + ) + self._stop = True + return None + try: + self._execute_termination() + logger.info(f"Pipeline terminated. request_id={request_id}...") + self._responses_queue.put( + (request_id, {STATUS_KEY: OperationStatus.SUCCESS}) + ) + except StreamOperationNotAllowedError as error: + self._handle_error( + request_id=request_id, error=error, error_type=ErrorType.OPERATION_ERROR + ) + + def _handle_termination_signal(self, signal_number: int, frame: FrameType) -> None: + try: + pid = os.getpid() + logger.info(f"Terminating pipeline in process:{pid}...") + if self._inference_pipeline is not None: + self._execute_termination() + self._command_queue.put(None) + logger.info(f"Termination successful in process:{pid}...") + except Exception as error: + logger.warning(f"Could not terminate pipeline gracefully. Error: {error}") + + def _execute_termination(self) -> None: + self._inference_pipeline.terminate() + self._inference_pipeline.join() + self._stop = True + + def _mute_pipeline(self, request_id: str) -> None: + if self._inference_pipeline is None: + return self._handle_error( + request_id=request_id, error_type=ErrorType.OPERATION_ERROR + ) + try: + self._inference_pipeline.mute_stream() + logger.info(f"Pipeline muted. request_id={request_id}...") + self._responses_queue.put( + (request_id, {STATUS_KEY: OperationStatus.SUCCESS}) + ) + except StreamOperationNotAllowedError as error: + self._handle_error( + request_id=request_id, error=error, error_type=ErrorType.OPERATION_ERROR + ) + + def _resume_pipeline(self, request_id: str) -> None: + if self._inference_pipeline is None: + return self._handle_error( + request_id=request_id, error_type=ErrorType.OPERATION_ERROR + ) + try: + self._inference_pipeline.resume_stream() + logger.info(f"Pipeline resumed. request_id={request_id}...") + self._responses_queue.put( + (request_id, {STATUS_KEY: OperationStatus.SUCCESS}) + ) + except StreamOperationNotAllowedError as error: + self._handle_error( + request_id=request_id, error=error, error_type=ErrorType.OPERATION_ERROR + ) + + def _get_pipeline_status(self, request_id: str) -> None: + if self._watchdog is None: + return self._handle_error( + request_id=request_id, error_type=ErrorType.OPERATION_ERROR + ) + try: + report = self._watchdog.get_report() + if report is None: + return self._handle_error( + request_id=request_id, error_type=ErrorType.OPERATION_ERROR + ) + response_payload = { + STATUS_KEY: OperationStatus.SUCCESS, + "report": asdict(report), + } + self._responses_queue.put((request_id, response_payload)) + logger.info(f"Pipeline status returned. request_id={request_id}...") + except StreamOperationNotAllowedError as error: + self._handle_error( + request_id=request_id, error=error, error_type=ErrorType.OPERATION_ERROR + ) + + def _handle_error( + self, + request_id: str, + error: Optional[Exception] = None, + error_type: ErrorType = ErrorType.INTERNAL_ERROR, + ): + logger.error( + f"Could not handle Command. request_id={request_id}, error={error}, error_type={error_type}" + ) + response_payload = describe_error(error, error_type=error_type) + self._responses_queue.put((request_id, response_payload)) + + +def assembly_pipeline_sink( + sink_config: dict, +) -> Callable[[ObjectDetectionPrediction, VideoFrame], None]: + if sink_config["type"] != "udp_sink": + raise NotImplementedError("Only `udp_socket` sink type is supported") + sink = UDPSink.init(ip_address=sink_config["host"], port=sink_config["port"]) + return sink.send_predictions diff --git a/inference/enterprise/stream_management/manager/serialisation.py b/inference/enterprise/stream_management/manager/serialisation.py new file mode 100644 index 0000000000000000000000000000000000000000..69a2c87870c21baece5015d486a23c6f4bec7ca3 --- /dev/null +++ b/inference/enterprise/stream_management/manager/serialisation.py @@ -0,0 +1,60 @@ +import json +from datetime import date, datetime +from enum import Enum +from typing import Any, Optional + +from inference.enterprise.stream_management.manager.entities import ( + ENCODING, + ERROR_TYPE_KEY, + PIPELINE_ID_KEY, + REQUEST_ID_KEY, + RESPONSE_KEY, + STATUS_KEY, + ErrorType, + OperationStatus, +) + + +def serialise_to_json(obj: Any) -> Any: + if isinstance(obj, (datetime, date)): + return obj.isoformat() + if issubclass(type(obj), Enum): + return obj.value + raise TypeError(f"Type {type(obj)} not serializable") + + +def describe_error( + exception: Optional[Exception] = None, + error_type: ErrorType = ErrorType.INTERNAL_ERROR, +) -> dict: + payload = { + STATUS_KEY: OperationStatus.FAILURE, + ERROR_TYPE_KEY: error_type, + } + if exception is not None: + payload["error_class"] = exception.__class__.__name__ + payload["error_message"] = str(exception) + return payload + + +def prepare_error_response( + request_id: str, error: Exception, error_type: ErrorType, pipeline_id: Optional[str] +) -> bytes: + error_description = describe_error(exception=error, error_type=error_type) + return prepare_response( + request_id=request_id, response=error_description, pipeline_id=pipeline_id + ) + + +def prepare_response( + request_id: str, response: dict, pipeline_id: Optional[str] +) -> bytes: + payload = json.dumps( + { + REQUEST_ID_KEY: request_id, + RESPONSE_KEY: response, + PIPELINE_ID_KEY: pipeline_id, + }, + default=serialise_to_json, + ) + return payload.encode(ENCODING) diff --git a/inference/enterprise/stream_management/manager/tcp_server.py b/inference/enterprise/stream_management/manager/tcp_server.py new file mode 100644 index 0000000000000000000000000000000000000000..5f951b1b6582707a3f70d5e0259fca8823267f81 --- /dev/null +++ b/inference/enterprise/stream_management/manager/tcp_server.py @@ -0,0 +1,19 @@ +import socket +from socketserver import BaseRequestHandler, TCPServer +from typing import Any, Optional, Tuple, Type + + +class RoboflowTCPServer(TCPServer): + def __init__( + self, + server_address: Tuple[str, int], + handler_class: Type[BaseRequestHandler], + socket_operations_timeout: Optional[float] = None, + ): + TCPServer.__init__(self, server_address, handler_class) + self._socket_operations_timeout = socket_operations_timeout + + def get_request(self) -> Tuple[socket.socket, Any]: + connection, address = self.socket.accept() + connection.settimeout(self._socket_operations_timeout) + return connection, address diff --git a/inference/enterprise/workflows/__init__.py b/inference/enterprise/workflows/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/enterprise/workflows/__pycache__/__init__.cpython-310.pyc b/inference/enterprise/workflows/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a317bcf2984afbe31424b1fb51554673a89c578 Binary files /dev/null and b/inference/enterprise/workflows/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/__pycache__/constants.cpython-310.pyc b/inference/enterprise/workflows/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f10c2cf33a9742260f7b108278dd59e8e1d9d3c Binary files /dev/null and b/inference/enterprise/workflows/__pycache__/constants.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/__pycache__/errors.cpython-310.pyc b/inference/enterprise/workflows/__pycache__/errors.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54408a6b9d7bfdda70238a6cfe2ef1ea22ef14ac Binary files /dev/null and b/inference/enterprise/workflows/__pycache__/errors.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/complier/__init__.py b/inference/enterprise/workflows/complier/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/enterprise/workflows/complier/__pycache__/__init__.cpython-310.pyc b/inference/enterprise/workflows/complier/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2962d3a51b6e1946e01ea1c0537d5c6643decd1d Binary files /dev/null and b/inference/enterprise/workflows/complier/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/complier/__pycache__/core.cpython-310.pyc b/inference/enterprise/workflows/complier/__pycache__/core.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa0121052288f1cf74ab5c5f36c4ab7f518450c8 Binary files /dev/null and b/inference/enterprise/workflows/complier/__pycache__/core.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/complier/__pycache__/entities.cpython-310.pyc b/inference/enterprise/workflows/complier/__pycache__/entities.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..334bb5fc0b0a8727f430db891e3868b8f3e87463 Binary files /dev/null and b/inference/enterprise/workflows/complier/__pycache__/entities.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/complier/__pycache__/execution_engine.cpython-310.pyc b/inference/enterprise/workflows/complier/__pycache__/execution_engine.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2eea139ca7d6a160f3d5ce0d5e5f8a7db1b9995 Binary files /dev/null and b/inference/enterprise/workflows/complier/__pycache__/execution_engine.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/complier/__pycache__/flow_coordinator.cpython-310.pyc b/inference/enterprise/workflows/complier/__pycache__/flow_coordinator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb9123c48df93bac9745eced1bcfc2eca34067ec Binary files /dev/null and b/inference/enterprise/workflows/complier/__pycache__/flow_coordinator.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/complier/__pycache__/graph_parser.cpython-310.pyc b/inference/enterprise/workflows/complier/__pycache__/graph_parser.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a36790184651d8cd6c91036a8d9d2729154ca33e Binary files /dev/null and b/inference/enterprise/workflows/complier/__pycache__/graph_parser.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/complier/__pycache__/runtime_input_validator.cpython-310.pyc b/inference/enterprise/workflows/complier/__pycache__/runtime_input_validator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a214bb0dc47f13329f558dc929abfeddce2591e9 Binary files /dev/null and b/inference/enterprise/workflows/complier/__pycache__/runtime_input_validator.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/complier/__pycache__/utils.cpython-310.pyc b/inference/enterprise/workflows/complier/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc5dd9fcd177995f8b8f31b34af9d66eb11a340e Binary files /dev/null and b/inference/enterprise/workflows/complier/__pycache__/utils.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/complier/__pycache__/validator.cpython-310.pyc b/inference/enterprise/workflows/complier/__pycache__/validator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a704ae726ff6066b3f93e25ec5ebc4db3e315de Binary files /dev/null and b/inference/enterprise/workflows/complier/__pycache__/validator.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/complier/core.py b/inference/enterprise/workflows/complier/core.py new file mode 100644 index 0000000000000000000000000000000000000000..913280fc92159469c85bd54ecb01d281b17dc37d --- /dev/null +++ b/inference/enterprise/workflows/complier/core.py @@ -0,0 +1,95 @@ +import asyncio +from asyncio import AbstractEventLoop +from typing import Any, Dict, Optional + +from fastapi import BackgroundTasks + +from inference.core.cache import cache +from inference.core.env import API_KEY, MAX_ACTIVE_MODELS +from inference.core.managers.base import ModelManager +from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache +from inference.core.registries.roboflow import RoboflowModelRegistry +from inference.enterprise.workflows.complier.entities import StepExecutionMode +from inference.enterprise.workflows.complier.execution_engine import execute_graph +from inference.enterprise.workflows.complier.graph_parser import prepare_execution_graph +from inference.enterprise.workflows.complier.steps_executors.active_learning_middlewares import ( + WorkflowsActiveLearningMiddleware, +) +from inference.enterprise.workflows.complier.validator import ( + validate_workflow_specification, +) +from inference.enterprise.workflows.entities.workflows_specification import ( + WorkflowSpecification, +) +from inference.enterprise.workflows.errors import InvalidSpecificationVersionError +from inference.models.utils import ROBOFLOW_MODEL_TYPES + + +def compile_and_execute( + workflow_specification: dict, + runtime_parameters: Dict[str, Any], + api_key: Optional[str] = None, + model_manager: Optional[ModelManager] = None, + loop: Optional[AbstractEventLoop] = None, + active_learning_middleware: Optional[WorkflowsActiveLearningMiddleware] = None, + background_tasks: Optional[BackgroundTasks] = None, + max_concurrent_steps: int = 1, + step_execution_mode: StepExecutionMode = StepExecutionMode.LOCAL, +) -> dict: + if loop is None: + loop = asyncio.get_event_loop() + return loop.run_until_complete( + compile_and_execute_async( + workflow_specification=workflow_specification, + runtime_parameters=runtime_parameters, + model_manager=model_manager, + api_key=api_key, + active_learning_middleware=active_learning_middleware, + background_tasks=background_tasks, + max_concurrent_steps=max_concurrent_steps, + step_execution_mode=step_execution_mode, + ) + ) + + +async def compile_and_execute_async( + workflow_specification: dict, + runtime_parameters: Dict[str, Any], + model_manager: Optional[ModelManager] = None, + api_key: Optional[str] = None, + active_learning_middleware: Optional[WorkflowsActiveLearningMiddleware] = None, + background_tasks: Optional[BackgroundTasks] = None, + max_concurrent_steps: int = 1, + step_execution_mode: StepExecutionMode = StepExecutionMode.LOCAL, +) -> dict: + if api_key is None: + api_key = API_KEY + if model_manager is None: + model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES) + model_manager = ModelManager(model_registry=model_registry) + model_manager = WithFixedSizeCache(model_manager, max_size=MAX_ACTIVE_MODELS) + if active_learning_middleware is None: + active_learning_middleware = WorkflowsActiveLearningMiddleware(cache=cache) + parsed_workflow_specification = WorkflowSpecification.parse_obj( + workflow_specification + ) + if parsed_workflow_specification.specification.version != "1.0": + raise InvalidSpecificationVersionError( + f"Only version 1.0 of workflow specification is supported." + ) + validate_workflow_specification( + workflow_specification=parsed_workflow_specification.specification + ) + execution_graph = prepare_execution_graph( + workflow_specification=parsed_workflow_specification.specification + ) + return await execute_graph( + execution_graph=execution_graph, + runtime_parameters=runtime_parameters, + model_manager=model_manager, + active_learning_middleware=active_learning_middleware, + background_tasks=background_tasks, + api_key=api_key, + max_concurrent_steps=max_concurrent_steps, + step_execution_mode=step_execution_mode, + ) diff --git a/inference/enterprise/workflows/complier/entities.py b/inference/enterprise/workflows/complier/entities.py new file mode 100644 index 0000000000000000000000000000000000000000..79f8988dd18d52b83575c1bbdaa9d87f0066a04f --- /dev/null +++ b/inference/enterprise/workflows/complier/entities.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class StepExecutionMode(Enum): + LOCAL = "local" + REMOTE = "remote" diff --git a/inference/enterprise/workflows/complier/execution_engine.py b/inference/enterprise/workflows/complier/execution_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..b90c8f127b1167c751e7dc88a6e62e462b1e6a10 --- /dev/null +++ b/inference/enterprise/workflows/complier/execution_engine.py @@ -0,0 +1,315 @@ +import asyncio +from datetime import datetime +from typing import Any, Dict, List, Optional, Set + +import networkx as nx +from fastapi import BackgroundTasks +from networkx import DiGraph + +from inference.core import logger +from inference.core.managers.base import ModelManager +from inference.enterprise.workflows.complier.entities import StepExecutionMode +from inference.enterprise.workflows.complier.flow_coordinator import ( + ParallelStepExecutionCoordinator, + SerialExecutionCoordinator, +) +from inference.enterprise.workflows.complier.runtime_input_validator import ( + prepare_runtime_parameters, +) +from inference.enterprise.workflows.complier.steps_executors.active_learning_middlewares import ( + WorkflowsActiveLearningMiddleware, +) +from inference.enterprise.workflows.complier.steps_executors.auxiliary import ( + run_active_learning_data_collector, + run_condition_step, + run_crop_step, + run_detection_filter, + run_detection_offset_step, + run_detections_consensus_step, + run_static_crop_step, +) +from inference.enterprise.workflows.complier.steps_executors.constants import ( + PARENT_COORDINATES_SUFFIX, +) +from inference.enterprise.workflows.complier.steps_executors.models import ( + run_clip_comparison_step, + run_ocr_model_step, + run_roboflow_model_step, + run_yolo_world_model_step, +) +from inference.enterprise.workflows.complier.steps_executors.types import OutputsLookup +from inference.enterprise.workflows.complier.steps_executors.utils import make_batches +from inference.enterprise.workflows.complier.utils import ( + get_nodes_of_specific_kind, + get_step_selector_from_its_output, + is_condition_step, +) +from inference.enterprise.workflows.constants import OUTPUT_NODE_KIND +from inference.enterprise.workflows.entities.outputs import CoordinatesSystem +from inference.enterprise.workflows.entities.validators import get_last_selector_chunk +from inference.enterprise.workflows.errors import ( + ExecutionEngineError, + WorkflowsCompilerRuntimeError, +) + +STEP_TYPE2EXECUTOR_MAPPING = { + "ClassificationModel": run_roboflow_model_step, + "MultiLabelClassificationModel": run_roboflow_model_step, + "ObjectDetectionModel": run_roboflow_model_step, + "KeypointsDetectionModel": run_roboflow_model_step, + "InstanceSegmentationModel": run_roboflow_model_step, + "OCRModel": run_ocr_model_step, + "Crop": run_crop_step, + "Condition": run_condition_step, + "DetectionFilter": run_detection_filter, + "DetectionOffset": run_detection_offset_step, + "AbsoluteStaticCrop": run_static_crop_step, + "RelativeStaticCrop": run_static_crop_step, + "ClipComparison": run_clip_comparison_step, + "DetectionsConsensus": run_detections_consensus_step, + "ActiveLearningDataCollector": run_active_learning_data_collector, + "YoloWorld": run_yolo_world_model_step, +} + + +async def execute_graph( + execution_graph: DiGraph, + runtime_parameters: Dict[str, Any], + model_manager: ModelManager, + active_learning_middleware: WorkflowsActiveLearningMiddleware, + background_tasks: Optional[BackgroundTasks] = None, + api_key: Optional[str] = None, + max_concurrent_steps: int = 1, + step_execution_mode: StepExecutionMode = StepExecutionMode.LOCAL, +) -> dict: + runtime_parameters = prepare_runtime_parameters( + execution_graph=execution_graph, + runtime_parameters=runtime_parameters, + ) + outputs_lookup = {} + steps_to_discard = set() + if max_concurrent_steps > 1: + execution_coordinator = ParallelStepExecutionCoordinator.init( + execution_graph=execution_graph + ) + else: + execution_coordinator = SerialExecutionCoordinator.init( + execution_graph=execution_graph + ) + while True: + next_steps = execution_coordinator.get_steps_to_execute_next( + steps_to_discard=steps_to_discard + ) + if next_steps is None: + break + steps_to_discard = await execute_steps( + steps=next_steps, + max_concurrent_steps=max_concurrent_steps, + execution_graph=execution_graph, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + model_manager=model_manager, + api_key=api_key, + step_execution_mode=step_execution_mode, + active_learning_middleware=active_learning_middleware, + background_tasks=background_tasks, + ) + return construct_response( + execution_graph=execution_graph, outputs_lookup=outputs_lookup + ) + + +async def execute_steps( + steps: List[str], + max_concurrent_steps: int, + execution_graph: DiGraph, + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, + model_manager: ModelManager, + api_key: Optional[str], + step_execution_mode: StepExecutionMode, + active_learning_middleware: WorkflowsActiveLearningMiddleware, + background_tasks: Optional[BackgroundTasks], +) -> Set[str]: + """outputs_lookup is mutated while execution, only independent steps may be run together""" + logger.info(f"Executing steps: {steps}. Execution mode: {step_execution_mode}") + nodes_to_discard = set() + steps_batches = list(make_batches(iterable=steps, batch_size=max_concurrent_steps)) + for steps_batch in steps_batches: + logger.info(f"Steps batch: {steps_batch}") + coroutines = [ + safe_execute_step( + step=step, + execution_graph=execution_graph, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + model_manager=model_manager, + api_key=api_key, + step_execution_mode=step_execution_mode, + active_learning_middleware=active_learning_middleware, + background_tasks=background_tasks, + ) + for step in steps_batch + ] + results = await asyncio.gather(*coroutines) + for result in results: + nodes_to_discard.update(result) + return nodes_to_discard + + +async def safe_execute_step( + step: str, + execution_graph: DiGraph, + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, + model_manager: ModelManager, + api_key: Optional[str], + step_execution_mode: StepExecutionMode, + active_learning_middleware: WorkflowsActiveLearningMiddleware, + background_tasks: Optional[BackgroundTasks], +) -> Set[str]: + try: + return await execute_step( + step=step, + execution_graph=execution_graph, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + model_manager=model_manager, + api_key=api_key, + step_execution_mode=step_execution_mode, + active_learning_middleware=active_learning_middleware, + background_tasks=background_tasks, + ) + except Exception as error: + raise ExecutionEngineError( + f"Error during execution of step: {step}. " + f"Type of error: {type(error).__name__}. " + f"Cause: {error}" + ) from error + + +async def execute_step( + step: str, + execution_graph: DiGraph, + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, + model_manager: ModelManager, + api_key: Optional[str], + step_execution_mode: StepExecutionMode, + active_learning_middleware: WorkflowsActiveLearningMiddleware, + background_tasks: Optional[BackgroundTasks], +) -> Set[str]: + logger.info(f"started execution of: {step} - {datetime.now().isoformat()}") + nodes_to_discard = set() + step_definition = execution_graph.nodes[step]["definition"] + executor = STEP_TYPE2EXECUTOR_MAPPING[step_definition.type] + additional_args = {} + if step_definition.type == "ActiveLearningDataCollector": + additional_args["active_learning_middleware"] = active_learning_middleware + additional_args["background_tasks"] = background_tasks + next_step, outputs_lookup = await executor( + step=step_definition, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + model_manager=model_manager, + api_key=api_key, + step_execution_mode=step_execution_mode, + **additional_args, + ) + if is_condition_step(execution_graph=execution_graph, node=step): + if execution_graph.nodes[step]["definition"].step_if_true == next_step: + nodes_to_discard = get_all_nodes_in_execution_path( + execution_graph=execution_graph, + source=execution_graph.nodes[step]["definition"].step_if_false, + ) + else: + nodes_to_discard = get_all_nodes_in_execution_path( + execution_graph=execution_graph, + source=execution_graph.nodes[step]["definition"].step_if_true, + ) + logger.info(f"finished execution of: {step} - {datetime.now().isoformat()}") + return nodes_to_discard + + +def get_all_nodes_in_execution_path( + execution_graph: DiGraph, + source: str, +) -> Set[str]: + nodes = set(nx.descendants(execution_graph, source)) + nodes.add(source) + return nodes + + +def construct_response( + execution_graph: nx.DiGraph, + outputs_lookup: Dict[str, Any], +) -> Dict[str, Any]: + output_nodes = get_nodes_of_specific_kind( + execution_graph=execution_graph, kind=OUTPUT_NODE_KIND + ) + result = {} + for node in output_nodes: + node_definition = execution_graph.nodes[node]["definition"] + fallback_selector = None + node_selector = node_definition.selector + if node_definition.coordinates_system is CoordinatesSystem.PARENT: + fallback_selector = node_selector + node_selector = f"{node_selector}{PARENT_COORDINATES_SUFFIX}" + step_selector = get_step_selector_from_its_output( + step_output_selector=node_selector + ) + step_field = get_last_selector_chunk(selector=node_selector) + fallback_step_field = ( + None + if fallback_selector is None + else get_last_selector_chunk(selector=fallback_selector) + ) + step_result = outputs_lookup.get(step_selector) + if step_result is not None: + if issubclass(type(step_result), list): + step_result = extract_step_result_from_list( + result=step_result, + step_field=step_field, + fallback_step_field=fallback_step_field, + step_selector=step_selector, + ) + else: + step_result = extract_step_result_from_dict( + result=step_result, + step_field=step_field, + fallback_step_field=fallback_step_field, + step_selector=step_selector, + ) + result[execution_graph.nodes[node]["definition"].name] = step_result + return result + + +def extract_step_result_from_list( + result: List[Dict[str, Any]], + step_field: str, + fallback_step_field: Optional[str], + step_selector: str, +) -> List[Any]: + return [ + extract_step_result_from_dict( + result=element, + step_field=step_field, + fallback_step_field=fallback_step_field, + step_selector=step_selector, + ) + for element in result + ] + + +def extract_step_result_from_dict( + result: Dict[str, Any], + step_field: str, + fallback_step_field: Optional[str], + step_selector: str, +) -> Any: + step_result = result.get(step_field, result.get(fallback_step_field)) + if step_result is None: + raise WorkflowsCompilerRuntimeError( + f"Cannot find neither field {step_field} nor {fallback_step_field} in result of step {step_selector}" + ) + return step_result diff --git a/inference/enterprise/workflows/complier/flow_coordinator.py b/inference/enterprise/workflows/complier/flow_coordinator.py new file mode 100644 index 0000000000000000000000000000000000000000..253085f383d4fca7ec456732c851b765e472e2ae --- /dev/null +++ b/inference/enterprise/workflows/complier/flow_coordinator.py @@ -0,0 +1,163 @@ +import abc +from collections import defaultdict +from queue import Queue +from typing import List, Optional, Set + +import networkx as nx + +from inference.enterprise.workflows.complier.utils import get_nodes_of_specific_kind +from inference.enterprise.workflows.constants import STEP_NODE_KIND + + +class StepExecutionCoordinator(metaclass=abc.ABCMeta): + + @classmethod + @abc.abstractmethod + def init(cls, execution_graph: nx.DiGraph) -> "StepExecutionCoordinator": + pass + + @abc.abstractmethod + def get_steps_to_execute_next( + self, steps_to_discard: Set[str] + ) -> Optional[List[str]]: + pass + + +class SerialExecutionCoordinator(StepExecutionCoordinator): + + @classmethod + def init(cls, execution_graph: nx.DiGraph) -> "StepExecutionCoordinator": + return cls(execution_graph=execution_graph) + + def __init__(self, execution_graph: nx.DiGraph): + self._execution_graph = execution_graph.copy() + self._discarded_steps: Set[str] = set() + self.__order: Optional[List[str]] = None + self.__step_pointer = 0 + + def get_steps_to_execute_next( + self, steps_to_discard: Set[str] + ) -> Optional[List[str]]: + if self.__order is None: + self.__establish_execution_order() + self._discarded_steps.update(steps_to_discard) + next_step = None + while self.__step_pointer < len(self.__order): + candidate_step = self.__order[self.__step_pointer] + self.__step_pointer += 1 + if candidate_step in self._discarded_steps: + continue + return [candidate_step] + return next_step + + def __establish_execution_order(self) -> None: + step_nodes = get_nodes_of_specific_kind( + execution_graph=self._execution_graph, kind=STEP_NODE_KIND + ) + self.__order = [ + n for n in nx.topological_sort(self._execution_graph) if n in step_nodes + ] + self.__step_pointer = 0 + + +class ParallelStepExecutionCoordinator(StepExecutionCoordinator): + + @classmethod + def init(cls, execution_graph: nx.DiGraph) -> "StepExecutionCoordinator": + return cls(execution_graph=execution_graph) + + def __init__(self, execution_graph: nx.DiGraph): + self._execution_graph = execution_graph.copy() + self._discarded_steps: Set[str] = set() + self.__execution_order: Optional[List[List[str]]] = None + self.__execution_pointer = 0 + + def get_steps_to_execute_next( + self, steps_to_discard: Set[str] + ) -> Optional[List[str]]: + if self.__execution_order is None: + self.__execution_order = establish_execution_order( + execution_graph=self._execution_graph + ) + self.__execution_pointer = 0 + self._discarded_steps.update(steps_to_discard) + next_step = None + while self.__execution_pointer < len(self.__execution_order): + candidate_steps = [ + e + for e in self.__execution_order[self.__execution_pointer] + if e not in self._discarded_steps + ] + self.__execution_pointer += 1 + if len(candidate_steps) == 0: + continue + return candidate_steps + return next_step + + +def establish_execution_order( + execution_graph: nx.DiGraph, +) -> List[List[str]]: + steps_flow_graph = construct_steps_flow_graph(execution_graph=execution_graph) + steps_flow_graph = assign_max_distances_from_start( + steps_flow_graph=steps_flow_graph + ) + return get_groups_execution_order(steps_flow_graph=steps_flow_graph) + + +def construct_steps_flow_graph(execution_graph: nx.DiGraph) -> nx.DiGraph: + steps_flow_graph = nx.DiGraph() + steps_flow_graph.add_node("start") + steps_flow_graph.add_node("end") + step_nodes = get_nodes_of_specific_kind( + execution_graph=execution_graph, kind=STEP_NODE_KIND + ) + for step_node in step_nodes: + for predecessor in execution_graph.predecessors(step_node): + start_node = predecessor if predecessor in step_nodes else "start" + steps_flow_graph.add_edge(start_node, step_node) + for successor in execution_graph.successors(step_node): + end_node = successor if successor in step_nodes else "end" + steps_flow_graph.add_edge(step_node, end_node) + return steps_flow_graph + + +def assign_max_distances_from_start(steps_flow_graph: nx.DiGraph) -> nx.DiGraph: + nodes_to_consider = Queue() + nodes_to_consider.put("start") + while nodes_to_consider.qsize() > 0: + node_to_consider = nodes_to_consider.get() + predecessors = list(steps_flow_graph.predecessors(node_to_consider)) + if not all( + steps_flow_graph.nodes[p].get("distance") is not None for p in predecessors + ): + # we can proceed to establish distance, only if all parents have distances established + continue + if len(predecessors) == 0: + distance_from_start = 0 + else: + distance_from_start = ( + max(steps_flow_graph.nodes[p]["distance"] for p in predecessors) + 1 + ) + steps_flow_graph.nodes[node_to_consider]["distance"] = distance_from_start + for neighbour in steps_flow_graph.successors(node_to_consider): + nodes_to_consider.put(neighbour) + return steps_flow_graph + + +def get_groups_execution_order(steps_flow_graph: nx.DiGraph) -> List[List[str]]: + distance2steps = defaultdict(list) + for node_name, node_data in steps_flow_graph.nodes(data=True): + if node_name in {"start", "end"}: + continue + distance2steps[node_data["distance"]].append(node_name) + sorted_distances = sorted(list(distance2steps.keys())) + return [distance2steps[d] for d in sorted_distances] + + +def get_next_steps_to_execute( + execution_order: List[List[str]], + execution_pointer: int, + discarded_steps: Set[str], +) -> List[str]: + return [e for e in execution_order[execution_pointer] if e not in discarded_steps] diff --git a/inference/enterprise/workflows/complier/graph_parser.py b/inference/enterprise/workflows/complier/graph_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..f3bfead98551f14f3d984e37bf890f8851e8328c --- /dev/null +++ b/inference/enterprise/workflows/complier/graph_parser.py @@ -0,0 +1,447 @@ +from collections import defaultdict +from typing import Any, Dict, List, Optional, Set, Tuple + +import networkx as nx +from networkx import DiGraph + +from inference.enterprise.workflows.complier.utils import ( + construct_input_selector, + construct_output_name, + construct_step_selector, + get_nodes_of_specific_kind, + get_step_input_selectors, + get_step_selector_from_its_output, + is_condition_step, + is_step_output_selector, +) +from inference.enterprise.workflows.constants import ( + INPUT_NODE_KIND, + OUTPUT_NODE_KIND, + STEP_NODE_KIND, +) +from inference.enterprise.workflows.entities.outputs import JsonField +from inference.enterprise.workflows.entities.steps import StepInterface +from inference.enterprise.workflows.entities.validators import is_selector +from inference.enterprise.workflows.entities.workflows_specification import ( + InputType, + StepType, + WorkflowSpecificationV1, +) +from inference.enterprise.workflows.errors import ( + AmbiguousPathDetected, + NodesNotReachingOutputError, + NotAcyclicGraphError, + SelectorToUndefinedNodeError, +) + + +def prepare_execution_graph(workflow_specification: WorkflowSpecificationV1) -> DiGraph: + execution_graph = construct_graph(workflow_specification=workflow_specification) + if not nx.is_directed_acyclic_graph(execution_graph): + raise NotAcyclicGraphError(f"Detected cycle in execution graph.") + verify_each_node_reach_at_least_one_output(execution_graph=execution_graph) + verify_each_node_step_has_parent_in_the_same_branch(execution_graph=execution_graph) + verify_that_steps_are_connected_with_compatible_inputs( + execution_graph=execution_graph + ) + return execution_graph + + +def construct_graph(workflow_specification: WorkflowSpecificationV1) -> DiGraph: + execution_graph = nx.DiGraph() + execution_graph = add_input_nodes_for_graph( + inputs=workflow_specification.inputs, execution_graph=execution_graph + ) + execution_graph = add_steps_nodes_for_graph( + steps=workflow_specification.steps, execution_graph=execution_graph + ) + execution_graph = add_output_nodes_for_graph( + outputs=workflow_specification.outputs, execution_graph=execution_graph + ) + execution_graph = add_steps_edges( + workflow_specification=workflow_specification, execution_graph=execution_graph + ) + return add_edges_for_outputs( + workflow_specification=workflow_specification, execution_graph=execution_graph + ) + + +def add_input_nodes_for_graph( + inputs: List[InputType], + execution_graph: DiGraph, +) -> DiGraph: + for input_spec in inputs: + input_selector = construct_input_selector(input_name=input_spec.name) + execution_graph.add_node( + input_selector, + kind=INPUT_NODE_KIND, + definition=input_spec, + ) + return execution_graph + + +def add_steps_nodes_for_graph( + steps: List[StepType], + execution_graph: DiGraph, +) -> DiGraph: + for step in steps: + step_selector = construct_step_selector(step_name=step.name) + execution_graph.add_node( + step_selector, + kind=STEP_NODE_KIND, + definition=step, + ) + return execution_graph + + +def add_output_nodes_for_graph( + outputs: List[JsonField], + execution_graph: DiGraph, +) -> DiGraph: + for output_spec in outputs: + execution_graph.add_node( + construct_output_name(name=output_spec.name), + kind=OUTPUT_NODE_KIND, + definition=output_spec, + ) + return execution_graph + + +def add_steps_edges( + workflow_specification: WorkflowSpecificationV1, + execution_graph: DiGraph, +) -> DiGraph: + for step in workflow_specification.steps: + input_selectors = get_step_input_selectors(step=step) + step_selector = construct_step_selector(step_name=step.name) + execution_graph = add_edges_for_step_inputs( + execution_graph=execution_graph, + input_selectors=input_selectors, + step_selector=step_selector, + ) + if step.type == "Condition": + verify_edge_is_created_between_existing_nodes( + execution_graph=execution_graph, + start=step_selector, + end=step.step_if_true, + ) + execution_graph.add_edge(step_selector, step.step_if_true) + verify_edge_is_created_between_existing_nodes( + execution_graph=execution_graph, + start=step_selector, + end=step.step_if_false, + ) + execution_graph.add_edge(step_selector, step.step_if_false) + return execution_graph + + +def add_edges_for_step_inputs( + execution_graph: DiGraph, + input_selectors: Set[str], + step_selector: str, +) -> DiGraph: + for input_selector in input_selectors: + if is_step_output_selector(selector_or_value=input_selector): + input_selector = get_step_selector_from_its_output( + step_output_selector=input_selector + ) + verify_edge_is_created_between_existing_nodes( + execution_graph=execution_graph, + start=input_selector, + end=step_selector, + ) + execution_graph.add_edge(input_selector, step_selector) + return execution_graph + + +def add_edges_for_outputs( + workflow_specification: WorkflowSpecificationV1, + execution_graph: DiGraph, +) -> DiGraph: + for output in workflow_specification.outputs: + output_selector = output.selector + if is_step_output_selector(selector_or_value=output_selector): + output_selector = get_step_selector_from_its_output( + step_output_selector=output_selector + ) + output_name = construct_output_name(name=output.name) + verify_edge_is_created_between_existing_nodes( + execution_graph=execution_graph, + start=output_selector, + end=output_name, + ) + execution_graph.add_edge(output_selector, output_name) + return execution_graph + + +def verify_edge_is_created_between_existing_nodes( + execution_graph: DiGraph, + start: str, + end: str, +) -> None: + if not execution_graph.has_node(start): + raise SelectorToUndefinedNodeError( + f"Graph definition contains selector {start} that points to not defined element." + ) + if not execution_graph.has_node(end): + raise SelectorToUndefinedNodeError( + f"Graph definition contains selector {end} that points to not defined element." + ) + + +def verify_each_node_reach_at_least_one_output( + execution_graph: DiGraph, +) -> None: + all_nodes = set(execution_graph.nodes()) + output_nodes = get_nodes_of_specific_kind( + execution_graph=execution_graph, kind=OUTPUT_NODE_KIND + ) + nodes_without_outputs = get_nodes_that_do_not_produce_outputs( + execution_graph=execution_graph + ) + nodes_that_must_be_reached = output_nodes.union(nodes_without_outputs) + nodes_reaching_output = ( + get_nodes_that_are_reachable_from_pointed_ones_in_reversed_graph( + execution_graph=execution_graph, + pointed_nodes=nodes_that_must_be_reached, + ) + ) + nodes_not_reaching_output = all_nodes.difference(nodes_reaching_output) + if len(nodes_not_reaching_output) > 0: + raise NodesNotReachingOutputError( + f"Detected {len(nodes_not_reaching_output)} nodes not reaching any of output node:" + f"{nodes_not_reaching_output}." + ) + + +def get_nodes_that_do_not_produce_outputs(execution_graph: DiGraph) -> Set[str]: + # assumption is that nodes without outputs will produce some side effect and shall be + # treated as output nodes while checking if there is no dangling steps in graph + step_nodes = get_nodes_of_specific_kind( + execution_graph=execution_graph, kind=STEP_NODE_KIND + ) + return { + step_node + for step_node in step_nodes + if len(execution_graph.nodes[step_node]["definition"].get_output_names()) == 0 + } + + +def get_nodes_that_are_reachable_from_pointed_ones_in_reversed_graph( + execution_graph: DiGraph, + pointed_nodes: Set[str], +) -> Set[str]: + result = set() + reversed_graph = execution_graph.reverse(copy=True) + for pointed_node in pointed_nodes: + nodes_reaching_pointed_one = list( + nx.dfs_postorder_nodes(reversed_graph, pointed_node) + ) + result.update(nodes_reaching_pointed_one) + return result + + +def verify_each_node_step_has_parent_in_the_same_branch( + execution_graph: DiGraph, +) -> None: + """ + Conditional branching creates a bit of mess, in terms of determining which + steps to execute. + Let's imagine graph: + / -> B -> C -> D + A -> IF < \ + \ -> E -> F -> G -> H + where node G requires node C even though IF branched the execution. In other + words - the problem emerges if a node of kind STEP has a parent (node from which + it can be achieved) of kind STEP and this parent is in a different branch (point out that + we allow for a single step to have multiple steps as input, but they must be at the same + execution path - for instance if D requires an output from C and B - this is allowed). + Additionally, we must prevent situation when outcomes of branches started by two or more + condition steps merge with each other, as condition eval may result in contradictory + execution (2). + + + We need to detect that situation upfront, such that we can raise error of ambiguous execution path + rather than run time-consuming computations that will end up in error. + + To detect problem, first we detect steps with more than one parent step. + From those steps we trace what sequence of steps would lead to execution of problematic one. + For each problematic node we take its parent nodes. Then, we analyse paths from + those parent nodes in reversed topological order (from those nodes towards entry nodes + of execution graph). While our analysis, on each path we denote `Condition` steps and + result of condition evaluation that must have been observed in runtime, to reach + the problematic node while graph execution in normal direction. If we detect that + for any `Condition` step we would need to output both True and False (more than one registered + next step of `Condition` step) - we raise error. + To detect problem (2) - we only let number of different condition steps considered be the number of + max condition steps in a single path from origin to parent of problematic step. + + Beware that the latter part of algorithm has quite bad time complexity in general case. + Worst part of algorithm runs at O(V^4) - at least taking coarse, worst-case estimations. + In fact, there is not so bad: + * The number of step nodes with multiple parents that we loop over in main loop, reduces the number of + steps we iterate through in inner loops, as we are dealing with DAG (with quite limited amount of edges) + and for each multi-parent node takes at least two other nodes (to construct a suspicious group) - + so expected number of iterations in main loop is low - let's say 1-3 for a real graph. + * for any reasonable execution graph, the complexity should be acceptable. + """ + steps_with_more_than_one_parent = detect_steps_with_more_than_one_parent_step( + execution_graph=execution_graph + ) # O(V+E) + if len(steps_with_more_than_one_parent) == 0: + return None + reversed_steps_graph = construct_reversed_steps_graph( + execution_graph=execution_graph + ) # O(V+E) + reversed_topological_order = list( + nx.topological_sort(reversed_steps_graph) + ) # O(V+E) + for step in steps_with_more_than_one_parent: # O(V) + verify_multi_parent_step_execution_paths( + reversed_steps_graph=reversed_steps_graph, + reversed_topological_order=reversed_topological_order, + step=step, + ) + + +def detect_steps_with_more_than_one_parent_step(execution_graph: DiGraph) -> Set[str]: + steps_nodes = get_nodes_of_specific_kind( + execution_graph=execution_graph, + kind=STEP_NODE_KIND, + ) + edges_of_steps_nodes = [edge for edge in execution_graph.edges()] + steps_parents = defaultdict(set) + for edge in edges_of_steps_nodes: + parent, child = edge + if parent not in steps_nodes or child not in steps_nodes: + continue + steps_parents[child].add(parent) + return {key for key, value in steps_parents.items() if len(value) > 1} + + +def construct_reversed_steps_graph(execution_graph: DiGraph) -> DiGraph: + reversed_steps_graph = execution_graph.reverse() + for node, node_data in list(reversed_steps_graph.nodes(data=True)): + if node_data.get("kind") != STEP_NODE_KIND: + reversed_steps_graph.remove_node(node) + return reversed_steps_graph + + +def verify_multi_parent_step_execution_paths( + reversed_steps_graph: nx.DiGraph, + reversed_topological_order: List[str], + step: str, +) -> None: + condition_steps_successors = defaultdict(set) + max_conditions_steps = 0 + for normal_flow_predecessor in reversed_steps_graph.successors(step): # O(V) + reversed_flow_path = ( + construct_reversed_path_to_multi_parent_step_parent( # O(E) -> O(V^2) + reversed_steps_graph=reversed_steps_graph, + reversed_topological_order=reversed_topological_order, + parent_step=normal_flow_predecessor, + step=step, + ) + ) + ( + condition_steps_successors, + condition_steps, + ) = denote_condition_steps_successors_in_normal_flow( # O(V) + reversed_steps_graph=reversed_steps_graph, + reversed_flow_path=reversed_flow_path, + condition_steps_successors=condition_steps_successors, + ) + max_conditions_steps = max(condition_steps, max_conditions_steps) + if len(condition_steps_successors) > max_conditions_steps: + raise AmbiguousPathDetected( + f"In execution graph, detected collision of branches that originate in different condition steps." + ) + for condition_step, potential_next_steps in condition_steps_successors.items(): + if len(potential_next_steps) > 1: + raise AmbiguousPathDetected( + f"In execution graph, condition step: {condition_step} creates ambiguous execution paths." + ) + + +def construct_reversed_path_to_multi_parent_step_parent( + reversed_steps_graph: nx.DiGraph, + reversed_topological_order: List[str], + parent_step: str, + step: str, +) -> List[str]: + normal_flow_path_nodes = nx.descendants(reversed_steps_graph, parent_step) + normal_flow_path_nodes.add(parent_step) + normal_flow_path_nodes.add(step) + return [n for n in reversed_topological_order if n in normal_flow_path_nodes] + + +def denote_condition_steps_successors_in_normal_flow( + reversed_steps_graph: nx.DiGraph, + reversed_flow_path: List[str], + condition_steps_successors: Dict[str, Set[str]], +) -> Tuple[Dict[str, Set[str]], int]: + conditions_steps = 0 + if len(reversed_flow_path) == 0: + return condition_steps_successors, conditions_steps + previous_node = reversed_flow_path[0] + for node in reversed_flow_path[1:]: + if is_condition_step(execution_graph=reversed_steps_graph, node=node): + condition_steps_successors[node].add(previous_node) + conditions_steps += 1 + previous_node = node + return condition_steps_successors, conditions_steps + + +def verify_that_steps_are_connected_with_compatible_inputs( + execution_graph: nx.DiGraph, +) -> None: + steps_nodes = get_nodes_of_specific_kind( + execution_graph=execution_graph, + kind=STEP_NODE_KIND, + ) + for step in steps_nodes: + verify_step_inputs_selectors(step=step, execution_graph=execution_graph) + + +def verify_step_inputs_selectors(step: str, execution_graph: nx.DiGraph) -> None: + step_definition = execution_graph.nodes[step]["definition"] + all_inputs = step_definition.get_input_names() + for input_step in all_inputs: + input_selector_or_value = getattr(step_definition, input_step) + if issubclass(type(input_selector_or_value), list): + for idx, single_selector_or_value in enumerate(input_selector_or_value): + validate_step_definition_input( + step_definition=step_definition, + input_name=input_step, + execution_graph=execution_graph, + input_selector_or_value=single_selector_or_value, + index=idx, + ) + else: + validate_step_definition_input( + step_definition=step_definition, + input_name=input_step, + execution_graph=execution_graph, + input_selector_or_value=input_selector_or_value, + ) + + +def validate_step_definition_input( + step_definition: StepInterface, + input_name: str, + execution_graph: nx.DiGraph, + input_selector_or_value: Any, + index: Optional[int] = None, +) -> None: + if not is_selector(selector_or_value=input_selector_or_value): + return None + if is_step_output_selector(selector_or_value=input_selector_or_value): + input_selector_or_value = get_step_selector_from_its_output( + step_output_selector=input_selector_or_value + ) + input_node_definition = execution_graph.nodes[input_selector_or_value]["definition"] + step_definition.validate_field_selector( + field_name=input_name, + input_step=input_node_definition, + index=index, + ) diff --git a/inference/enterprise/workflows/complier/runtime_input_validator.py b/inference/enterprise/workflows/complier/runtime_input_validator.py new file mode 100644 index 0000000000000000000000000000000000000000..ed79024ce1c19e6a4f49b44b157a87f314ff1364 --- /dev/null +++ b/inference/enterprise/workflows/complier/runtime_input_validator.py @@ -0,0 +1,188 @@ +from typing import Any, Dict, Optional, Set, Union + +import numpy as np +from networkx import DiGraph + +from inference.core.utils.image_utils import ImageType +from inference.enterprise.workflows.complier.steps_executors.constants import ( + IMAGE_TYPE_KEY, + IMAGE_VALUE_KEY, + PARENT_ID_KEY, +) +from inference.enterprise.workflows.complier.utils import ( + get_nodes_of_specific_kind, + is_input_selector, +) +from inference.enterprise.workflows.constants import INPUT_NODE_KIND, STEP_NODE_KIND +from inference.enterprise.workflows.entities.validators import get_last_selector_chunk +from inference.enterprise.workflows.errors import ( + InvalidStepInputDetected, + RuntimeParameterMissingError, +) + + +def prepare_runtime_parameters( + execution_graph: DiGraph, + runtime_parameters: Dict[str, Any], +) -> Dict[str, Any]: + ensure_all_parameters_filled( + execution_graph=execution_graph, + runtime_parameters=runtime_parameters, + ) + runtime_parameters = fill_runtime_parameters_with_defaults( + execution_graph=execution_graph, + runtime_parameters=runtime_parameters, + ) + runtime_parameters = assembly_input_images( + execution_graph=execution_graph, + runtime_parameters=runtime_parameters, + ) + validate_inputs_binding( + execution_graph=execution_graph, + runtime_parameters=runtime_parameters, + ) + return runtime_parameters + + +def ensure_all_parameters_filled( + execution_graph: DiGraph, + runtime_parameters: Dict[str, Any], +) -> None: + parameters_without_default_values = get_input_parameters_without_default_values( + execution_graph=execution_graph, + ) + missing_parameters = [] + for name in parameters_without_default_values: + if name not in runtime_parameters: + missing_parameters.append(name) + if len(missing_parameters) > 0: + raise RuntimeParameterMissingError( + f"Parameters passed to execution runtime do not define required inputs: {missing_parameters}" + ) + + +def get_input_parameters_without_default_values(execution_graph: DiGraph) -> Set[str]: + input_nodes = get_nodes_of_specific_kind( + execution_graph=execution_graph, + kind=INPUT_NODE_KIND, + ) + result = set() + for input_node in input_nodes: + definition = execution_graph.nodes[input_node]["definition"] + if definition.type == "InferenceImage": + result.add(definition.name) + continue + if definition.type == "InferenceParameter" and definition.default_value is None: + result.add(definition.name) + continue + return result + + +def fill_runtime_parameters_with_defaults( + execution_graph: DiGraph, + runtime_parameters: Dict[str, Any], +) -> Dict[str, Any]: + default_values_parameters = get_input_parameters_default_values( + execution_graph=execution_graph + ) + default_values_parameters.update(runtime_parameters) + return default_values_parameters + + +def get_input_parameters_default_values(execution_graph: DiGraph) -> Dict[str, Any]: + input_nodes = get_nodes_of_specific_kind( + execution_graph=execution_graph, + kind=INPUT_NODE_KIND, + ) + result = {} + for input_node in input_nodes: + definition = execution_graph.nodes[input_node]["definition"] + if ( + definition.type == "InferenceParameter" + and definition.default_value is not None + ): + result[definition.name] = definition.default_value + return result + + +def assembly_input_images( + execution_graph: DiGraph, + runtime_parameters: Dict[str, Any], +) -> Dict[str, Any]: + input_nodes = get_nodes_of_specific_kind( + execution_graph=execution_graph, + kind=INPUT_NODE_KIND, + ) + for input_node in input_nodes: + definition = execution_graph.nodes[input_node]["definition"] + if definition.type != "InferenceImage": + continue + if issubclass(type(runtime_parameters[definition.name]), list): + runtime_parameters[definition.name] = [ + assembly_input_image( + parameter=input_node, + image=image, + identifier=i, + ) + for i, image in enumerate(runtime_parameters[definition.name]) + ] + else: + runtime_parameters[definition.name] = [ + assembly_input_image( + parameter=input_node, image=runtime_parameters[definition.name] + ) + ] + return runtime_parameters + + +def assembly_input_image( + parameter: str, image: Any, identifier: Optional[int] = None +) -> Dict[str, Union[str, np.ndarray]]: + parent = parameter + if identifier is not None: + parent = f"{parent}.[{identifier}]" + if issubclass(type(image), dict): + image[PARENT_ID_KEY] = parent + return image + if issubclass(type(image), np.ndarray): + return { + IMAGE_TYPE_KEY: ImageType.NUMPY_OBJECT.value, + IMAGE_VALUE_KEY: image, + PARENT_ID_KEY: parent, + } + raise InvalidStepInputDetected( + f"Detected runtime parameter `{parameter}` defined as `InferenceImage` with type {type(image)} that is invalid." + ) + + +def validate_inputs_binding( + execution_graph: DiGraph, + runtime_parameters: Dict[str, Any], +) -> None: + step_nodes = get_nodes_of_specific_kind( + execution_graph=execution_graph, + kind=STEP_NODE_KIND, + ) + for step in step_nodes: + validate_step_input_bindings( + step=step, + execution_graph=execution_graph, + runtime_parameters=runtime_parameters, + ) + + +def validate_step_input_bindings( + step: str, + execution_graph: DiGraph, + runtime_parameters: Dict[str, Any], +) -> None: + step_definition = execution_graph.nodes[step]["definition"] + for input_name in step_definition.get_input_names(): + selector_or_value = getattr(step_definition, input_name) + if not is_input_selector(selector_or_value=selector_or_value): + continue + input_parameter_name = get_last_selector_chunk(selector=selector_or_value) + parameter_value = runtime_parameters[input_parameter_name] + step_definition.validate_field_binding( + field_name=input_name, value=parameter_value + ) diff --git a/inference/enterprise/workflows/complier/steps_executors/__init__.py b/inference/enterprise/workflows/complier/steps_executors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/enterprise/workflows/complier/steps_executors/__pycache__/__init__.cpython-310.pyc b/inference/enterprise/workflows/complier/steps_executors/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3287874192729ff4c7aee108d6b2d9a3491ad45b Binary files /dev/null and b/inference/enterprise/workflows/complier/steps_executors/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/complier/steps_executors/__pycache__/active_learning_middlewares.cpython-310.pyc b/inference/enterprise/workflows/complier/steps_executors/__pycache__/active_learning_middlewares.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3db729bd135c77c0244e0f28480c4043727da4c Binary files /dev/null and b/inference/enterprise/workflows/complier/steps_executors/__pycache__/active_learning_middlewares.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/complier/steps_executors/__pycache__/auxiliary.cpython-310.pyc b/inference/enterprise/workflows/complier/steps_executors/__pycache__/auxiliary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..512d8a45f7025f9224168b7b8152362d6de889c3 Binary files /dev/null and b/inference/enterprise/workflows/complier/steps_executors/__pycache__/auxiliary.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/complier/steps_executors/__pycache__/constants.cpython-310.pyc b/inference/enterprise/workflows/complier/steps_executors/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a7d2d2089efed367d224b3a6e7b6947c94153a7 Binary files /dev/null and b/inference/enterprise/workflows/complier/steps_executors/__pycache__/constants.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/complier/steps_executors/__pycache__/models.cpython-310.pyc b/inference/enterprise/workflows/complier/steps_executors/__pycache__/models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..108fcc43affe795b646c1a7710ad5bcc7f6dab3d Binary files /dev/null and b/inference/enterprise/workflows/complier/steps_executors/__pycache__/models.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/complier/steps_executors/__pycache__/types.cpython-310.pyc b/inference/enterprise/workflows/complier/steps_executors/__pycache__/types.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c2204455a5e943b67a36ae4e9fffdb4884d76b5 Binary files /dev/null and b/inference/enterprise/workflows/complier/steps_executors/__pycache__/types.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/complier/steps_executors/__pycache__/utils.cpython-310.pyc b/inference/enterprise/workflows/complier/steps_executors/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb6820af84047c7232480e9ecb04e456bc5917b9 Binary files /dev/null and b/inference/enterprise/workflows/complier/steps_executors/__pycache__/utils.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/complier/steps_executors/active_learning_middlewares.py b/inference/enterprise/workflows/complier/steps_executors/active_learning_middlewares.py new file mode 100644 index 0000000000000000000000000000000000000000..200b7118dcdcac7c34ce0322c68ff32fe5b8f61a --- /dev/null +++ b/inference/enterprise/workflows/complier/steps_executors/active_learning_middlewares.py @@ -0,0 +1,119 @@ +from typing import Dict, List, Optional, Union + +from fastapi import BackgroundTasks + +from inference.core import logger +from inference.core.active_learning.middlewares import ActiveLearningMiddleware +from inference.core.cache.base import BaseCache +from inference.core.env import DISABLE_PREPROC_AUTO_ORIENT +from inference.enterprise.workflows.entities.steps import ( + DisabledActiveLearningConfiguration, + EnabledActiveLearningConfiguration, +) + + +class WorkflowsActiveLearningMiddleware: + + def __init__( + self, + cache: BaseCache, + middlewares: Optional[Dict[str, ActiveLearningMiddleware]] = None, + ): + self._cache = cache + self._middlewares = middlewares if middlewares is not None else {} + + def register( + self, + dataset_name: str, + images: List[dict], + predictions: List[dict], + api_key: Optional[str], + prediction_type: str, + active_learning_disabled_for_request: bool, + background_tasks: Optional[BackgroundTasks] = None, + active_learning_configuration: Optional[ + Union[ + EnabledActiveLearningConfiguration, DisabledActiveLearningConfiguration + ] + ] = None, + ) -> None: + model_id = f"{dataset_name}/workflows" + if api_key is None or active_learning_disabled_for_request: + return None + if background_tasks is None: + self._register( + model_id=model_id, + images=images, + predictions=predictions, + api_key=api_key, + prediction_type=prediction_type, + active_learning_configuration=active_learning_configuration, + ) + return None + background_tasks.add_task( + self._register, + model_id=model_id, + images=images, + predictions=predictions, + api_key=api_key, + prediction_type=prediction_type, + active_learning_configuration=active_learning_configuration, + ) + + def _register( + self, + model_id: str, + images: List[dict], + predictions: List[dict], + api_key: str, + prediction_type: str, + active_learning_configuration: Optional[ + Union[ + EnabledActiveLearningConfiguration, DisabledActiveLearningConfiguration + ] + ], + ) -> None: + try: + self._ensure_middleware_initialised( + model_id=model_id, + api_key=api_key, + active_learning_configuration=active_learning_configuration, + ) + self._middlewares[model_id].register_batch( + inference_inputs=images, + predictions=predictions, + prediction_type=prediction_type, + disable_preproc_auto_orient=DISABLE_PREPROC_AUTO_ORIENT, + ) + except Exception as error: + # Error handling to be decided + logger.warning( + f"Error in datapoint registration for Active Learning. Details: {error}. " + f"Error is suppressed in favour of normal operations of API." + ) + + def _ensure_middleware_initialised( + self, + model_id: str, + api_key: str, + active_learning_configuration: Optional[ + Union[ + EnabledActiveLearningConfiguration, DisabledActiveLearningConfiguration + ] + ], + ) -> None: + if model_id in self._middlewares: + return None + if active_learning_configuration is not None: + self._middlewares[model_id] = ActiveLearningMiddleware.init_from_config( + api_key=api_key, + model_id=model_id, + cache=self._cache, + config=active_learning_configuration.dict(), + ) + else: + self._middlewares[model_id] = ActiveLearningMiddleware.init( + api_key=api_key, + model_id=model_id, + cache=self._cache, + ) diff --git a/inference/enterprise/workflows/complier/steps_executors/auxiliary.py b/inference/enterprise/workflows/complier/steps_executors/auxiliary.py new file mode 100644 index 0000000000000000000000000000000000000000..d6a791fa710dac5f1af19d5ab7e58d1b06018b7a --- /dev/null +++ b/inference/enterprise/workflows/complier/steps_executors/auxiliary.py @@ -0,0 +1,916 @@ +import itertools +import statistics +from collections import Counter, defaultdict +from copy import deepcopy +from functools import partial +from typing import Any, Callable, Dict, Generator, List, Optional, Set, Tuple, Union +from uuid import uuid4 + +import numpy as np +from fastapi import BackgroundTasks + +from inference.core.managers.base import ModelManager +from inference.core.utils.image_utils import ImageType, load_image +from inference.enterprise.workflows.complier.entities import StepExecutionMode +from inference.enterprise.workflows.complier.steps_executors.active_learning_middlewares import ( + WorkflowsActiveLearningMiddleware, +) +from inference.enterprise.workflows.complier.steps_executors.constants import ( + CENTER_X_KEY, + CENTER_Y_KEY, + DETECTION_ID_KEY, + HEIGHT_KEY, + IMAGE_TYPE_KEY, + IMAGE_VALUE_KEY, + ORIGIN_COORDINATES_KEY, + ORIGIN_SIZE_KEY, + PARENT_ID_KEY, + WIDTH_KEY, +) +from inference.enterprise.workflows.complier.steps_executors.types import ( + NextStepReference, + OutputsLookup, +) +from inference.enterprise.workflows.complier.steps_executors.utils import ( + get_image, + resolve_parameter, +) +from inference.enterprise.workflows.complier.utils import ( + construct_selector_pointing_step_output, + construct_step_selector, +) +from inference.enterprise.workflows.entities.steps import ( + AbsoluteStaticCrop, + ActiveLearningDataCollector, + AggregationMode, + BinaryOperator, + CompoundDetectionFilterDefinition, + Condition, + Crop, + DetectionFilter, + DetectionFilterDefinition, + DetectionOffset, + DetectionsConsensus, + Operator, + RelativeStaticCrop, +) +from inference.enterprise.workflows.entities.validators import get_last_selector_chunk +from inference.enterprise.workflows.errors import ExecutionGraphError + +OPERATORS = { + Operator.EQUAL: lambda a, b: a == b, + Operator.NOT_EQUAL: lambda a, b: a != b, + Operator.LOWER_THAN: lambda a, b: a < b, + Operator.GREATER_THAN: lambda a, b: a > b, + Operator.LOWER_OR_EQUAL_THAN: lambda a, b: a <= b, + Operator.GREATER_OR_EQUAL_THAN: lambda a, b: a >= b, + Operator.IN: lambda a, b: a in b, +} + +BINARY_OPERATORS = { + BinaryOperator.AND: lambda a, b: a and b, + BinaryOperator.OR: lambda a, b: a or b, +} + +AGGREGATION_MODE2FIELD_AGGREGATOR = { + AggregationMode.MAX: max, + AggregationMode.MIN: min, + AggregationMode.AVERAGE: statistics.mean, +} + + +async def run_crop_step( + step: Crop, + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, + model_manager: ModelManager, + api_key: Optional[str], + step_execution_mode: StepExecutionMode, +) -> Tuple[NextStepReference, OutputsLookup]: + image = get_image( + step=step, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + detections = resolve_parameter( + selector_or_value=step.detections, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + decoded_images = [load_image(e) for e in image] + decoded_images = [ + i[0] if i[1] is True else i[0][:, :, ::-1] for i in decoded_images + ] + origin_image_shape = extract_origin_size_from_images( + input_images=image, + decoded_images=decoded_images, + ) + crops = list( + itertools.chain.from_iterable( + crop_image(image=i, detections=d, origin_size=o) + for i, d, o in zip(decoded_images, detections, origin_image_shape) + ) + ) + parent_ids = [c[PARENT_ID_KEY] for c in crops] + outputs_lookup[construct_step_selector(step_name=step.name)] = { + "crops": crops, + PARENT_ID_KEY: parent_ids, + } + return None, outputs_lookup + + +def crop_image( + image: np.ndarray, + detections: List[dict], + origin_size: dict, +) -> List[Dict[str, Union[str, np.ndarray]]]: + crops = [] + for detection in detections: + x_min, y_min, x_max, y_max = detection_to_xyxy(detection=detection) + cropped_image = image[y_min:y_max, x_min:x_max] + crops.append( + { + IMAGE_TYPE_KEY: ImageType.NUMPY_OBJECT.value, + IMAGE_VALUE_KEY: cropped_image, + PARENT_ID_KEY: detection[DETECTION_ID_KEY], + ORIGIN_COORDINATES_KEY: { + CENTER_X_KEY: detection["x"], + CENTER_Y_KEY: detection["y"], + ORIGIN_SIZE_KEY: origin_size, + }, + } + ) + return crops + + +async def run_condition_step( + step: Condition, + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, + model_manager: ModelManager, + api_key: Optional[str], + step_execution_mode: StepExecutionMode, +) -> Tuple[NextStepReference, OutputsLookup]: + left_value = resolve_parameter( + selector_or_value=step.left, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + right_value = resolve_parameter( + selector_or_value=step.right, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + evaluation_result = OPERATORS[step.operator](left_value, right_value) + next_step = step.step_if_true if evaluation_result else step.step_if_false + return next_step, outputs_lookup + + +async def run_detection_filter( + step: DetectionFilter, + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, + model_manager: ModelManager, + api_key: Optional[str], + step_execution_mode: StepExecutionMode, +) -> Tuple[NextStepReference, OutputsLookup]: + predictions = resolve_parameter( + selector_or_value=step.predictions, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + images_meta_selector = construct_selector_pointing_step_output( + selector=step.predictions, + new_output="image", + ) + images_meta = resolve_parameter( + selector_or_value=images_meta_selector, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + prediction_type_selector = construct_selector_pointing_step_output( + selector=step.predictions, + new_output="prediction_type", + ) + predictions_type = resolve_parameter( + selector_or_value=prediction_type_selector, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + filter_callable = build_filter_callable(definition=step.filter_definition) + result_detections, result_parent_id = [], [] + for prediction in predictions: + filtered_predictions = [deepcopy(p) for p in prediction if filter_callable(p)] + result_detections.append(filtered_predictions) + result_parent_id.append([p[PARENT_ID_KEY] for p in filtered_predictions]) + step_selector = construct_step_selector(step_name=step.name) + outputs_lookup[step_selector] = [ + {"predictions": d, PARENT_ID_KEY: p, "image": i, "prediction_type": pt} + for d, p, i, pt in zip( + result_detections, result_parent_id, images_meta, predictions_type + ) + ] + return None, outputs_lookup + + +def build_filter_callable( + definition: Union[DetectionFilterDefinition, CompoundDetectionFilterDefinition], +) -> Callable[[dict], bool]: + if definition.type == "CompoundDetectionFilterDefinition": + left_callable = build_filter_callable(definition=definition.left) + right_callable = build_filter_callable(definition=definition.right) + binary_operator = BINARY_OPERATORS[definition.operator] + return lambda e: binary_operator(left_callable(e), right_callable(e)) + if definition.type == "DetectionFilterDefinition": + operator = OPERATORS[definition.operator] + return lambda e: operator(e[definition.field_name], definition.reference_value) + raise ExecutionGraphError( + f"Detected filter definition of type {definition.type} which is unknown" + ) + + +async def run_detection_offset_step( + step: DetectionOffset, + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, + model_manager: ModelManager, + api_key: Optional[str], + step_execution_mode: StepExecutionMode, +) -> Tuple[NextStepReference, OutputsLookup]: + detections = resolve_parameter( + selector_or_value=step.predictions, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + images_meta_selector = construct_selector_pointing_step_output( + selector=step.predictions, + new_output="image", + ) + images_meta = resolve_parameter( + selector_or_value=images_meta_selector, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + prediction_type_selector = construct_selector_pointing_step_output( + selector=step.predictions, + new_output="prediction_type", + ) + predictions_type = resolve_parameter( + selector_or_value=prediction_type_selector, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + offset_x = resolve_parameter( + selector_or_value=step.offset_x, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + offset_y = resolve_parameter( + selector_or_value=step.offset_y, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + result_detections, result_parent_id = [], [] + for detection in detections: + offset_detections = [ + offset_detection(detection=d, offset_x=offset_x, offset_y=offset_y) + for d in detection + ] + result_detections.append(offset_detections) + result_parent_id.append([d[PARENT_ID_KEY] for d in offset_detections]) + step_selector = construct_step_selector(step_name=step.name) + outputs_lookup[step_selector] = [ + {"predictions": d, PARENT_ID_KEY: p, "image": i, "prediction_type": pt} + for d, p, i, pt in zip( + result_detections, result_parent_id, images_meta, predictions_type + ) + ] + return None, outputs_lookup + + +def offset_detection( + detection: Dict[str, Any], offset_x: int, offset_y: int +) -> Dict[str, Any]: + detection_copy = deepcopy(detection) + detection_copy[WIDTH_KEY] += round(offset_x) + detection_copy[HEIGHT_KEY] += round(offset_y) + detection_copy[PARENT_ID_KEY] = detection_copy[DETECTION_ID_KEY] + detection_copy[DETECTION_ID_KEY] = str(uuid4()) + return detection_copy + + +async def run_static_crop_step( + step: Union[AbsoluteStaticCrop, RelativeStaticCrop], + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, + model_manager: ModelManager, + api_key: Optional[str], + step_execution_mode: StepExecutionMode, +) -> Tuple[NextStepReference, OutputsLookup]: + image = get_image( + step=step, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + decoded_images = [load_image(e) for e in image] + decoded_images = [ + i[0] if i[1] is True else i[0][:, :, ::-1] for i in decoded_images + ] + origin_image_shape = extract_origin_size_from_images( + input_images=image, + decoded_images=decoded_images, + ) + crops = [ + take_static_crop( + image=i, + crop=step, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + origin_size=size, + ) + for i, size in zip(decoded_images, origin_image_shape) + ] + parent_ids = [c[PARENT_ID_KEY] for c in crops] + outputs_lookup[construct_step_selector(step_name=step.name)] = { + "crops": crops, + PARENT_ID_KEY: parent_ids, + } + return None, outputs_lookup + + +def extract_origin_size_from_images( + input_images: List[Union[dict, np.ndarray]], + decoded_images: List[np.ndarray], +) -> List[Dict[str, int]]: + result = [] + for input_image, decoded_image in zip(input_images, decoded_images): + if ( + issubclass(type(input_image), dict) + and ORIGIN_COORDINATES_KEY in input_image + ): + result.append(input_image[ORIGIN_COORDINATES_KEY][ORIGIN_SIZE_KEY]) + else: + result.append( + {HEIGHT_KEY: decoded_image.shape[0], WIDTH_KEY: decoded_image.shape[1]} + ) + return result + + +def take_static_crop( + image: np.ndarray, + crop: Union[AbsoluteStaticCrop, RelativeStaticCrop], + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, + origin_size: dict, +) -> Dict[str, Union[str, np.ndarray]]: + resolve_parameter_closure = partial( + resolve_parameter, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + x_center = resolve_parameter_closure(crop.x_center) + y_center = resolve_parameter_closure(crop.y_center) + width = resolve_parameter_closure(crop.width) + height = resolve_parameter_closure(crop.height) + if crop.type == "RelativeStaticCrop": + x_center = round(image.shape[1] * x_center) + y_center = round(image.shape[0] * y_center) + width = round(image.shape[1] * width) + height = round(image.shape[0] * height) + x_min = round(x_center - width / 2) + y_min = round(y_center - height / 2) + x_max = round(x_min + width) + y_max = round(y_min + height) + cropped_image = image[y_min:y_max, x_min:x_max] + return { + IMAGE_TYPE_KEY: ImageType.NUMPY_OBJECT.value, + IMAGE_VALUE_KEY: cropped_image, + PARENT_ID_KEY: f"$steps.{crop.name}", + ORIGIN_COORDINATES_KEY: { + CENTER_X_KEY: x_center, + CENTER_Y_KEY: y_center, + ORIGIN_SIZE_KEY: origin_size, + }, + } + + +async def run_detections_consensus_step( + step: DetectionsConsensus, + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, + model_manager: ModelManager, + api_key: Optional[str], + step_execution_mode: StepExecutionMode, +) -> Tuple[NextStepReference, OutputsLookup]: + resolve_parameter_closure = partial( + resolve_parameter, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + all_predictions = [resolve_parameter_closure(p) for p in step.predictions] + # all_predictions has shape (n_consensus_input, bs, img_predictions) + if len(all_predictions) < 1: + raise ExecutionGraphError( + f"Consensus step requires at least one source of predictions." + ) + batch_sizes = get_and_validate_batch_sizes( + all_predictions=all_predictions, + step_name=step.name, + ) + images_meta_selector = construct_selector_pointing_step_output( + selector=step.predictions[0], + new_output="image", + ) + images_meta = resolve_parameter_closure(images_meta_selector) + batch_size = batch_sizes[0] + results = [] + for batch_index in range(batch_size): + batch_element_predictions = [e[batch_index] for e in all_predictions] + ( + parent_id, + object_present, + presence_confidence, + consensus_detections, + ) = resolve_batch_consensus( + predictions=batch_element_predictions, + required_votes=resolve_parameter_closure(step.required_votes), + class_aware=resolve_parameter_closure(step.class_aware), + iou_threshold=resolve_parameter_closure(step.iou_threshold), + confidence=resolve_parameter_closure(step.confidence), + classes_to_consider=resolve_parameter_closure(step.classes_to_consider), + required_objects=resolve_parameter_closure(step.required_objects), + presence_confidence_aggregation=step.presence_confidence_aggregation, + detections_merge_confidence_aggregation=step.detections_merge_confidence_aggregation, + detections_merge_coordinates_aggregation=step.detections_merge_coordinates_aggregation, + ) + results.append( + { + "predictions": consensus_detections, + "parent_id": parent_id, + "object_present": object_present, + "presence_confidence": presence_confidence, + "image": images_meta[batch_index], + "prediction_type": "object-detection", + } + ) + outputs_lookup[construct_step_selector(step_name=step.name)] = results + return None, outputs_lookup + + +def get_and_validate_batch_sizes( + all_predictions: List[List[List[dict]]], + step_name: str, +) -> List[int]: + batch_sizes = get_predictions_batch_sizes(all_predictions=all_predictions) + if not all_batch_sizes_equal(batch_sizes=batch_sizes): + raise ExecutionGraphError( + f"Detected missmatch of input dimensions in step: {step_name}" + ) + return batch_sizes + + +def get_predictions_batch_sizes(all_predictions: List[List[List[dict]]]) -> List[int]: + return [len(predictions) for predictions in all_predictions] + + +def all_batch_sizes_equal(batch_sizes: List[int]) -> bool: + if len(batch_sizes) == 0: + return True + reference = batch_sizes[0] + return all(e == reference for e in batch_sizes) + + +def resolve_batch_consensus( + predictions: List[List[dict]], + required_votes: int, + class_aware: bool, + iou_threshold: float, + confidence: float, + classes_to_consider: Optional[List[str]], + required_objects: Optional[Union[int, Dict[str, int]]], + presence_confidence_aggregation: AggregationMode, + detections_merge_confidence_aggregation: AggregationMode, + detections_merge_coordinates_aggregation: AggregationMode, +) -> Tuple[str, bool, Dict[str, float], List[dict]]: + if does_not_detected_objects_in_any_source(predictions=predictions): + return "undefined", False, {}, [] + parent_id = get_parent_id_of_predictions_from_different_sources( + predictions=predictions, + ) + predictions = filter_predictions( + predictions=predictions, + classes_to_consider=classes_to_consider, + ) + detections_already_considered = set() + consensus_detections = [] + for source_id, detection in enumerate_detections(predictions=predictions): + ( + consensus_detections_update, + detections_already_considered, + ) = get_consensus_for_single_detection( + detection=detection, + source_id=source_id, + predictions=predictions, + iou_threshold=iou_threshold, + class_aware=class_aware, + required_votes=required_votes, + confidence=confidence, + detections_merge_confidence_aggregation=detections_merge_confidence_aggregation, + detections_merge_coordinates_aggregation=detections_merge_coordinates_aggregation, + detections_already_considered=detections_already_considered, + ) + consensus_detections += consensus_detections_update + ( + object_present, + presence_confidence, + ) = check_objects_presence_in_consensus_predictions( + consensus_detections=consensus_detections, + aggregation_mode=presence_confidence_aggregation, + class_aware=class_aware, + required_objects=required_objects, + ) + return ( + parent_id, + object_present, + presence_confidence, + consensus_detections, + ) + + +def get_consensus_for_single_detection( + detection: dict, + source_id: int, + predictions: List[List[dict]], + iou_threshold: float, + class_aware: bool, + required_votes: int, + confidence: float, + detections_merge_confidence_aggregation: AggregationMode, + detections_merge_coordinates_aggregation: AggregationMode, + detections_already_considered: Set[str], +) -> Tuple[List[dict], Set[str]]: + if detection["detection_id"] in detections_already_considered: + return ([], detections_already_considered) + consensus_detections = [] + detections_with_max_overlap = ( + get_detections_from_different_sources_with_max_overlap( + detection=detection, + source=source_id, + predictions=predictions, + iou_threshold=iou_threshold, + class_aware=class_aware, + detections_already_considered=detections_already_considered, + ) + ) + if len(detections_with_max_overlap) < (required_votes - 1): + return consensus_detections, detections_already_considered + detections_to_merge = [detection] + [ + matched_value[0] for matched_value in detections_with_max_overlap.values() + ] + merged_detection = merge_detections( + detections=detections_to_merge, + confidence_aggregation_mode=detections_merge_confidence_aggregation, + boxes_aggregation_mode=detections_merge_coordinates_aggregation, + ) + if merged_detection["confidence"] < confidence: + return consensus_detections, detections_already_considered + consensus_detections.append(merged_detection) + detections_already_considered.add(detection[DETECTION_ID_KEY]) + for matched_value in detections_with_max_overlap.values(): + detections_already_considered.add(matched_value[0][DETECTION_ID_KEY]) + return consensus_detections, detections_already_considered + + +def check_objects_presence_in_consensus_predictions( + consensus_detections: List[dict], + class_aware: bool, + aggregation_mode: AggregationMode, + required_objects: Optional[Union[int, Dict[str, int]]], +) -> Tuple[bool, Dict[str, float]]: + if len(consensus_detections) == 0: + return False, {} + if required_objects is None: + required_objects = 0 + if issubclass(type(required_objects), dict) and not class_aware: + required_objects = sum(required_objects.values()) + if ( + issubclass(type(required_objects), int) + and len(consensus_detections) < required_objects + ): + return False, {} + if not class_aware: + aggregated_confidence = aggregate_field_values( + detections=consensus_detections, + field="confidence", + aggregation_mode=aggregation_mode, + ) + return True, {"any_object": aggregated_confidence} + class2detections = defaultdict(list) + for detection in consensus_detections: + class2detections[detection["class"]].append(detection) + if issubclass(type(required_objects), dict): + for requested_class, required_objects_count in required_objects.items(): + if len(class2detections[requested_class]) < required_objects_count: + return False, {} + class2confidence = { + class_name: aggregate_field_values( + detections=class_detections, + field="confidence", + aggregation_mode=aggregation_mode, + ) + for class_name, class_detections in class2detections.items() + } + return True, class2confidence + + +def does_not_detected_objects_in_any_source(predictions: List[List[dict]]) -> bool: + return all(len(p) == 0 for p in predictions) + + +def get_parent_id_of_predictions_from_different_sources( + predictions: List[List[dict]], +) -> str: + encountered_parent_ids = { + p[PARENT_ID_KEY] for prediction_source in predictions for p in prediction_source + } + if len(encountered_parent_ids) > 1: + raise ExecutionGraphError( + f"Missmatch in predictions - while executing consensus step, " + f"in equivalent batches, detections are assigned different parent " + f"identifiers, whereas consensus can only be applied for predictions " + f"made against the same input." + ) + return list(encountered_parent_ids)[0] + + +def filter_predictions( + predictions: List[List[dict]], + classes_to_consider: Optional[List[str]], +) -> List[List[dict]]: + if classes_to_consider is None: + return predictions + classes_to_consider = set(classes_to_consider) + return [ + [ + detection + for detection in detections + if detection["class"] in classes_to_consider + ] + for detections in predictions + ] + + +def get_detections_from_different_sources_with_max_overlap( + detection: dict, + source: int, + predictions: List[List[dict]], + iou_threshold: float, + class_aware: bool, + detections_already_considered: Set[str], +) -> Dict[int, Tuple[dict, float]]: + current_max_overlap = {} + for other_source, other_detection in enumerate_detections( + predictions=predictions, + excluded_source=source, + ): + if other_detection[DETECTION_ID_KEY] in detections_already_considered: + continue + if class_aware and detection["class"] != other_detection["class"]: + continue + iou_value = calculate_iou( + detection_a=detection, + detection_b=other_detection, + ) + if iou_value <= iou_threshold: + continue + if current_max_overlap.get(other_source) is None: + current_max_overlap[other_source] = (other_detection, iou_value) + if current_max_overlap[other_source][1] < iou_value: + current_max_overlap[other_source] = (other_detection, iou_value) + return current_max_overlap + + +def enumerate_detections( + predictions: List[List[dict]], + excluded_source: Optional[int] = None, +) -> Generator[Tuple[int, dict], None, None]: + for source_id, detections in enumerate(predictions): + if excluded_source is not None and excluded_source == source_id: + continue + for detection in detections: + yield source_id, detection + + +def calculate_iou(detection_a: dict, detection_b: dict) -> float: + box_a = detection_to_xyxy(detection=detection_a) + box_b = detection_to_xyxy(detection=detection_b) + x_a = max(box_a[0], box_b[0]) + y_a = max(box_a[1], box_b[1]) + x_b = min(box_a[2], box_b[2]) + y_b = min(box_a[3], box_b[3]) + intersection = max(0, x_b - x_a) * max(0, y_b - y_a) + bbox_a_area, bbox_b_area = get_detection_sizes( + detections=[detection_a, detection_b] + ) + union = float(bbox_a_area + bbox_b_area - intersection) + if union == 0.0: + return 0.0 + return intersection / float(bbox_a_area + bbox_b_area - intersection) + + +def detection_to_xyxy(detection: dict) -> Tuple[int, int, int, int]: + x_min = round(detection["x"] - detection[WIDTH_KEY] / 2) + y_min = round(detection["y"] - detection[HEIGHT_KEY] / 2) + x_max = round(x_min + detection[WIDTH_KEY]) + y_max = round(y_min + detection[HEIGHT_KEY]) + return x_min, y_min, x_max, y_max + + +def merge_detections( + detections: List[dict], + confidence_aggregation_mode: AggregationMode, + boxes_aggregation_mode: AggregationMode, +) -> dict: + class_name, class_id = AGGREGATION_MODE2CLASS_SELECTOR[confidence_aggregation_mode]( + detections + ) + x, y, width, height = AGGREGATION_MODE2BOXES_AGGREGATOR[boxes_aggregation_mode]( + detections + ) + return { + PARENT_ID_KEY: detections[0][PARENT_ID_KEY], + DETECTION_ID_KEY: f"{uuid4()}", + "class": class_name, + "class_id": class_id, + "confidence": aggregate_field_values( + detections=detections, + field="confidence", + aggregation_mode=confidence_aggregation_mode, + ), + "x": x, + "y": y, + "width": width, + "height": height, + } + + +def get_majority_class(detections: List[dict]) -> Tuple[str, int]: + class_counts = Counter(d["class"] for d in detections) + most_common_class_name = class_counts.most_common(1)[0][0] + class_id = [ + d["class_id"] for d in detections if d["class"] == most_common_class_name + ][0] + return most_common_class_name, class_id + + +def get_class_of_most_confident_detection(detections: List[dict]) -> Tuple[str, int]: + max_confidence = aggregate_field_values( + detections=detections, + field="confidence", + aggregation_mode=AggregationMode.MAX, + ) + most_confident_prediction = [ + d for d in detections if d["confidence"] == max_confidence + ][0] + return most_confident_prediction["class"], most_confident_prediction["class_id"] + + +def get_class_of_least_confident_detection(detections: List[dict]) -> Tuple[str, int]: + max_confidence = aggregate_field_values( + detections=detections, + field="confidence", + aggregation_mode=AggregationMode.MIN, + ) + most_confident_prediction = [ + d for d in detections if d["confidence"] == max_confidence + ][0] + return most_confident_prediction["class"], most_confident_prediction["class_id"] + + +AGGREGATION_MODE2CLASS_SELECTOR = { + AggregationMode.MAX: get_class_of_most_confident_detection, + AggregationMode.MIN: get_class_of_least_confident_detection, + AggregationMode.AVERAGE: get_majority_class, +} + + +def get_average_bounding_box(detections: List[dict]) -> Tuple[int, int, int, int]: + x = round(aggregate_field_values(detections=detections, field="x")) + y = round(aggregate_field_values(detections=detections, field="y")) + width = round(aggregate_field_values(detections=detections, field="width")) + height = round(aggregate_field_values(detections=detections, field="height")) + return x, y, width, height + + +def get_smallest_bounding_box(detections: List[dict]) -> Tuple[int, int, int, int]: + detection_sizes = get_detection_sizes(detections=detections) + smallest_size = min(detection_sizes) + matching_detection_id = [ + idx for idx, v in enumerate(detection_sizes) if v == smallest_size + ][0] + matching_detection = detections[matching_detection_id] + return ( + matching_detection["x"], + matching_detection["y"], + matching_detection["width"], + matching_detection["height"], + ) + + +def get_largest_bounding_box(detections: List[dict]) -> Tuple[int, int, int, int]: + detection_sizes = get_detection_sizes(detections=detections) + largest_size = max(detection_sizes) + matching_detection_id = [ + idx for idx, v in enumerate(detection_sizes) if v == largest_size + ][0] + matching_detection = detections[matching_detection_id] + return ( + matching_detection["x"], + matching_detection["y"], + matching_detection[WIDTH_KEY], + matching_detection[HEIGHT_KEY], + ) + + +AGGREGATION_MODE2BOXES_AGGREGATOR = { + AggregationMode.MAX: get_largest_bounding_box, + AggregationMode.MIN: get_smallest_bounding_box, + AggregationMode.AVERAGE: get_average_bounding_box, +} + + +def get_detection_sizes(detections: List[dict]) -> List[float]: + return [d[HEIGHT_KEY] * d[WIDTH_KEY] for d in detections] + + +def aggregate_field_values( + detections: List[dict], + field: str, + aggregation_mode: AggregationMode = AggregationMode.AVERAGE, +) -> float: + values = [d[field] for d in detections] + return AGGREGATION_MODE2FIELD_AGGREGATOR[aggregation_mode](values) + + +async def run_active_learning_data_collector( + step: ActiveLearningDataCollector, + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, + model_manager: ModelManager, + api_key: Optional[str], + step_execution_mode: StepExecutionMode, + active_learning_middleware: WorkflowsActiveLearningMiddleware, + background_tasks: Optional[BackgroundTasks], +) -> Tuple[NextStepReference, OutputsLookup]: + resolve_parameter_closure = partial( + resolve_parameter, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + image = get_image( + step=step, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + images_meta_selector = construct_selector_pointing_step_output( + selector=step.predictions, + new_output="image", + ) + images_meta = resolve_parameter_closure(images_meta_selector) + prediction_type_selector = construct_selector_pointing_step_output( + selector=step.predictions, + new_output="prediction_type", + ) + predictions_type = resolve_parameter( + selector_or_value=prediction_type_selector, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + prediction_type = set(predictions_type) + if len(prediction_type) > 1: + raise ExecutionGraphError( + f"Active Learning data collection step requires only single prediction " + f"type to be part of ingest. Detected: {prediction_type}." + ) + prediction_type = next(iter(prediction_type)) + predictions = resolve_parameter_closure(step.predictions) + predictions_output_name = get_last_selector_chunk(step.predictions) + target_dataset = resolve_parameter_closure(step.target_dataset) + target_dataset_api_key = resolve_parameter_closure(step.target_dataset_api_key) + disable_active_learning = resolve_parameter_closure(step.disable_active_learning) + active_learning_compatible_predictions = [ + {"image": image_meta, predictions_output_name: prediction} + for image_meta, prediction in zip(images_meta, predictions) + ] + active_learning_middleware.register( + # this should actually be asyncio, but that requires a lot of backend components redesign + dataset_name=target_dataset, + images=image, + predictions=active_learning_compatible_predictions, + api_key=target_dataset_api_key or api_key, + active_learning_disabled_for_request=disable_active_learning, + prediction_type=prediction_type, + background_tasks=background_tasks, + active_learning_configuration=step.active_learning_configuration, + ) + return None, outputs_lookup diff --git a/inference/enterprise/workflows/complier/steps_executors/constants.py b/inference/enterprise/workflows/complier/steps_executors/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..92edda1bf008e278c15fd7a1059803ad4721fc4f --- /dev/null +++ b/inference/enterprise/workflows/complier/steps_executors/constants.py @@ -0,0 +1,11 @@ +IMAGE_TYPE_KEY = "type" +IMAGE_VALUE_KEY = "value" +PARENT_ID_KEY = "parent_id" +ORIGIN_COORDINATES_KEY = "origin_coordinates" +CENTER_X_KEY = "center_x" +CENTER_Y_KEY = "center_y" +ORIGIN_SIZE_KEY = "origin_image_size" +WIDTH_KEY = "width" +HEIGHT_KEY = "height" +DETECTION_ID_KEY = "detection_id" +PARENT_COORDINATES_SUFFIX = "_parent_coordinates" diff --git a/inference/enterprise/workflows/complier/steps_executors/models.py b/inference/enterprise/workflows/complier/steps_executors/models.py new file mode 100644 index 0000000000000000000000000000000000000000..f51ab274d0ebee9c84431077973c6fb0eeeead52 --- /dev/null +++ b/inference/enterprise/workflows/complier/steps_executors/models.py @@ -0,0 +1,816 @@ +import asyncio +from copy import deepcopy +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Union +from uuid import uuid4 + +from inference.core.entities.requests.clip import ClipCompareRequest +from inference.core.entities.requests.doctr import DoctrOCRInferenceRequest +from inference.core.entities.requests.inference import ( + ClassificationInferenceRequest, + InstanceSegmentationInferenceRequest, + KeypointsDetectionInferenceRequest, + ObjectDetectionInferenceRequest, +) +from inference.core.entities.requests.yolo_world import YOLOWorldInferenceRequest +from inference.core.env import ( + HOSTED_CLASSIFICATION_URL, + HOSTED_CORE_MODEL_URL, + HOSTED_DETECT_URL, + HOSTED_INSTANCE_SEGMENTATION_URL, + LOCAL_INFERENCE_API_URL, + WORKFLOWS_REMOTE_API_TARGET, + WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, + WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, +) +from inference.core.managers.base import ModelManager +from inference.enterprise.workflows.complier.entities import StepExecutionMode +from inference.enterprise.workflows.complier.steps_executors.constants import ( + CENTER_X_KEY, + CENTER_Y_KEY, + ORIGIN_COORDINATES_KEY, + ORIGIN_SIZE_KEY, + PARENT_COORDINATES_SUFFIX, +) +from inference.enterprise.workflows.complier.steps_executors.types import ( + NextStepReference, + OutputsLookup, +) +from inference.enterprise.workflows.complier.steps_executors.utils import ( + get_image, + make_batches, + resolve_parameter, +) +from inference.enterprise.workflows.complier.utils import construct_step_selector +from inference.enterprise.workflows.entities.steps import ( + ClassificationModel, + ClipComparison, + InstanceSegmentationModel, + KeypointsDetectionModel, + MultiLabelClassificationModel, + ObjectDetectionModel, + OCRModel, + RoboflowModel, + StepInterface, + YoloWorld, +) +from inference_sdk import InferenceConfiguration, InferenceHTTPClient + +MODEL_TYPE2PREDICTION_TYPE = { + "ClassificationModel": "classification", + "MultiLabelClassificationModel": "classification", + "ObjectDetectionModel": "object-detection", + "InstanceSegmentationModel": "instance-segmentation", + "KeypointsDetectionModel": "keypoint-detection", +} + + +async def run_roboflow_model_step( + step: RoboflowModel, + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, + model_manager: ModelManager, + api_key: Optional[str], + step_execution_mode: StepExecutionMode, +) -> Tuple[NextStepReference, OutputsLookup]: + model_id = resolve_parameter( + selector_or_value=step.model_id, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + image = get_image( + step=step, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + if step_execution_mode is StepExecutionMode.LOCAL: + serialised_result = await get_roboflow_model_predictions_locally( + image=image, + model_id=model_id, + step=step, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + model_manager=model_manager, + api_key=api_key, + ) + else: + serialised_result = await get_roboflow_model_predictions_from_remote_api( + image=image, + model_id=model_id, + step=step, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + api_key=api_key, + ) + serialised_result = attach_prediction_type_info( + results=serialised_result, + prediction_type=MODEL_TYPE2PREDICTION_TYPE[step.get_type()], + ) + if step.type in {"ClassificationModel", "MultiLabelClassificationModel"}: + serialised_result = attach_parent_info( + image=image, results=serialised_result, nested_key=None + ) + else: + serialised_result = attach_parent_info(image=image, results=serialised_result) + serialised_result = anchor_detections_in_parent_coordinates( + image=image, + serialised_result=serialised_result, + ) + outputs_lookup[construct_step_selector(step_name=step.name)] = serialised_result + return None, outputs_lookup + + +async def get_roboflow_model_predictions_locally( + image: List[dict], + model_id: str, + step: RoboflowModel, + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, + model_manager: ModelManager, + api_key: Optional[str], +) -> List[dict]: + request_constructor = MODEL_TYPE2REQUEST_CONSTRUCTOR[step.type] + request = request_constructor( + step=step, + image=image, + api_key=api_key, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + model_manager.add_model( + model_id=model_id, + api_key=api_key, + ) + result = await model_manager.infer_from_request(model_id=model_id, request=request) + if issubclass(type(result), list): + serialised_result = [e.dict(by_alias=True, exclude_none=True) for e in result] + else: + serialised_result = [result.dict(by_alias=True, exclude_none=True)] + return serialised_result + + +def construct_classification_request( + step: Union[ClassificationModel, MultiLabelClassificationModel], + image: Any, + api_key: Optional[str], + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, +) -> ClassificationInferenceRequest: + resolve = partial( + resolve_parameter, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + return ClassificationInferenceRequest( + api_key=api_key, + model_id=resolve(step.model_id), + image=image, + confidence=resolve(step.confidence), + disable_active_learning=resolve(step.disable_active_learning), + ) + + +def construct_object_detection_request( + step: ObjectDetectionModel, + image: Any, + api_key: Optional[str], + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, +) -> ObjectDetectionInferenceRequest: + resolve = partial( + resolve_parameter, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + return ObjectDetectionInferenceRequest( + api_key=api_key, + model_id=resolve(step.model_id), + image=image, + disable_active_learning=resolve(step.disable_active_learning), + class_agnostic_nms=resolve(step.class_agnostic_nms), + class_filter=resolve(step.class_filter), + confidence=resolve(step.confidence), + iou_threshold=resolve(step.iou_threshold), + max_detections=resolve(step.max_detections), + max_candidates=resolve(step.max_candidates), + ) + + +def construct_instance_segmentation_request( + step: InstanceSegmentationModel, + image: Any, + api_key: Optional[str], + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, +) -> InstanceSegmentationInferenceRequest: + resolve = partial( + resolve_parameter, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + return InstanceSegmentationInferenceRequest( + api_key=api_key, + model_id=resolve(step.model_id), + image=image, + disable_active_learning=resolve(step.disable_active_learning), + class_agnostic_nms=resolve(step.class_agnostic_nms), + class_filter=resolve(step.class_filter), + confidence=resolve(step.confidence), + iou_threshold=resolve(step.iou_threshold), + max_detections=resolve(step.max_detections), + max_candidates=resolve(step.max_candidates), + mask_decode_mode=resolve(step.mask_decode_mode), + tradeoff_factor=resolve(step.tradeoff_factor), + ) + + +def construct_keypoints_detection_request( + step: KeypointsDetectionModel, + image: Any, + api_key: Optional[str], + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, +) -> KeypointsDetectionInferenceRequest: + resolve = partial( + resolve_parameter, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + return KeypointsDetectionInferenceRequest( + api_key=api_key, + model_id=resolve(step.model_id), + image=image, + disable_active_learning=resolve(step.disable_active_learning), + class_agnostic_nms=resolve(step.class_agnostic_nms), + class_filter=resolve(step.class_filter), + confidence=resolve(step.confidence), + iou_threshold=resolve(step.iou_threshold), + max_detections=resolve(step.max_detections), + max_candidates=resolve(step.max_candidates), + keypoint_confidence=resolve(step.keypoint_confidence), + ) + + +MODEL_TYPE2REQUEST_CONSTRUCTOR = { + "ClassificationModel": construct_classification_request, + "MultiLabelClassificationModel": construct_classification_request, + "ObjectDetectionModel": construct_object_detection_request, + "InstanceSegmentationModel": construct_instance_segmentation_request, + "KeypointsDetectionModel": construct_keypoints_detection_request, +} + + +async def get_roboflow_model_predictions_from_remote_api( + image: List[dict], + model_id: str, + step: RoboflowModel, + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, + api_key: Optional[str], +) -> List[dict]: + api_url = resolve_model_api_url(step=step) + client = InferenceHTTPClient( + api_url=api_url, + api_key=api_key, + ) + if WORKFLOWS_REMOTE_API_TARGET == "hosted": + client.select_api_v0() + configuration = MODEL_TYPE2HTTP_CLIENT_CONSTRUCTOR[step.type]( + step=step, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + client.configure(inference_configuration=configuration) + inference_input = [i["value"] for i in image] + results = await client.infer_async( + inference_input=inference_input, + model_id=model_id, + ) + if not issubclass(type(results), list): + return [results] + return results + + +def construct_http_client_configuration_for_classification_step( + step: Union[ClassificationModel, MultiLabelClassificationModel], + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, +) -> InferenceConfiguration: + resolve = partial( + resolve_parameter, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + return InferenceConfiguration( + confidence_threshold=resolve(step.confidence), + disable_active_learning=resolve(step.disable_active_learning), + max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, + max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, + ) + + +def construct_http_client_configuration_for_detection_step( + step: ObjectDetectionModel, + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, +) -> InferenceConfiguration: + resolve = partial( + resolve_parameter, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + return InferenceConfiguration( + disable_active_learning=resolve(step.disable_active_learning), + class_agnostic_nms=resolve(step.class_agnostic_nms), + class_filter=resolve(step.class_filter), + confidence_threshold=resolve(step.confidence), + iou_threshold=resolve(step.iou_threshold), + max_detections=resolve(step.max_detections), + max_candidates=resolve(step.max_candidates), + max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, + max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, + ) + + +def construct_http_client_configuration_for_segmentation_step( + step: InstanceSegmentationModel, + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, +) -> InferenceConfiguration: + resolve = partial( + resolve_parameter, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + return InferenceConfiguration( + disable_active_learning=resolve(step.disable_active_learning), + class_agnostic_nms=resolve(step.class_agnostic_nms), + class_filter=resolve(step.class_filter), + confidence_threshold=resolve(step.confidence), + iou_threshold=resolve(step.iou_threshold), + max_detections=resolve(step.max_detections), + max_candidates=resolve(step.max_candidates), + mask_decode_mode=resolve(step.mask_decode_mode), + tradeoff_factor=resolve(step.tradeoff_factor), + max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, + max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, + ) + + +def construct_http_client_configuration_for_keypoints_detection_step( + step: KeypointsDetectionModel, + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, +) -> InferenceConfiguration: + resolve = partial( + resolve_parameter, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + return InferenceConfiguration( + disable_active_learning=resolve(step.disable_active_learning), + class_agnostic_nms=resolve(step.class_agnostic_nms), + class_filter=resolve(step.class_filter), + confidence_threshold=resolve(step.confidence), + iou_threshold=resolve(step.iou_threshold), + max_detections=resolve(step.max_detections), + max_candidates=resolve(step.max_candidates), + keypoint_confidence_threshold=resolve(step.keypoint_confidence), + max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, + max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, + ) + + +MODEL_TYPE2HTTP_CLIENT_CONSTRUCTOR = { + "ClassificationModel": construct_http_client_configuration_for_classification_step, + "MultiLabelClassificationModel": construct_http_client_configuration_for_classification_step, + "ObjectDetectionModel": construct_http_client_configuration_for_detection_step, + "InstanceSegmentationModel": construct_http_client_configuration_for_segmentation_step, + "KeypointsDetectionModel": construct_http_client_configuration_for_keypoints_detection_step, +} + + +async def run_yolo_world_model_step( + step: YoloWorld, + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, + model_manager: ModelManager, + api_key: Optional[str], + step_execution_mode: StepExecutionMode, +) -> Tuple[NextStepReference, OutputsLookup]: + image = get_image( + step=step, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + class_names = resolve_parameter( + selector_or_value=step.class_names, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + model_version = resolve_parameter( + selector_or_value=step.version, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + confidence = resolve_parameter( + selector_or_value=step.confidence, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + if step_execution_mode is StepExecutionMode.LOCAL: + serialised_result = await get_yolo_world_predictions_locally( + image=image, + class_names=class_names, + model_version=model_version, + confidence=confidence, + model_manager=model_manager, + api_key=api_key, + ) + else: + serialised_result = await get_yolo_world_predictions_from_remote_api( + image=image, + class_names=class_names, + model_version=model_version, + confidence=confidence, + step=step, + api_key=api_key, + ) + serialised_result = attach_prediction_type_info( + results=serialised_result, + prediction_type="object-detection", + ) + serialised_result = attach_parent_info(image=image, results=serialised_result) + serialised_result = anchor_detections_in_parent_coordinates( + image=image, + serialised_result=serialised_result, + ) + outputs_lookup[construct_step_selector(step_name=step.name)] = serialised_result + return None, outputs_lookup + + +async def get_yolo_world_predictions_locally( + image: List[dict], + class_names: List[str], + model_version: Optional[str], + confidence: Optional[float], + model_manager: ModelManager, + api_key: Optional[str], +) -> List[dict]: + serialised_result = [] + for single_image in image: + inference_request = YOLOWorldInferenceRequest( + image=single_image, + yolo_world_version_id=model_version, + confidence=confidence, + text=class_names, + ) + yolo_world_model_id = load_core_model( + model_manager=model_manager, + inference_request=inference_request, + core_model="yolo_world", + api_key=api_key, + ) + result = await model_manager.infer_from_request( + yolo_world_model_id, inference_request + ) + serialised_result.append(result.dict()) + return serialised_result + + +async def get_yolo_world_predictions_from_remote_api( + image: List[dict], + class_names: List[str], + model_version: Optional[str], + confidence: Optional[float], + step: YoloWorld, + api_key: Optional[str], +) -> List[dict]: + api_url = resolve_model_api_url(step=step) + client = InferenceHTTPClient( + api_url=api_url, + api_key=api_key, + ) + configuration = InferenceConfiguration( + max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, + ) + client.configure(inference_configuration=configuration) + if WORKFLOWS_REMOTE_API_TARGET == "hosted": + client.select_api_v0() + image_batches = list( + make_batches( + iterable=image, + batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, + ) + ) + serialised_result = [] + for single_batch in image_batches: + batch_results = await client.infer_from_yolo_world_async( + inference_input=[i["value"] for i in single_batch], + class_names=class_names, + model_version=model_version, + confidence=confidence, + ) + serialised_result.extend(batch_results) + return serialised_result + + +async def run_ocr_model_step( + step: OCRModel, + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, + model_manager: ModelManager, + api_key: Optional[str], + step_execution_mode: StepExecutionMode, +) -> Tuple[NextStepReference, OutputsLookup]: + image = get_image( + step=step, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + if step_execution_mode is StepExecutionMode.LOCAL: + serialised_result = await get_ocr_predictions_locally( + image=image, + model_manager=model_manager, + api_key=api_key, + ) + else: + serialised_result = await get_ocr_predictions_from_remote_api( + step=step, + image=image, + api_key=api_key, + ) + serialised_result = attach_parent_info( + image=image, + results=serialised_result, + nested_key=None, + ) + serialised_result = attach_prediction_type_info( + results=serialised_result, + prediction_type="ocr", + ) + outputs_lookup[construct_step_selector(step_name=step.name)] = serialised_result + return None, outputs_lookup + + +async def get_ocr_predictions_locally( + image: List[dict], + model_manager: ModelManager, + api_key: Optional[str], +) -> List[dict]: + serialised_result = [] + for single_image in image: + inference_request = DoctrOCRInferenceRequest( + image=single_image, + ) + doctr_model_id = load_core_model( + model_manager=model_manager, + inference_request=inference_request, + core_model="doctr", + api_key=api_key, + ) + result = await model_manager.infer_from_request( + doctr_model_id, inference_request + ) + serialised_result.append(result.dict()) + return serialised_result + + +async def get_ocr_predictions_from_remote_api( + step: OCRModel, + image: List[dict], + api_key: Optional[str], +) -> List[dict]: + api_url = resolve_model_api_url(step=step) + client = InferenceHTTPClient( + api_url=api_url, + api_key=api_key, + ) + if WORKFLOWS_REMOTE_API_TARGET == "hosted": + client.select_api_v0() + configuration = InferenceConfiguration( + max_batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_BATCH_SIZE, + max_concurrent_requests=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, + ) + client.configure(configuration) + result = await client.ocr_image_async( + inference_input=[i["value"] for i in image], + ) + if len(image) == 1: + return [result] + return result + + +async def run_clip_comparison_step( + step: ClipComparison, + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, + model_manager: ModelManager, + api_key: Optional[str], + step_execution_mode: StepExecutionMode, +) -> Tuple[NextStepReference, OutputsLookup]: + image = get_image( + step=step, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + text = resolve_parameter( + selector_or_value=step.text, + runtime_parameters=runtime_parameters, + outputs_lookup=outputs_lookup, + ) + if step_execution_mode is StepExecutionMode.LOCAL: + serialised_result = await get_clip_comparison_locally( + image=image, + text=text, + model_manager=model_manager, + api_key=api_key, + ) + else: + serialised_result = await get_clip_comparison_from_remote_api( + step=step, + image=image, + text=text, + api_key=api_key, + ) + serialised_result = attach_parent_info( + image=image, + results=serialised_result, + nested_key=None, + ) + serialised_result = attach_prediction_type_info( + results=serialised_result, + prediction_type="embeddings-comparison", + ) + outputs_lookup[construct_step_selector(step_name=step.name)] = serialised_result + return None, outputs_lookup + + +async def get_clip_comparison_locally( + image: List[dict], + text: str, + model_manager: ModelManager, + api_key: Optional[str], +) -> List[dict]: + serialised_result = [] + for single_image in image: + inference_request = ClipCompareRequest( + subject=single_image, subject_type="image", prompt=text, prompt_type="text" + ) + doctr_model_id = load_core_model( + model_manager=model_manager, + inference_request=inference_request, + core_model="clip", + api_key=api_key, + ) + result = await model_manager.infer_from_request( + doctr_model_id, inference_request + ) + serialised_result.append(result.dict()) + return serialised_result + + +async def get_clip_comparison_from_remote_api( + step: ClipComparison, + image: List[dict], + text: str, + api_key: Optional[str], +) -> List[dict]: + api_url = resolve_model_api_url(step=step) + client = InferenceHTTPClient( + api_url=api_url, + api_key=api_key, + ) + if WORKFLOWS_REMOTE_API_TARGET == "hosted": + client.select_api_v0() + image_batches = list( + make_batches( + iterable=image, + batch_size=WORKFLOWS_REMOTE_EXECUTION_MAX_STEP_CONCURRENT_REQUESTS, + ) + ) + serialised_result = [] + for single_batch in image_batches: + coroutines = [] + for single_image in single_batch: + coroutine = client.clip_compare_async( + subject=single_image["value"], + prompt=text, + ) + coroutines.append(coroutine) + batch_results = list(await asyncio.gather(*coroutines)) + serialised_result.extend(batch_results) + return serialised_result + + +def load_core_model( + model_manager: ModelManager, + inference_request: Union[DoctrOCRInferenceRequest, ClipCompareRequest], + core_model: str, + api_key: Optional[str] = None, +) -> str: + if api_key: + inference_request.api_key = api_key + version_id_field = f"{core_model}_version_id" + core_model_id = ( + f"{core_model}/{inference_request.__getattribute__(version_id_field)}" + ) + model_manager.add_model(core_model_id, inference_request.api_key) + return core_model_id + + +def attach_prediction_type_info( + results: List[Dict[str, Any]], + prediction_type: str, + key: str = "prediction_type", +) -> List[Dict[str, Any]]: + for result in results: + result[key] = prediction_type + return results + + +def attach_parent_info( + image: List[Dict[str, Any]], + results: List[Dict[str, Any]], + nested_key: Optional[str] = "predictions", +) -> List[Dict[str, Any]]: + return [ + attach_parent_info_to_image_detections( + image=i, predictions=p, nested_key=nested_key + ) + for i, p in zip(image, results) + ] + + +def attach_parent_info_to_image_detections( + image: Dict[str, Any], + predictions: Dict[str, Any], + nested_key: Optional[str], +) -> Dict[str, Any]: + predictions["parent_id"] = image["parent_id"] + if nested_key is None: + return predictions + for prediction in predictions[nested_key]: + prediction["parent_id"] = image["parent_id"] + return predictions + + +def anchor_detections_in_parent_coordinates( + image: List[Dict[str, Any]], + serialised_result: List[Dict[str, Any]], + image_metadata_key: str = "image", + detections_key: str = "predictions", +) -> List[Dict[str, Any]]: + return [ + anchor_image_detections_in_parent_coordinates( + image=i, + serialised_result=d, + image_metadata_key=image_metadata_key, + detections_key=detections_key, + ) + for i, d in zip(image, serialised_result) + ] + + +def anchor_image_detections_in_parent_coordinates( + image: Dict[str, Any], + serialised_result: Dict[str, Any], + image_metadata_key: str = "image", + detections_key: str = "predictions", +) -> Dict[str, Any]: + serialised_result[f"{detections_key}{PARENT_COORDINATES_SUFFIX}"] = deepcopy( + serialised_result[detections_key] + ) + serialised_result[f"{image_metadata_key}{PARENT_COORDINATES_SUFFIX}"] = deepcopy( + serialised_result[image_metadata_key] + ) + if ORIGIN_COORDINATES_KEY not in image: + return serialised_result + shift_x, shift_y = ( + image[ORIGIN_COORDINATES_KEY][CENTER_X_KEY], + image[ORIGIN_COORDINATES_KEY][CENTER_Y_KEY], + ) + for detection in serialised_result[f"{detections_key}{PARENT_COORDINATES_SUFFIX}"]: + detection["x"] += shift_x + detection["y"] += shift_y + serialised_result[f"{image_metadata_key}{PARENT_COORDINATES_SUFFIX}"] = image[ + ORIGIN_COORDINATES_KEY + ][ORIGIN_SIZE_KEY] + return serialised_result + + +ROBOFLOW_MODEL2HOSTED_ENDPOINT = { + "ClassificationModel": HOSTED_CLASSIFICATION_URL, + "MultiLabelClassificationModel": HOSTED_CLASSIFICATION_URL, + "ObjectDetectionModel": HOSTED_DETECT_URL, + "KeypointsDetectionModel": HOSTED_DETECT_URL, + "InstanceSegmentationModel": HOSTED_INSTANCE_SEGMENTATION_URL, + "OCRModel": HOSTED_CORE_MODEL_URL, + "ClipComparison": HOSTED_CORE_MODEL_URL, +} + + +def resolve_model_api_url(step: StepInterface) -> str: + if WORKFLOWS_REMOTE_API_TARGET != "hosted": + return LOCAL_INFERENCE_API_URL + return ROBOFLOW_MODEL2HOSTED_ENDPOINT[step.get_type()] diff --git a/inference/enterprise/workflows/complier/steps_executors/types.py b/inference/enterprise/workflows/complier/steps_executors/types.py new file mode 100644 index 0000000000000000000000000000000000000000..879d8c24978b31cb18ae6127352ee9158dae835b --- /dev/null +++ b/inference/enterprise/workflows/complier/steps_executors/types.py @@ -0,0 +1,4 @@ +from typing import Any, Dict, Optional + +NextStepReference = Optional[str] +OutputsLookup = Dict[str, Any] diff --git a/inference/enterprise/workflows/complier/steps_executors/utils.py b/inference/enterprise/workflows/complier/steps_executors/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5edbffc9977ed5d8ccf25a370f9af49a0e3bbe8a --- /dev/null +++ b/inference/enterprise/workflows/complier/steps_executors/utils.py @@ -0,0 +1,87 @@ +from typing import Any, Dict, Generator, Iterable, List, TypeVar, Union + +import numpy as np + +from inference.enterprise.workflows.complier.steps_executors.types import OutputsLookup +from inference.enterprise.workflows.complier.utils import ( + get_step_selector_from_its_output, + is_input_selector, + is_step_output_selector, +) +from inference.enterprise.workflows.entities.steps import ( + AbsoluteStaticCrop, + ActiveLearningDataCollector, + ClipComparison, + Crop, + OCRModel, + RelativeStaticCrop, + RoboflowModel, + YoloWorld, +) +from inference.enterprise.workflows.entities.validators import ( + get_last_selector_chunk, + is_selector, +) +from inference.enterprise.workflows.errors import ExecutionGraphError + +T = TypeVar("T") + + +def get_image( + step: Union[ + RoboflowModel, + OCRModel, + Crop, + AbsoluteStaticCrop, + RelativeStaticCrop, + ClipComparison, + ActiveLearningDataCollector, + YoloWorld, + ], + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, +) -> List[Dict[str, Union[str, np.ndarray]]]: + if is_input_selector(selector_or_value=step.image): + return runtime_parameters[get_last_selector_chunk(selector=step.image)] + if is_step_output_selector(selector_or_value=step.image): + step_selector = get_step_selector_from_its_output( + step_output_selector=step.image + ) + step_output = outputs_lookup[step_selector] + return step_output[get_last_selector_chunk(selector=step.image)] + raise ExecutionGraphError("Cannot find image") + + +def resolve_parameter( + selector_or_value: Any, + runtime_parameters: Dict[str, Any], + outputs_lookup: OutputsLookup, +) -> Any: + if not is_selector(selector_or_value=selector_or_value): + return selector_or_value + if is_step_output_selector(selector_or_value=selector_or_value): + step_selector = get_step_selector_from_its_output( + step_output_selector=selector_or_value + ) + step_output = outputs_lookup[step_selector] + if issubclass(type(step_output), list): + return [ + e[get_last_selector_chunk(selector=selector_or_value)] + for e in step_output + ] + return step_output[get_last_selector_chunk(selector=selector_or_value)] + return runtime_parameters[get_last_selector_chunk(selector=selector_or_value)] + + +def make_batches( + iterable: Iterable[T], batch_size: int +) -> Generator[List[T], None, None]: + batch_size = max(batch_size, 1) + batch = [] + for element in iterable: + batch.append(element) + if len(batch) >= batch_size: + yield batch + batch = [] + if len(batch) > 0: + yield batch diff --git a/inference/enterprise/workflows/complier/utils.py b/inference/enterprise/workflows/complier/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2f3eabd99dfe5e4cdf31ef9df9fa716c3d11691e --- /dev/null +++ b/inference/enterprise/workflows/complier/utils.py @@ -0,0 +1,106 @@ +from typing import Any, List, Set + +from networkx import DiGraph + +from inference.enterprise.workflows.entities.outputs import JsonField +from inference.enterprise.workflows.entities.validators import is_selector +from inference.enterprise.workflows.entities.workflows_specification import ( + InputType, + StepType, +) + + +def get_input_parameters_selectors(inputs: List[InputType]) -> Set[str]: + return { + construct_input_selector(input_name=input_definition.name) + for input_definition in inputs + } + + +def construct_input_selector(input_name: str) -> str: + return f"$inputs.{input_name}" + + +def get_steps_selectors(steps: List[StepType]) -> Set[str]: + return {construct_step_selector(step_name=step.name) for step in steps} + + +def construct_step_selector(step_name: str) -> str: + return f"$steps.{step_name}" + + +def get_steps_input_selectors(steps: List[StepType]) -> Set[str]: + result = set() + for step in steps: + result.update(get_step_input_selectors(step=step)) + return result + + +def get_step_input_selectors(step: StepType) -> Set[str]: + result = set() + for step_input_name in step.get_input_names(): + step_input = getattr(step, step_input_name) + if not issubclass(type(step_input), list): + step_input = [step_input] + for element in step_input: + if not is_selector(selector_or_value=element): + continue + result.add(element) + return result + + +def get_steps_output_selectors(steps: List[StepType]) -> Set[str]: + result = set() + for step in steps: + for output_name in step.get_output_names(): + result.add(f"$steps.{step.name}.{output_name}") + return result + + +def get_output_names(outputs: List[JsonField]) -> Set[str]: + return {construct_output_name(name=output.name) for output in outputs} + + +def construct_output_name(name: str) -> str: + return f"$outputs.{name}" + + +def get_output_selectors(outputs: List[JsonField]) -> Set[str]: + return {output.selector for output in outputs} + + +def is_input_selector(selector_or_value: Any) -> bool: + if not is_selector(selector_or_value=selector_or_value): + return False + return selector_or_value.startswith("$inputs") + + +def construct_selector_pointing_step_output(selector: str, new_output: str) -> str: + if is_step_output_selector(selector_or_value=selector): + selector = get_step_selector_from_its_output(step_output_selector=selector) + return f"{selector}.{new_output}" + + +def is_step_output_selector(selector_or_value: Any) -> bool: + if not is_selector(selector_or_value=selector_or_value): + return False + return ( + selector_or_value.startswith("$steps.") + and len(selector_or_value.split(".")) == 3 + ) + + +def get_step_selector_from_its_output(step_output_selector: str) -> str: + return ".".join(step_output_selector.split(".")[:2]) + + +def get_nodes_of_specific_kind(execution_graph: DiGraph, kind: str) -> Set[str]: + return { + node[0] + for node in execution_graph.nodes(data=True) + if node[1].get("kind") == kind + } + + +def is_condition_step(execution_graph: DiGraph, node: str) -> bool: + return execution_graph.nodes[node]["definition"].type == "Condition" diff --git a/inference/enterprise/workflows/complier/validator.py b/inference/enterprise/workflows/complier/validator.py new file mode 100644 index 0000000000000000000000000000000000000000..4575928076a5f0e4f252ac7288924700f622ef41 --- /dev/null +++ b/inference/enterprise/workflows/complier/validator.py @@ -0,0 +1,75 @@ +from typing import List + +from inference.enterprise.workflows.complier.utils import ( + get_input_parameters_selectors, + get_output_names, + get_output_selectors, + get_steps_input_selectors, + get_steps_output_selectors, + get_steps_selectors, +) +from inference.enterprise.workflows.entities.outputs import JsonField +from inference.enterprise.workflows.entities.workflows_specification import ( + InputType, + StepType, + WorkflowSpecificationV1, +) +from inference.enterprise.workflows.errors import ( + DuplicatedSymbolError, + InvalidReferenceError, +) + + +def validate_workflow_specification( + workflow_specification: WorkflowSpecificationV1, +) -> None: + validate_inputs_names_are_unique(inputs=workflow_specification.inputs) + validate_steps_names_are_unique(steps=workflow_specification.steps) + validate_outputs_names_are_unique(outputs=workflow_specification.outputs) + validate_selectors_references_correctness( + workflow_specification=workflow_specification + ) + + +def validate_inputs_names_are_unique(inputs: List[InputType]) -> None: + input_parameters_selectors = get_input_parameters_selectors(inputs=inputs) + if len(input_parameters_selectors) != len(inputs): + raise DuplicatedSymbolError("Found duplicated input parameter names") + + +def validate_steps_names_are_unique(steps: List[StepType]) -> None: + steps_selectors = get_steps_selectors(steps=steps) + if len(steps_selectors) != len(steps): + raise DuplicatedSymbolError("Found duplicated steps names") + + +def validate_outputs_names_are_unique(outputs: List[JsonField]) -> None: + output_names = get_output_names(outputs=outputs) + if len(output_names) != len(outputs): + raise DuplicatedSymbolError("Found duplicated outputs names") + + +def validate_selectors_references_correctness( + workflow_specification: WorkflowSpecificationV1, +) -> None: + input_parameters_selectors = get_input_parameters_selectors( + inputs=workflow_specification.inputs + ) + steps_inputs_selectors = get_steps_input_selectors( + steps=workflow_specification.steps + ) + steps_output_selectors = get_steps_output_selectors( + steps=workflow_specification.steps + ) + output_selectors = get_output_selectors(outputs=workflow_specification.outputs) + all_possible_input_selectors = input_parameters_selectors | steps_output_selectors + for step_input_selector in steps_inputs_selectors: + if step_input_selector not in all_possible_input_selectors: + raise InvalidReferenceError( + f"Detected step input selector: {step_input_selector} that is not defined as valid input." + ) + for output_selector in output_selectors: + if output_selector not in steps_output_selectors: + raise InvalidReferenceError( + f"Detected output selector: {output_selector} that is not defined as valid output of any of the steps." + ) diff --git a/inference/enterprise/workflows/constants.py b/inference/enterprise/workflows/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..985635e9866dca5dc643e861231b04746dcf546a --- /dev/null +++ b/inference/enterprise/workflows/constants.py @@ -0,0 +1,3 @@ +INPUT_NODE_KIND = "INPUT_NODE" +STEP_NODE_KIND = "STEP_NODE" +OUTPUT_NODE_KIND = "OUTPUT_NODE" diff --git a/inference/enterprise/workflows/entities/__init__.py b/inference/enterprise/workflows/entities/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/inference/enterprise/workflows/entities/__pycache__/__init__.cpython-310.pyc b/inference/enterprise/workflows/entities/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4e0fe97802992f4059c54cbeb021413494324e8 Binary files /dev/null and b/inference/enterprise/workflows/entities/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/entities/__pycache__/base.cpython-310.pyc b/inference/enterprise/workflows/entities/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fed4770435307e8a758abcfefb0e500097accac Binary files /dev/null and b/inference/enterprise/workflows/entities/__pycache__/base.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/entities/__pycache__/inputs.cpython-310.pyc b/inference/enterprise/workflows/entities/__pycache__/inputs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..923835b219460c53267f2424342f67838eb09f0e Binary files /dev/null and b/inference/enterprise/workflows/entities/__pycache__/inputs.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/entities/__pycache__/outputs.cpython-310.pyc b/inference/enterprise/workflows/entities/__pycache__/outputs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a1a69593e8c392a2577861ce4912ec798fc893b Binary files /dev/null and b/inference/enterprise/workflows/entities/__pycache__/outputs.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/entities/__pycache__/steps.cpython-310.pyc b/inference/enterprise/workflows/entities/__pycache__/steps.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6784371b23f4a2e543b8d7d7687272391087de7b Binary files /dev/null and b/inference/enterprise/workflows/entities/__pycache__/steps.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/entities/__pycache__/validators.cpython-310.pyc b/inference/enterprise/workflows/entities/__pycache__/validators.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a270aec23d934f2f24aeb5538c75e39b2b7f2074 Binary files /dev/null and b/inference/enterprise/workflows/entities/__pycache__/validators.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/entities/__pycache__/workflows_specification.cpython-310.pyc b/inference/enterprise/workflows/entities/__pycache__/workflows_specification.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..622fcd26209ac04179bf717730636e7d489c4824 Binary files /dev/null and b/inference/enterprise/workflows/entities/__pycache__/workflows_specification.cpython-310.pyc differ diff --git a/inference/enterprise/workflows/entities/base.py b/inference/enterprise/workflows/entities/base.py new file mode 100644 index 0000000000000000000000000000000000000000..15454ebb4401ba15a8197123135c875ff7944002 --- /dev/null +++ b/inference/enterprise/workflows/entities/base.py @@ -0,0 +1,7 @@ +from abc import ABC, abstractmethod + + +class GraphNone(ABC): + @abstractmethod + def get_type(self) -> str: + pass diff --git a/inference/enterprise/workflows/entities/inputs.py b/inference/enterprise/workflows/entities/inputs.py new file mode 100644 index 0000000000000000000000000000000000000000..ffc39d93087d015a0679c38992d663efeb8080a7 --- /dev/null +++ b/inference/enterprise/workflows/entities/inputs.py @@ -0,0 +1,24 @@ +from typing import Literal, Optional, Union + +from pydantic import BaseModel, Field + +from inference.enterprise.workflows.entities.base import GraphNone + + +class InferenceImage(BaseModel, GraphNone): + type: Literal["InferenceImage"] + name: str + + def get_type(self) -> str: + return self.type + + +class InferenceParameter(BaseModel, GraphNone): + type: Literal["InferenceParameter"] + name: str + default_value: Optional[Union[float, int, str, bool, list, set]] = Field( + default=None + ) + + def get_type(self) -> str: + return self.type diff --git a/inference/enterprise/workflows/entities/outputs.py b/inference/enterprise/workflows/entities/outputs.py new file mode 100644 index 0000000000000000000000000000000000000000..15915638a1187dd577277f91267604dd0fe75cb0 --- /dev/null +++ b/inference/enterprise/workflows/entities/outputs.py @@ -0,0 +1,21 @@ +from enum import Enum +from typing import Literal + +from pydantic import BaseModel, Field + +from inference.enterprise.workflows.entities.base import GraphNone + + +class CoordinatesSystem(Enum): + OWN = "own" + PARENT = "parent" + + +class JsonField(BaseModel, GraphNone): + type: Literal["JsonField"] + name: str + selector: str + coordinates_system: CoordinatesSystem = Field(default=CoordinatesSystem.PARENT) + + def get_type(self) -> str: + return self.type diff --git a/inference/enterprise/workflows/entities/steps.py b/inference/enterprise/workflows/entities/steps.py new file mode 100644 index 0000000000000000000000000000000000000000..828a01f215409e3f2043121e34d2bdc85a1ae505 --- /dev/null +++ b/inference/enterprise/workflows/entities/steps.py @@ -0,0 +1,1481 @@ +from abc import ABCMeta, abstractmethod +from enum import Enum +from typing import Annotated, Any, Dict, List, Literal, Optional, Set, Tuple, Union + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + NonNegativeInt, + PositiveInt, + confloat, + field_validator, +) + +from inference.enterprise.workflows.entities.base import GraphNone +from inference.enterprise.workflows.entities.validators import ( + get_last_selector_chunk, + is_selector, + validate_field_has_given_type, + validate_field_is_empty_or_selector_or_list_of_string, + validate_field_is_in_range_zero_one_or_empty_or_selector, + validate_field_is_list_of_selectors, + validate_field_is_list_of_string, + validate_field_is_one_of_selected_values, + validate_field_is_selector_or_has_given_type, + validate_field_is_selector_or_one_of_values, + validate_image_biding, + validate_image_is_valid_selector, + validate_selector_holds_detections, + validate_selector_holds_image, + validate_selector_is_inference_parameter, + validate_value_is_empty_or_number_in_range_zero_one, + validate_value_is_empty_or_positive_number, + validate_value_is_empty_or_selector_or_positive_number, +) +from inference.enterprise.workflows.errors import ( + ExecutionGraphError, + InvalidStepInputDetected, + VariableTypeError, +) + + +class StepInterface(GraphNone, metaclass=ABCMeta): + @abstractmethod + def get_input_names(self) -> Set[str]: + """ + Supposed to give the name of all fields expected to represent inputs + """ + pass + + @abstractmethod + def get_output_names(self) -> Set[str]: + """ + Supposed to give the name of all fields expected to represent outputs to be referred by other steps + """ + + @abstractmethod + def validate_field_selector( + self, field_name: str, input_step: GraphNone, index: Optional[int] = None + ) -> None: + """ + Supposed to validate the type of input is referred + """ + pass + + @abstractmethod + def validate_field_binding(self, field_name: str, value: Any) -> None: + """ + Supposed to validate the type of value that is to be bounded with field as a result of graph + execution (values passed by client to invocation, as well as constructed during graph execution) + """ + pass + + +class RoboflowModel(BaseModel, StepInterface, metaclass=ABCMeta): + model_config = ConfigDict(protected_namespaces=()) + type: Literal["RoboflowModel"] + name: str + image: Union[str, List[str]] + model_id: str + disable_active_learning: Union[Optional[bool], str] = Field(default=False) + + @field_validator("image") + @classmethod + def validate_image(cls, value: Any) -> Union[str, List[str]]: + validate_image_is_valid_selector(value=value) + return value + + @field_validator("model_id") + @classmethod + def model_id_must_be_selector_or_str(cls, value: Any) -> str: + validate_field_is_selector_or_has_given_type( + value=value, field_name="model_id", allowed_types=[str] + ) + return value + + @field_validator("disable_active_learning") + @classmethod + def disable_active_learning_must_be_selector_or_bool( + cls, value: Any + ) -> Union[Optional[bool], str]: + validate_field_is_selector_or_has_given_type( + field_name="disable_active_learning", + allowed_types=[type(None), bool], + value=value, + ) + return value + + def get_type(self) -> str: + return self.type + + def get_input_names(self) -> Set[str]: + return {"image", "model_id", "disable_active_learning"} + + def get_output_names(self) -> Set[str]: + return {"prediction_type"} + + def validate_field_selector( + self, field_name: str, input_step: GraphNone, index: Optional[int] = None + ) -> None: + if not is_selector(selector_or_value=getattr(self, field_name)): + raise ExecutionGraphError( + f"Attempted to validate selector value for field {field_name}, but field is not selector." + ) + validate_selector_holds_image( + step_type=self.type, + field_name=field_name, + input_step=input_step, + ) + validate_selector_is_inference_parameter( + step_type=self.type, + field_name=field_name, + input_step=input_step, + applicable_fields={"model_id", "disable_active_learning"}, + ) + + def validate_field_binding(self, field_name: str, value: Any) -> None: + if field_name == "image": + validate_image_biding(value=value) + elif field_name == "model_id": + validate_field_has_given_type( + field_name=field_name, + allowed_types=[str], + value=value, + error=VariableTypeError, + ) + elif field_name == "disable_active_learning": + validate_field_has_given_type( + field_name=field_name, + allowed_types=[bool], + value=value, + error=VariableTypeError, + ) + + +class ClassificationModel(RoboflowModel): + type: Literal["ClassificationModel"] + confidence: Union[Optional[float], str] = Field(default=0.4) + + @field_validator("confidence") + @classmethod + def confidence_must_be_selector_or_number( + cls, value: Any + ) -> Union[Optional[float], str]: + validate_field_is_in_range_zero_one_or_empty_or_selector(value=value) + return value + + def get_input_names(self) -> Set[str]: + inputs = super().get_input_names() + inputs.add("confidence") + return inputs + + def get_output_names(self) -> Set[str]: + outputs = super().get_output_names() + outputs.update(["predictions", "top", "confidence", "parent_id"]) + return outputs + + def validate_field_selector( + self, field_name: str, input_step: GraphNone, index: Optional[int] = None + ) -> None: + super().validate_field_selector(field_name=field_name, input_step=input_step) + validate_selector_is_inference_parameter( + step_type=self.type, + field_name=field_name, + input_step=input_step, + applicable_fields={"confidence"}, + ) + + def validate_field_binding(self, field_name: str, value: Any) -> None: + super().validate_field_binding(field_name=field_name, value=value) + if field_name == "confidence": + if value is None: + raise VariableTypeError("Parameter `confidence` cannot be None") + validate_value_is_empty_or_number_in_range_zero_one( + value=value, error=VariableTypeError + ) + + +class MultiLabelClassificationModel(RoboflowModel): + type: Literal["MultiLabelClassificationModel"] + confidence: Union[Optional[float], str] = Field(default=0.4) + + @field_validator("confidence") + @classmethod + def confidence_must_be_selector_or_number( + cls, value: Any + ) -> Union[Optional[float], str]: + validate_field_is_in_range_zero_one_or_empty_or_selector(value=value) + return value + + def get_input_names(self) -> Set[str]: + inputs = super().get_input_names() + inputs.add("confidence") + return inputs + + def get_output_names(self) -> Set[str]: + outputs = super().get_output_names() + outputs.update(["predictions", "predicted_classes", "parent_id"]) + return outputs + + def validate_field_selector( + self, field_name: str, input_step: GraphNone, index: Optional[int] = None + ) -> None: + super().validate_field_selector(field_name=field_name, input_step=input_step) + validate_selector_is_inference_parameter( + step_type=self.type, + field_name=field_name, + input_step=input_step, + applicable_fields={"confidence"}, + ) + + def validate_field_binding(self, field_name: str, value: Any) -> None: + super().validate_field_binding(field_name=field_name, value=value) + if field_name == "confidence": + if value is None: + raise VariableTypeError("Parameter `confidence` cannot be None") + validate_value_is_empty_or_number_in_range_zero_one( + value=value, error=VariableTypeError + ) + + +class ObjectDetectionModel(RoboflowModel): + type: Literal["ObjectDetectionModel"] + class_agnostic_nms: Union[Optional[bool], str] = Field(default=False) + class_filter: Union[Optional[List[str]], str] = Field(default=None) + confidence: Union[Optional[float], str] = Field(default=0.4) + iou_threshold: Union[Optional[float], str] = Field(default=0.3) + max_detections: Union[Optional[int], str] = Field(default=300) + max_candidates: Union[Optional[int], str] = Field(default=3000) + + @field_validator("class_agnostic_nms") + @classmethod + def class_agnostic_nms_must_be_selector_or_bool( + cls, value: Any + ) -> Union[Optional[bool], str]: + validate_field_is_selector_or_has_given_type( + field_name="class_agnostic_nms", + allowed_types=[type(None), bool], + value=value, + ) + return value + + @field_validator("class_filter") + @classmethod + def class_filter_must_be_selector_or_list_of_string( + cls, value: Any + ) -> Union[Optional[List[str]], str]: + validate_field_is_empty_or_selector_or_list_of_string( + value=value, field_name="class_filter" + ) + return value + + @field_validator("confidence", "iou_threshold") + @classmethod + def field_must_be_selector_or_number_from_zero_to_one( + cls, value: Any + ) -> Union[Optional[float], str]: + validate_field_is_in_range_zero_one_or_empty_or_selector( + value=value, field_name="confidence | iou_threshold" + ) + return value + + @field_validator("max_detections", "max_candidates") + @classmethod + def field_must_be_selector_or_positive_number( + cls, value: Any + ) -> Union[Optional[int], str]: + validate_value_is_empty_or_selector_or_positive_number( + value=value, + field_name="max_detections | max_candidates", + ) + return value + + def get_input_names(self) -> Set[str]: + inputs = super().get_input_names() + inputs.update( + [ + "class_agnostic_nms", + "class_filter", + "confidence", + "iou_threshold", + "max_detections", + "max_candidates", + ] + ) + return inputs + + def get_output_names(self) -> Set[str]: + outputs = super().get_output_names() + outputs.update(["predictions", "parent_id", "image"]) + return outputs + + def validate_field_selector( + self, field_name: str, input_step: GraphNone, index: Optional[int] = None + ) -> None: + super().validate_field_selector(field_name=field_name, input_step=input_step) + validate_selector_is_inference_parameter( + step_type=self.type, + field_name=field_name, + input_step=input_step, + applicable_fields={ + "class_agnostic_nms", + "class_filter", + "confidence", + "iou_threshold", + "max_detections", + "max_candidates", + }, + ) + + def validate_field_binding(self, field_name: str, value: Any) -> None: + super().validate_field_binding(field_name=field_name, value=value) + if value is None: + raise VariableTypeError(f"Parameter `{field_name}` cannot be None") + if field_name == "class_agnostic_nms": + validate_field_has_given_type( + field_name=field_name, + allowed_types=[bool], + value=value, + error=VariableTypeError, + ) + elif field_name == "class_filter": + if value is None: + return None + validate_field_is_list_of_string( + value=value, field_name=field_name, error=VariableTypeError + ) + elif field_name == "confidence" or field_name == "iou_threshold": + validate_value_is_empty_or_number_in_range_zero_one( + value=value, + field_name=field_name, + error=VariableTypeError, + ) + elif field_name == "max_detections" or field_name == "max_candidates": + validate_value_is_empty_or_positive_number( + value=value, + field_name=field_name, + error=VariableTypeError, + ) + + +class KeypointsDetectionModel(ObjectDetectionModel): + type: Literal["KeypointsDetectionModel"] + keypoint_confidence: Union[Optional[float], str] = Field(default=0.0) + + @field_validator("keypoint_confidence") + @classmethod + def keypoint_confidence_field_must_be_selector_or_number_from_zero_to_one( + cls, value: Any + ) -> Union[Optional[float], str]: + validate_field_is_in_range_zero_one_or_empty_or_selector( + value=value, field_name="keypoint_confidence" + ) + return value + + def get_input_names(self) -> Set[str]: + inputs = super().get_input_names() + inputs.add("keypoint_confidence") + return inputs + + def validate_field_selector( + self, field_name: str, input_step: GraphNone, index: Optional[int] = None + ) -> None: + super().validate_field_selector(field_name=field_name, input_step=input_step) + validate_selector_is_inference_parameter( + step_type=self.type, + field_name=field_name, + input_step=input_step, + applicable_fields={"keypoint_confidence"}, + ) + + def validate_field_binding(self, field_name: str, value: Any) -> None: + super().validate_field_binding(field_name=field_name, value=value) + if field_name == "keypoint_confidence": + validate_value_is_empty_or_number_in_range_zero_one( + value=value, + field_name=field_name, + error=VariableTypeError, + ) + + +DECODE_MODES = {"accurate", "tradeoff", "fast"} + + +class InstanceSegmentationModel(ObjectDetectionModel): + type: Literal["InstanceSegmentationModel"] + mask_decode_mode: Optional[str] = Field(default="accurate") + tradeoff_factor: Union[Optional[float], str] = Field(default=0.0) + + @field_validator("mask_decode_mode") + @classmethod + def mask_decode_mode_must_be_selector_or_one_of_allowed_values( + cls, value: Any + ) -> Optional[str]: + validate_field_is_selector_or_one_of_values( + value=value, + field_name="mask_decode_mode", + selected_values=DECODE_MODES, + ) + return value + + @field_validator("tradeoff_factor") + @classmethod + def field_must_be_selector_or_number_from_zero_to_one( + cls, value: Any + ) -> Union[Optional[float], str]: + validate_field_is_in_range_zero_one_or_empty_or_selector( + value=value, field_name="tradeoff_factor" + ) + return value + + def get_input_names(self) -> Set[str]: + inputs = super().get_input_names() + inputs.update(["mask_decode_mode", "tradeoff_factor"]) + return inputs + + def validate_field_selector( + self, field_name: str, input_step: GraphNone, index: Optional[int] = None + ) -> None: + super().validate_field_selector(field_name=field_name, input_step=input_step) + validate_selector_is_inference_parameter( + step_type=self.type, + field_name=field_name, + input_step=input_step, + applicable_fields={"mask_decode_mode", "tradeoff_factor"}, + ) + + def validate_field_binding(self, field_name: str, value: Any) -> None: + super().validate_field_binding(field_name=field_name, value=value) + if field_name == "mask_decode_mode": + validate_field_is_one_of_selected_values( + value=value, + field_name=field_name, + selected_values=DECODE_MODES, + error=VariableTypeError, + ) + elif field_name == "tradeoff_factor": + validate_value_is_empty_or_number_in_range_zero_one( + value=value, + field_name=field_name, + error=VariableTypeError, + ) + + +class OCRModel(BaseModel, StepInterface): + type: Literal["OCRModel"] + name: str + image: Union[str, List[str]] + + @field_validator("image") + @classmethod + def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]: + validate_image_is_valid_selector(value=value) + return value + + def validate_field_selector( + self, field_name: str, input_step: GraphNone, index: Optional[int] = None + ) -> None: + if not is_selector(selector_or_value=getattr(self, field_name)): + raise ExecutionGraphError( + f"Attempted to validate selector value for field {field_name}, but field is not selector." + ) + validate_selector_holds_image( + step_type=self.type, + field_name=field_name, + input_step=input_step, + ) + + def validate_field_binding(self, field_name: str, value: Any) -> None: + if field_name == "image": + validate_image_biding(value=value) + + def get_type(self) -> str: + return self.type + + def get_input_names(self) -> Set[str]: + return {"image"} + + def get_output_names(self) -> Set[str]: + return {"result", "parent_id", "prediction_type"} + + +class Crop(BaseModel, StepInterface): + type: Literal["Crop"] + name: str + image: Union[str, List[str]] + detections: str + + @field_validator("image") + @classmethod + def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]: + validate_image_is_valid_selector(value=value) + return value + + @field_validator("detections") + @classmethod + def detections_must_hold_selector(cls, value: Any) -> str: + if not is_selector(selector_or_value=value): + raise ValueError("`detections` field can only contain selector values") + return value + + def get_type(self) -> str: + return self.type + + def get_input_names(self) -> Set[str]: + return {"image", "detections"} + + def get_output_names(self) -> Set[str]: + return {"crops", "parent_id"} + + def validate_field_selector( + self, field_name: str, input_step: GraphNone, index: Optional[int] = None + ) -> None: + if not is_selector(selector_or_value=getattr(self, field_name)): + raise ExecutionGraphError( + f"Attempted to validate selector value for field {field_name}, but field is not selector." + ) + validate_selector_holds_image( + step_type=self.type, + field_name=field_name, + input_step=input_step, + ) + validate_selector_holds_detections( + step_name=self.name, + image_selector=self.image, + detections_selector=self.detections, + field_name=field_name, + input_step=input_step, + ) + + def validate_field_binding(self, field_name: str, value: Any) -> None: + if field_name == "image": + validate_image_biding(value=value) + + +class Operator(Enum): + EQUAL = "equal" + NOT_EQUAL = "not_equal" + LOWER_THAN = "lower_than" + GREATER_THAN = "greater_than" + LOWER_OR_EQUAL_THAN = "lower_or_equal_than" + GREATER_OR_EQUAL_THAN = "greater_or_equal_than" + IN = "in" + + +class Condition(BaseModel, StepInterface): + type: Literal["Condition"] + name: str + left: Union[float, int, bool, str, list, set] + operator: Operator + right: Union[float, int, bool, str, list, set] + step_if_true: str + step_if_false: str + + def get_type(self) -> str: + return self.type + + def get_input_names(self) -> Set[str]: + return {"left", "right"} + + def get_output_names(self) -> Set[str]: + return set() + + def validate_field_selector( + self, field_name: str, input_step: GraphNone, index: Optional[int] = None + ) -> None: + if not is_selector(selector_or_value=getattr(self, field_name)): + raise ExecutionGraphError( + f"Attempted to validate selector value for field {field_name}, but field is not selector." + ) + input_type = input_step.get_type() + if field_name in {"left", "right"}: + if input_type == "InferenceImage": + raise InvalidStepInputDetected( + f"Field {field_name} of step {self.type} comes from invalid input type: {input_type}. " + f"Expected: anything else than `InferenceImage`" + ) + + def validate_field_binding(self, field_name: str, value: Any) -> None: + pass + + +class BinaryOperator(Enum): + OR = "or" + AND = "and" + + +class DetectionFilterDefinition(BaseModel): + type: Literal["DetectionFilterDefinition"] + field_name: str + operator: Operator + reference_value: Union[float, int, bool, str, list, set] + + +class CompoundDetectionFilterDefinition(BaseModel): + type: Literal["CompoundDetectionFilterDefinition"] + left: DetectionFilterDefinition + operator: BinaryOperator + right: DetectionFilterDefinition + + +class DetectionFilter(BaseModel, StepInterface): + type: Literal["DetectionFilter"] + name: str + predictions: str + filter_definition: Annotated[ + Union[DetectionFilterDefinition, CompoundDetectionFilterDefinition], + Field(discriminator="type"), + ] + + def get_input_names(self) -> Set[str]: + return {"predictions"} + + def get_output_names(self) -> Set[str]: + return {"predictions", "parent_id", "image", "prediction_type"} + + def validate_field_selector( + self, field_name: str, input_step: GraphNone, index: Optional[int] = None + ) -> None: + if not is_selector(selector_or_value=getattr(self, field_name)): + raise ExecutionGraphError( + f"Attempted to validate selector value for field {field_name}, but field is not selector." + ) + validate_selector_holds_detections( + step_name=self.name, + image_selector=None, + detections_selector=self.predictions, + field_name=field_name, + input_step=input_step, + applicable_fields={"predictions"}, + ) + + def validate_field_binding(self, field_name: str, value: Any) -> None: + pass + + def get_type(self) -> str: + return self.type + + +class DetectionOffset(BaseModel, StepInterface): + type: Literal["DetectionOffset"] + name: str + predictions: str + offset_x: Union[int, str] + offset_y: Union[int, str] + + def get_input_names(self) -> Set[str]: + return {"predictions", "offset_x", "offset_y"} + + def get_output_names(self) -> Set[str]: + return {"predictions", "parent_id", "image", "prediction_type"} + + def validate_field_selector( + self, field_name: str, input_step: GraphNone, index: Optional[int] = None + ) -> None: + if not is_selector(selector_or_value=getattr(self, field_name)): + raise ExecutionGraphError( + f"Attempted to validate selector value for field {field_name}, but field is not selector." + ) + validate_selector_holds_detections( + step_name=self.name, + image_selector=None, + detections_selector=self.predictions, + field_name=field_name, + input_step=input_step, + applicable_fields={"predictions"}, + ) + validate_selector_is_inference_parameter( + step_type=self.type, + field_name=field_name, + input_step=input_step, + applicable_fields={"offset_x", "offset_y"}, + ) + + def validate_field_binding(self, field_name: str, value: Any) -> None: + if field_name in {"offset_x", "offset_y"}: + validate_field_has_given_type( + field_name=field_name, + value=value, + allowed_types=[int], + error=VariableTypeError, + ) + + def get_type(self) -> str: + return self.type + + +class AbsoluteStaticCrop(BaseModel, StepInterface): + type: Literal["AbsoluteStaticCrop"] + name: str + image: Union[str, List[str]] + x_center: Union[int, str] + y_center: Union[int, str] + width: Union[int, str] + height: Union[int, str] + + @field_validator("image") + @classmethod + def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]: + validate_image_is_valid_selector(value=value) + return value + + @field_validator("x_center", "y_center", "width", "height") + @classmethod + def validate_crops_coordinates(cls, value: Any) -> str: + validate_value_is_empty_or_selector_or_positive_number( + value=value, field_name="x_center | y_center | width | height" + ) + return value + + def get_type(self) -> str: + return self.type + + def get_input_names(self) -> Set[str]: + return {"image", "x_center", "y_center", "width", "height"} + + def get_output_names(self) -> Set[str]: + return {"crops", "parent_id"} + + def validate_field_selector( + self, field_name: str, input_step: GraphNone, index: Optional[int] = None + ) -> None: + if not is_selector(selector_or_value=getattr(self, field_name)): + raise ExecutionGraphError( + f"Attempted to validate selector value for field {field_name}, but field is not selector." + ) + validate_selector_holds_image( + step_type=self.type, + field_name=field_name, + input_step=input_step, + ) + validate_selector_is_inference_parameter( + step_type=self.type, + field_name=field_name, + input_step=input_step, + applicable_fields={"x_center", "y_center", "width", "height"}, + ) + + def validate_field_binding(self, field_name: str, value: Any) -> None: + if field_name == "image": + validate_image_biding(value=value) + if field_name in {"x_center", "y_center", "width", "height"}: + if ( + not issubclass(type(value), int) and not issubclass(type(value), float) + ) or value != round(value): + raise VariableTypeError( + f"Field {field_name} of step {self.type} must be integer" + ) + + +class RelativeStaticCrop(BaseModel, StepInterface): + type: Literal["RelativeStaticCrop"] + name: str + image: Union[str, List[str]] + x_center: Union[float, str] + y_center: Union[float, str] + width: Union[float, str] + height: Union[float, str] + + @field_validator("image") + @classmethod + def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]: + validate_image_is_valid_selector(value=value) + return value + + @field_validator("x_center", "y_center", "width", "height") + @classmethod + def detections_must_hold_selector(cls, value: Any) -> str: + if issubclass(type(value), str): + if not is_selector(selector_or_value=value): + raise ValueError("Field must be either float of valid selector") + elif not issubclass(type(value), float): + raise ValueError("Field must be either float of valid selector") + return value + + def get_type(self) -> str: + return self.type + + def get_input_names(self) -> Set[str]: + return {"image", "x_center", "y_center", "width", "height"} + + def get_output_names(self) -> Set[str]: + return {"crops", "parent_id"} + + def validate_field_selector( + self, field_name: str, input_step: GraphNone, index: Optional[int] = None + ) -> None: + if not is_selector(selector_or_value=getattr(self, field_name)): + raise ExecutionGraphError( + f"Attempted to validate selector value for field {field_name}, but field is not selector." + ) + validate_selector_holds_image( + step_type=self.type, + field_name=field_name, + input_step=input_step, + ) + validate_selector_is_inference_parameter( + step_type=self.type, + field_name=field_name, + input_step=input_step, + applicable_fields={"x_center", "y_center", "width", "height"}, + ) + + def validate_field_binding(self, field_name: str, value: Any) -> None: + if field_name == "image": + validate_image_biding(value=value) + if field_name in {"x_center", "y_center", "width", "height"}: + validate_field_has_given_type( + field_name=field_name, + value=value, + allowed_types=[float], + error=VariableTypeError, + ) + + +class ClipComparison(BaseModel, StepInterface): + type: Literal["ClipComparison"] + name: str + image: Union[str, List[str]] + text: Union[str, List[str]] + + @field_validator("image") + @classmethod + def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]: + validate_image_is_valid_selector(value=value) + return value + + @field_validator("text") + @classmethod + def text_must_be_valid(cls, value: Any) -> Union[str, List[str]]: + if is_selector(selector_or_value=value): + return value + if issubclass(type(value), list): + validate_field_is_list_of_string(value=value, field_name="text") + elif not issubclass(type(value), str): + raise ValueError("`text` field given must be string or list of strings") + return value + + def validate_field_selector( + self, field_name: str, input_step: GraphNone, index: Optional[int] = None + ) -> None: + if not is_selector(selector_or_value=getattr(self, field_name)): + raise ExecutionGraphError( + f"Attempted to validate selector value for field {field_name}, but field is not selector." + ) + validate_selector_holds_image( + step_type=self.type, + field_name=field_name, + input_step=input_step, + ) + validate_selector_is_inference_parameter( + step_type=self.type, + field_name=field_name, + input_step=input_step, + applicable_fields={"text"}, + ) + + def validate_field_binding(self, field_name: str, value: Any) -> None: + if field_name == "image": + validate_image_biding(value=value) + if field_name == "text": + if issubclass(type(value), list): + validate_field_is_list_of_string( + value=value, field_name=field_name, error=VariableTypeError + ) + elif not issubclass(type(value), str): + validate_field_has_given_type( + value=value, + field_name=field_name, + allowed_types=[str], + error=VariableTypeError, + ) + + def get_type(self) -> str: + return self.type + + def get_input_names(self) -> Set[str]: + return {"image", "text"} + + def get_output_names(self) -> Set[str]: + return {"similarity", "parent_id", "predictions_type"} + + +class AggregationMode(Enum): + AVERAGE = "average" + MAX = "max" + MIN = "min" + + +class DetectionsConsensus(BaseModel, StepInterface): + type: Literal["DetectionsConsensus"] + name: str + predictions: List[str] + required_votes: Union[int, str] + class_aware: Union[bool, str] = Field(default=True) + iou_threshold: Union[float, str] = Field(default=0.3) + confidence: Union[float, str] = Field(default=0.0) + classes_to_consider: Optional[Union[List[str], str]] = Field(default=None) + required_objects: Optional[Union[int, Dict[str, int], str]] = Field(default=None) + presence_confidence_aggregation: AggregationMode = Field( + default=AggregationMode.MAX + ) + detections_merge_confidence_aggregation: AggregationMode = Field( + default=AggregationMode.AVERAGE + ) + detections_merge_coordinates_aggregation: AggregationMode = Field( + default=AggregationMode.AVERAGE + ) + + @field_validator("predictions") + @classmethod + def predictions_must_be_list_of_selectors(cls, value: Any) -> List[str]: + validate_field_is_list_of_selectors(value=value, field_name="predictions") + if len(value) < 1: + raise ValueError( + "There must be at least 1 `predictions` selectors in consensus step" + ) + return value + + @field_validator("required_votes") + @classmethod + def required_votes_must_be_selector_or_positive_integer( + cls, value: Any + ) -> Union[str, int]: + if value is None: + raise ValueError("Field `required_votes` is required.") + validate_value_is_empty_or_selector_or_positive_number( + value=value, field_name="required_votes" + ) + return value + + @field_validator("class_aware") + @classmethod + def class_aware_must_be_selector_or_boolean(cls, value: Any) -> Union[str, bool]: + validate_field_is_selector_or_has_given_type( + value=value, field_name="class_aware", allowed_types=[bool] + ) + return value + + @field_validator("iou_threshold", "confidence") + @classmethod + def field_must_be_selector_or_number_from_zero_to_one( + cls, value: Any + ) -> Union[str, float]: + if value is None: + raise ValueError("Fields `iou_threshold` and `confidence` cannot be None") + validate_field_is_in_range_zero_one_or_empty_or_selector( + value=value, field_name="iou_threshold | confidence" + ) + return value + + @field_validator("classes_to_consider") + @classmethod + def classes_to_consider_must_be_empty_or_selector_or_list_of_strings( + cls, value: Any + ) -> Optional[Union[str, List[str]]]: + validate_field_is_empty_or_selector_or_list_of_string( + value=value, field_name="classes_to_consider" + ) + return value + + @field_validator("required_objects") + @classmethod + def required_objects_field_must_be_valid( + cls, value: Any + ) -> Optional[Union[str, int, Dict[str, int]]]: + if value is None: + return value + validate_field_is_selector_or_has_given_type( + value=value, field_name="required_objects", allowed_types=[int, dict] + ) + if issubclass(type(value), int): + validate_value_is_empty_or_positive_number( + value=value, field_name="required_objects" + ) + return value + elif issubclass(type(value), dict): + for k, v in value.items(): + if v is None: + raise ValueError(f"Field `required_objects[{k}]` must not be None.") + validate_value_is_empty_or_positive_number( + value=v, field_name=f"required_objects[{k}]" + ) + return value + + def get_input_names(self) -> Set[str]: + return { + "predictions", + "required_votes", + "class_aware", + "iou_threshold", + "confidence", + "classes_to_consider", + "required_objects", + } + + def get_output_names(self) -> Set[str]: + return { + "parent_id", + "predictions", + "image", + "object_present", + "presence_confidence", + "predictions_type", + } + + def validate_field_selector( + self, field_name: str, input_step: GraphNone, index: Optional[int] = None + ) -> None: + if field_name != "predictions" and not is_selector( + selector_or_value=getattr(self, field_name) + ): + raise ExecutionGraphError( + f"Attempted to validate selector value for field {field_name}, but field is not selector." + ) + if field_name == "predictions": + if index is None or index > len(self.predictions): + raise ExecutionGraphError( + f"Attempted to validate selector value for field {field_name}, which requires multiple inputs, " + f"but `index` not provided." + ) + if not is_selector( + selector_or_value=self.predictions[index], + ): + raise ExecutionGraphError( + f"Attempted to validate selector value for field {field_name}[{index}], but field is not selector." + ) + validate_selector_holds_detections( + step_name=self.name, + image_selector=None, + detections_selector=self.predictions[index], + field_name=field_name, + input_step=input_step, + applicable_fields={"predictions"}, + ) + return None + validate_selector_is_inference_parameter( + step_type=self.type, + field_name=field_name, + input_step=input_step, + applicable_fields={ + "required_votes", + "class_aware", + "iou_threshold", + "confidence", + "classes_to_consider", + "required_objects", + }, + ) + + def validate_field_binding(self, field_name: str, value: Any) -> None: + if field_name == "required_votes": + if value is None: + raise VariableTypeError("Field `required_votes` cannot be None.") + validate_value_is_empty_or_positive_number( + value=value, field_name="required_votes", error=VariableTypeError + ) + elif field_name == "class_aware": + validate_field_has_given_type( + field_name=field_name, + allowed_types=[bool], + value=value, + error=VariableTypeError, + ) + elif field_name in {"iou_threshold", "confidence"}: + if value is None: + raise VariableTypeError(f"Fields `{field_name}` cannot be None.") + validate_value_is_empty_or_number_in_range_zero_one( + value=value, + field_name=field_name, + error=VariableTypeError, + ) + elif field_name == "classes_to_consider": + if value is None: + return None + validate_field_is_list_of_string( + value=value, + field_name=field_name, + error=VariableTypeError, + ) + elif field_name == "required_objects": + self._validate_required_objects_binding(value=value) + return None + + def get_type(self) -> str: + return self.type + + def _validate_required_objects_binding(self, value: Any) -> None: + if value is None: + return value + validate_field_has_given_type( + value=value, + field_name="required_objects", + allowed_types=[int, dict], + error=VariableTypeError, + ) + if issubclass(type(value), int): + validate_value_is_empty_or_positive_number( + value=value, + field_name="required_objects", + error=VariableTypeError, + ) + return None + for k, v in value.items(): + if v is None: + raise VariableTypeError( + f"Field `required_objects[{k}]` must not be None." + ) + validate_value_is_empty_or_positive_number( + value=v, + field_name=f"required_objects[{k}]", + error=VariableTypeError, + ) + + +ACTIVE_LEARNING_DATA_COLLECTOR_ELIGIBLE_SELECTORS = { + "ObjectDetectionModel": "predictions", + "KeypointsDetectionModel": "predictions", + "InstanceSegmentationModel": "predictions", + "DetectionFilter": "predictions", + "DetectionsConsensus": "predictions", + "DetectionOffset": "predictions", + "YoloWorld": "predictions", + "ClassificationModel": "top", +} + + +class DisabledActiveLearningConfiguration(BaseModel): + enabled: bool + + @field_validator("enabled") + @classmethod + def ensure_only_false_is_valid(cls, value: Any) -> bool: + if value is not False: + raise ValueError( + "One can only specify enabled=False in `DisabledActiveLearningConfiguration`" + ) + return value + + +class LimitDefinition(BaseModel): + type: Literal["minutely", "hourly", "daily"] + value: PositiveInt + + +class RandomSamplingConfig(BaseModel): + type: Literal["random"] + name: str + traffic_percentage: confloat(ge=0.0, le=1.0) + tags: List[str] = Field(default_factory=lambda: []) + limits: List[LimitDefinition] = Field(default_factory=lambda: []) + + +class CloseToThresholdSampling(BaseModel): + type: Literal["close_to_threshold"] + name: str + probability: confloat(ge=0.0, le=1.0) + threshold: confloat(ge=0.0, le=1.0) + epsilon: confloat(ge=0.0, le=1.0) + max_batch_images: Optional[int] = Field(default=None) + only_top_classes: bool = Field(default=True) + minimum_objects_close_to_threshold: int = Field(default=1) + selected_class_names: Optional[List[str]] = Field(default=None) + tags: List[str] = Field(default_factory=lambda: []) + limits: List[LimitDefinition] = Field(default_factory=lambda: []) + + +class ClassesBasedSampling(BaseModel): + type: Literal["classes_based"] + name: str + probability: confloat(ge=0.0, le=1.0) + selected_class_names: List[str] + tags: List[str] = Field(default_factory=lambda: []) + limits: List[LimitDefinition] = Field(default_factory=lambda: []) + + +class DetectionsBasedSampling(BaseModel): + type: Literal["detections_number_based"] + name: str + probability: confloat(ge=0.0, le=1.0) + more_than: Optional[NonNegativeInt] + less_than: Optional[NonNegativeInt] + selected_class_names: Optional[List[str]] = Field(default=None) + tags: List[str] = Field(default_factory=lambda: []) + limits: List[LimitDefinition] = Field(default_factory=lambda: []) + + +class ActiveLearningBatchingStrategy(BaseModel): + batches_name_prefix: str + recreation_interval: Literal["never", "daily", "weekly", "monthly"] + max_batch_images: Optional[int] = Field(default=None) + + +ActiveLearningStrategyType = Annotated[ + Union[ + RandomSamplingConfig, + CloseToThresholdSampling, + ClassesBasedSampling, + DetectionsBasedSampling, + ], + Field(discriminator="type"), +] + + +class EnabledActiveLearningConfiguration(BaseModel): + enabled: bool + persist_predictions: bool + sampling_strategies: List[ActiveLearningStrategyType] + batching_strategy: ActiveLearningBatchingStrategy + tags: List[str] = Field(default_factory=lambda: []) + max_image_size: Optional[Tuple[PositiveInt, PositiveInt]] = Field(default=None) + jpeg_compression_level: int = Field(default=95) + + @field_validator("jpeg_compression_level") + @classmethod + def validate_json_compression_level(cls, value: Any) -> int: + validate_field_has_given_type( + field_name="jpeg_compression_level", allowed_types=[int], value=value + ) + if value <= 0 or value > 100: + raise ValueError("`jpeg_compression_level` must be in range [1, 100]") + return value + + +class ActiveLearningDataCollector(BaseModel, StepInterface): + type: Literal["ActiveLearningDataCollector"] + name: str + image: str + predictions: str + target_dataset: str + target_dataset_api_key: Optional[str] = Field(default=None) + disable_active_learning: Union[bool, str] = Field(default=False) + active_learning_configuration: Optional[ + Union[EnabledActiveLearningConfiguration, DisabledActiveLearningConfiguration] + ] = Field(default=None) + + @field_validator("image") + @classmethod + def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]: + validate_image_is_valid_selector(value=value) + return value + + @field_validator("predictions") + @classmethod + def predictions_must_hold_selector(cls, value: Any) -> str: + if not is_selector(selector_or_value=value): + raise ValueError("`predictions` field can only contain selector values") + return value + + @field_validator("target_dataset") + @classmethod + def validate_target_dataset_field(cls, value: Any) -> str: + validate_field_is_selector_or_has_given_type( + value=value, field_name="target_dataset", allowed_types=[str] + ) + return value + + @field_validator("target_dataset_api_key") + @classmethod + def validate_target_dataset_api_key_field(cls, value: Any) -> Union[str, bool]: + validate_field_is_selector_or_has_given_type( + value=value, + field_name="target_dataset_api_key", + allowed_types=[bool, type(None)], + ) + return value + + @field_validator("disable_active_learning") + @classmethod + def validate_boolean_flags_or_selectors(cls, value: Any) -> Union[str, bool]: + validate_field_is_selector_or_has_given_type( + value=value, field_name="disable_active_learning", allowed_types=[bool] + ) + return value + + def get_type(self) -> str: + return self.type + + def get_input_names(self) -> Set[str]: + return { + "image", + "predictions", + "target_dataset", + "target_dataset_api_key", + "disable_active_learning", + } + + def get_output_names(self) -> Set[str]: + return set() + + def validate_field_selector( + self, field_name: str, input_step: GraphNone, index: Optional[int] = None + ) -> None: + selector = getattr(self, field_name) + if not is_selector(selector_or_value=selector): + raise ExecutionGraphError( + f"Attempted to validate selector value for field {field_name}, but field is not selector." + ) + if field_name == "predictions": + input_step_type = input_step.get_type() + expected_last_selector_chunk = ( + ACTIVE_LEARNING_DATA_COLLECTOR_ELIGIBLE_SELECTORS.get(input_step_type) + ) + if expected_last_selector_chunk is None: + raise ExecutionGraphError( + f"Attempted to validate predictions selector of {self.name} step, but input step of type: " + f"{input_step_type} does match by type." + ) + if get_last_selector_chunk(selector) != expected_last_selector_chunk: + raise ExecutionGraphError( + f"It is only allowed to refer to {input_step_type} step output named {expected_last_selector_chunk}. " + f"Reference that was found: {selector}" + ) + input_step_image = getattr(input_step, "image", self.image) + if input_step_image != self.image: + raise ExecutionGraphError( + f"ActiveLearningDataCollector step refers to input step that uses reference to different image. " + f"ActiveLearningDataCollector step image: {self.image}. Input step (of type {input_step_image}) " + f"uses {input_step_image}." + ) + validate_selector_holds_image( + step_type=self.type, + field_name=field_name, + input_step=input_step, + ) + validate_selector_is_inference_parameter( + step_type=self.type, + field_name=field_name, + input_step=input_step, + applicable_fields={ + "target_dataset", + "target_dataset_api_key", + "disable_active_learning", + }, + ) + + def validate_field_binding(self, field_name: str, value: Any) -> None: + if field_name == "image": + validate_image_biding(value=value) + elif field_name in {"disable_active_learning"}: + validate_field_has_given_type( + field_name=field_name, + allowed_types=[bool], + value=value, + error=VariableTypeError, + ) + elif field_name in {"target_dataset"}: + validate_field_has_given_type( + field_name=field_name, + allowed_types=[str], + value=value, + error=VariableTypeError, + ) + elif field_name in {"target_dataset_api_key"}: + validate_field_has_given_type( + field_name=field_name, + allowed_types=[str], + value=value, + error=VariableTypeError, + ) + + +class YoloWorld(BaseModel, StepInterface): + type: Literal["YoloWorld"] + name: str + image: str + class_names: Union[str, List[str]] + version: Optional[str] = Field(default="l") + confidence: Union[Optional[float], str] = Field(default=0.4) + + @field_validator("image") + @classmethod + def image_must_only_hold_selectors(cls, value: Any) -> Union[str, List[str]]: + validate_image_is_valid_selector(value=value) + return value + + @field_validator("class_names") + @classmethod + def validate_class_names(cls, value: Any) -> Union[str, List[str]]: + if is_selector(selector_or_value=value): + return value + if issubclass(type(value), list): + validate_field_is_list_of_string(value=value, field_name="class_names") + return value + raise ValueError( + "`class_names` field given must be selector or list of strings" + ) + + @field_validator("version") + @classmethod + def validate_model_version(cls, value: Any) -> Optional[str]: + validate_field_is_selector_or_one_of_values( + value=value, + selected_values={None, "s", "m", "l"}, + field_name="version", + ) + return value + + @field_validator("confidence") + @classmethod + def field_must_be_selector_or_number_from_zero_to_one( + cls, value: Any + ) -> Union[Optional[float], str]: + if value is None: + return None + validate_field_is_in_range_zero_one_or_empty_or_selector( + value=value, field_name="confidence" + ) + return value + + def get_input_names(self) -> Set[str]: + return {"image", "class_names", "version", "confidence"} + + def get_output_names(self) -> Set[str]: + return {"predictions", "parent_id", "image", "prediction_type"} + + def validate_field_selector( + self, field_name: str, input_step: GraphNone, index: Optional[int] = None + ) -> None: + selector = getattr(self, field_name) + if not is_selector(selector_or_value=selector): + raise ExecutionGraphError( + f"Attempted to validate selector value for field {field_name}, but field is not selector." + ) + validate_selector_holds_image( + step_type=self.type, + field_name=field_name, + input_step=input_step, + ) + validate_selector_is_inference_parameter( + step_type=self.type, + field_name=field_name, + input_step=input_step, + applicable_fields={"class_names", "version", "confidence"}, + ) + + def validate_field_binding(self, field_name: str, value: Any) -> None: + if field_name == "image": + validate_image_biding(value=value) + elif field_name == "class_names": + validate_field_is_list_of_string( + value=value, + field_name=field_name, + error=VariableTypeError, + ) + elif field_name == "version": + validate_field_is_one_of_selected_values( + value=value, + field_name=field_name, + selected_values={None, "s", "m", "l"}, + error=VariableTypeError, + ) + elif field_name == "confidence": + validate_value_is_empty_or_number_in_range_zero_one( + value=value, + field_name=field_name, + error=VariableTypeError, + ) + + def get_type(self) -> str: + return self.type diff --git a/inference/enterprise/workflows/entities/validators.py b/inference/enterprise/workflows/entities/validators.py new file mode 100644 index 0000000000000000000000000000000000000000..c080f6b7bb8f6d84a33ed2ba29a73e11ce559f82 --- /dev/null +++ b/inference/enterprise/workflows/entities/validators.py @@ -0,0 +1,239 @@ +from typing import Any, List, Optional, Set, Type + +from pydantic import ValidationError + +from inference.core.entities.requests.inference import InferenceRequestImage +from inference.enterprise.workflows.entities.base import GraphNone +from inference.enterprise.workflows.errors import ( + InvalidStepInputDetected, + VariableTypeError, +) + +STEPS_WITH_IMAGE = { + "InferenceImage", + "Crop", + "AbsoluteStaticCrop", + "RelativeStaticCrop", +} + + +def validate_image_is_valid_selector(value: Any, field_name: str = "image") -> None: + if issubclass(type(value), list): + if any(not is_selector(selector_or_value=e) for e in value): + raise ValueError(f"`{field_name}` field can only contain selector values") + elif not is_selector(selector_or_value=value): + raise ValueError(f"`{field_name}` field can only contain selector values") + + +def validate_field_is_in_range_zero_one_or_empty_or_selector( + value: Any, field_name: str = "confidence" +) -> None: + if is_selector(selector_or_value=value) or value is None: + return None + validate_value_is_empty_or_number_in_range_zero_one( + value=value, field_name=field_name + ) + + +def validate_value_is_empty_or_number_in_range_zero_one( + value: Any, field_name: str = "confidence", error: Type[Exception] = ValueError +) -> None: + validate_field_has_given_type( + field_name=field_name, + allowed_types=[type(None), int, float], + value=value, + error=error, + ) + if value is None: + return None + if not (0 <= value <= 1): + raise error(f"Parameter `{field_name}` must be in range [0.0, 1.0]") + + +def validate_value_is_empty_or_selector_or_positive_number( + value: Any, field_name: str +) -> None: + if is_selector(selector_or_value=value): + return None + validate_value_is_empty_or_positive_number(value=value, field_name=field_name) + + +def validate_value_is_empty_or_positive_number( + value: Any, field_name: str, error: Type[Exception] = ValueError +) -> None: + validate_field_has_given_type( + field_name=field_name, + allowed_types=[type(None), int, float], + value=value, + error=error, + ) + if value is None: + return None + if value <= 0: + raise error(f"Parameter `{field_name}` must be positive (> 0)") + + +def validate_field_is_list_of_selectors( + value: Any, field_name: str, error: Type[Exception] = ValueError +) -> None: + if not issubclass(type(value), list): + raise error(f"`{field_name}` field must be list") + if any(not is_selector(selector_or_value=e) for e in value): + raise error(f"Parameter `{field_name}` must be a list of selectors") + + +def validate_field_is_empty_or_selector_or_list_of_string( + value: Any, field_name: str +) -> None: + if is_selector(selector_or_value=value) or value is None: + return value + validate_field_is_list_of_string(value=value, field_name=field_name) + + +def validate_field_is_list_of_string( + value: Any, field_name: str, error: Type[Exception] = ValueError +) -> None: + if not issubclass(type(value), list): + raise error(f"`{field_name}` field must be list") + if any(not issubclass(type(e), str) for e in value): + raise error(f"Parameter `{field_name}` must be a list of string") + + +def validate_field_is_selector_or_one_of_values( + value: Any, field_name: str, selected_values: set +) -> None: + if is_selector(selector_or_value=value) or value is None: + return value + validate_field_is_one_of_selected_values( + value=value, field_name=field_name, selected_values=selected_values + ) + + +def validate_field_is_one_of_selected_values( + value: Any, + field_name: str, + selected_values: set, + error: Type[Exception] = ValueError, +) -> None: + if value not in selected_values: + raise error( + f"Value of field `{field_name}` must be in {selected_values}. Found: {value}" + ) + + +def validate_field_is_selector_or_has_given_type( + value: Any, field_name: str, allowed_types: List[type] +) -> None: + if is_selector(selector_or_value=value): + return None + validate_field_has_given_type( + field_name=field_name, allowed_types=allowed_types, value=value + ) + return None + + +def validate_field_has_given_type( + value: Any, + field_name: str, + allowed_types: List[type], + error: Type[Exception] = ValueError, +) -> None: + if all(not issubclass(type(value), allowed_type) for allowed_type in allowed_types): + raise error( + f"`{field_name}` field type must be one of {allowed_types}. Detected: {value}" + ) + + +def validate_image_biding(value: Any, field_name: str = "image") -> None: + try: + if not issubclass(type(value), list): + value = [value] + for e in value: + InferenceRequestImage.model_validate(e) + except (ValueError, ValidationError) as error: + raise VariableTypeError( + f"Parameter `{field_name}` must be compatible with `InferenceRequestImage`" + ) from error + + +def validate_selector_is_inference_parameter( + step_type: str, + field_name: str, + input_step: GraphNone, + applicable_fields: Set[str], +) -> None: + if field_name not in applicable_fields: + return None + input_step_type = input_step.get_type() + if input_step_type not in {"InferenceParameter"}: + raise InvalidStepInputDetected( + f"Field {field_name} of step {step_type} comes from invalid input type: {input_step_type}. " + f"Expected: `InferenceParameter`" + ) + + +def validate_selector_holds_image( + step_type: str, + field_name: str, + input_step: GraphNone, + applicable_fields: Optional[Set[str]] = None, +) -> None: + if applicable_fields is None: + applicable_fields = {"image"} + if field_name not in applicable_fields: + return None + if input_step.get_type() not in STEPS_WITH_IMAGE: + raise InvalidStepInputDetected( + f"Field {field_name} of step {step_type} comes from invalid input type: {input_step.get_type()}. " + f"Expected: {STEPS_WITH_IMAGE}" + ) + + +def validate_selector_holds_detections( + step_name: str, + image_selector: Optional[str], + detections_selector: str, + field_name: str, + input_step: GraphNone, + applicable_fields: Optional[Set[str]] = None, +) -> None: + if applicable_fields is None: + applicable_fields = {"detections"} + if field_name not in applicable_fields: + return None + if input_step.get_type() not in { + "ObjectDetectionModel", + "KeypointsDetectionModel", + "InstanceSegmentationModel", + "DetectionFilter", + "DetectionsConsensus", + "DetectionOffset", + "YoloWorld", + }: + raise InvalidStepInputDetected( + f"Step step with name {step_name} cannot take as an input predictions from {input_step.get_type()}. " + f"Step requires detection-based output." + ) + if get_last_selector_chunk(detections_selector) != "predictions": + raise InvalidStepInputDetected( + f"Step with name {step_name} must take as input step output of name `predictions`" + ) + if not hasattr(input_step, "image") or image_selector is None: + # Here, filter do not hold the reference to image, we skip the check in this case + return None + input_step_image_reference = input_step.image + if image_selector != input_step_image_reference: + raise InvalidStepInputDetected( + f"Step step with name {step_name} was given detections reference that is bound to different image: " + f"step.image: {image_selector}, detections step image: {input_step_image_reference}" + ) + + +def is_selector(selector_or_value: Any) -> bool: + if not issubclass(type(selector_or_value), str): + return False + return selector_or_value.startswith("$") + + +def get_last_selector_chunk(selector: str) -> str: + return selector.split(".")[-1] diff --git a/inference/enterprise/workflows/entities/workflows_specification.py b/inference/enterprise/workflows/entities/workflows_specification.py new file mode 100644 index 0000000000000000000000000000000000000000..2dd6d894d1b4c44ea6c4a981963fed429bb383e9 --- /dev/null +++ b/inference/enterprise/workflows/entities/workflows_specification.py @@ -0,0 +1,65 @@ +from typing import Annotated, List, Literal, Union + +from pydantic import BaseModel, Field + +from inference.enterprise.workflows.entities.inputs import ( + InferenceImage, + InferenceParameter, +) +from inference.enterprise.workflows.entities.outputs import JsonField +from inference.enterprise.workflows.entities.steps import ( + AbsoluteStaticCrop, + ActiveLearningDataCollector, + ClassificationModel, + ClipComparison, + Condition, + Crop, + DetectionFilter, + DetectionOffset, + DetectionsConsensus, + InstanceSegmentationModel, + KeypointsDetectionModel, + MultiLabelClassificationModel, + ObjectDetectionModel, + OCRModel, + RelativeStaticCrop, + YoloWorld, +) + +InputType = Annotated[ + Union[InferenceImage, InferenceParameter], Field(discriminator="type") +] +StepType = Annotated[ + Union[ + ClassificationModel, + MultiLabelClassificationModel, + ObjectDetectionModel, + KeypointsDetectionModel, + InstanceSegmentationModel, + OCRModel, + Crop, + Condition, + DetectionFilter, + DetectionOffset, + ClipComparison, + RelativeStaticCrop, + AbsoluteStaticCrop, + DetectionsConsensus, + ActiveLearningDataCollector, + YoloWorld, + ], + Field(discriminator="type"), +] + + +class WorkflowSpecificationV1(BaseModel): + version: Literal["1.0"] + inputs: List[InputType] + steps: List[StepType] + outputs: List[JsonField] + + +class WorkflowSpecification(BaseModel): + specification: ( + WorkflowSpecificationV1 # in the future - union with discriminator can be used + ) diff --git a/inference/enterprise/workflows/errors.py b/inference/enterprise/workflows/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..d1654445c9947caf8a1cd335601f9c18327644db --- /dev/null +++ b/inference/enterprise/workflows/errors.py @@ -0,0 +1,62 @@ +class WorkflowsCompilerError(Exception): + pass + + +class ValidationError(WorkflowsCompilerError): + pass + + +class InvalidSpecificationVersionError(ValidationError): + pass + + +class DuplicatedSymbolError(ValidationError): + pass + + +class InvalidReferenceError(ValidationError): + pass + + +class ExecutionGraphError(WorkflowsCompilerError): + pass + + +class SelectorToUndefinedNodeError(ExecutionGraphError): + pass + + +class NotAcyclicGraphError(ExecutionGraphError): + pass + + +class NodesNotReachingOutputError(ExecutionGraphError): + pass + + +class AmbiguousPathDetected(ExecutionGraphError): + pass + + +class InvalidStepInputDetected(ExecutionGraphError): + pass + + +class WorkflowsCompilerRuntimeError(WorkflowsCompilerError): + pass + + +class RuntimePayloadError(WorkflowsCompilerRuntimeError): + pass + + +class RuntimeParameterMissingError(RuntimePayloadError): + pass + + +class VariableTypeError(RuntimePayloadError): + pass + + +class ExecutionEngineError(WorkflowsCompilerRuntimeError): + pass diff --git a/inference/models/__init__.py b/inference/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..68001cf10c80fd363c4edca12dc37902f0af2f6f --- /dev/null +++ b/inference/models/__init__.py @@ -0,0 +1,46 @@ +try: + from inference.models.clip import Clip +except: + pass + +try: + from inference.models.gaze import Gaze +except: + pass + +try: + from inference.models.sam import SegmentAnything +except: + pass + +try: + from inference.models.doctr import DocTR +except: + pass + +try: + from inference.models.grounding_dino import GroundingDINO +except: + pass + +try: + from inference.models.cogvlm import CogVLM +except: + pass + +try: + from inference.models.yolo_world import YOLOWorld +except: + pass + +from inference.models.vit import VitClassification +from inference.models.yolact import YOLACT +from inference.models.yolonas import YOLONASObjectDetection +from inference.models.yolov5 import YOLOv5InstanceSegmentation, YOLOv5ObjectDetection +from inference.models.yolov7 import YOLOv7InstanceSegmentation +from inference.models.yolov8 import ( + YOLOv8Classification, + YOLOv8InstanceSegmentation, + YOLOv8KeypointsDetection, + YOLOv8ObjectDetection, +) diff --git a/inference/models/__pycache__/__init__.cpython-310.pyc b/inference/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d31e59730ec60734f41bbb88a627fb84c105518 Binary files /dev/null and b/inference/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/models/__pycache__/aliases.cpython-310.pyc b/inference/models/__pycache__/aliases.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40be9271c3753af5750d0d241f8121a1647028fc Binary files /dev/null and b/inference/models/__pycache__/aliases.cpython-310.pyc differ diff --git a/inference/models/__pycache__/utils.cpython-310.pyc b/inference/models/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7e6e490318c15a3092221d606fc0cadd166a3ac Binary files /dev/null and b/inference/models/__pycache__/utils.cpython-310.pyc differ diff --git a/inference/models/aliases.py b/inference/models/aliases.py new file mode 100644 index 0000000000000000000000000000000000000000..742ea3586a77d977afedc04a32eb6ce8a9d412fc --- /dev/null +++ b/inference/models/aliases.py @@ -0,0 +1,33 @@ +# We have a duplicate in inference_sdk.http.utils.aliases - please maintain both +# (to have aliasing work in both libraries) + + +REGISTERED_ALIASES = { + "yolov8n-640": "coco/3", + "yolov8n-1280": "coco/9", + "yolov8s-640": "coco/6", + "yolov8s-1280": "coco/10", + "yolov8m-640": "coco/8", + "yolov8m-1280": "coco/11", + "yolov8l-640": "coco/7", + "yolov8l-1280": "coco/12", + "yolov8x-640": "coco/5", + "yolov8x-1280": "coco/13", + "yolo-nas-s-640": "coco/14", + "yolo-nas-m-640": "coco/15", + "yolo-nas-l-640": "coco/16", + "yolov8n-seg-640": "coco-dataset-vdnr1/2", + "yolov8n-seg-1280": "coco-dataset-vdnr1/7", + "yolov8s-seg-640": "coco-dataset-vdnr1/4", + "yolov8s-seg-1280": "coco-dataset-vdnr1/8", + "yolov8m-seg-640": "coco-dataset-vdnr1/5", + "yolov8m-seg-1280": "coco-dataset-vdnr1/9", + "yolov8l-seg-640": "coco-dataset-vdnr1/6", + "yolov8l-seg-1280": "coco-dataset-vdnr1/10", + "yolov8x-seg-640": "coco-dataset-vdnr1/3", + "yolov8x-seg-1280": "coco-dataset-vdnr1/11", +} + + +def resolve_roboflow_model_alias(model_id: str) -> str: + return REGISTERED_ALIASES.get(model_id, model_id) diff --git a/inference/models/clip/__init__.py b/inference/models/clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..87e1f4adce984d8b30bbea749cc2875a1670c651 --- /dev/null +++ b/inference/models/clip/__init__.py @@ -0,0 +1 @@ +from inference.models.clip.clip_model import Clip diff --git a/inference/models/clip/__pycache__/__init__.cpython-310.pyc b/inference/models/clip/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38ebdc8d1a696e0f8041fb5aae5f32ba1f04fc7f Binary files /dev/null and b/inference/models/clip/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/models/clip/__pycache__/clip_model.cpython-310.pyc b/inference/models/clip/__pycache__/clip_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6c537ee8f96523efc257818e416021502013ae4 Binary files /dev/null and b/inference/models/clip/__pycache__/clip_model.cpython-310.pyc differ diff --git a/inference/models/clip/clip_model.py b/inference/models/clip/clip_model.py new file mode 100644 index 0000000000000000000000000000000000000000..d5b42624875f3fd5e18124441773473f24089660 --- /dev/null +++ b/inference/models/clip/clip_model.py @@ -0,0 +1,371 @@ +from time import perf_counter +from typing import Any, Dict, List, Tuple, Union + +import clip +import numpy as np +import onnxruntime +from PIL import Image + +from inference.core.entities.requests.clip import ( + ClipCompareRequest, + ClipImageEmbeddingRequest, + ClipInferenceRequest, + ClipTextEmbeddingRequest, +) +from inference.core.entities.requests.inference import InferenceRequestImage +from inference.core.entities.responses.clip import ( + ClipCompareResponse, + ClipEmbeddingResponse, +) +from inference.core.entities.responses.inference import InferenceResponse +from inference.core.env import ( + CLIP_MAX_BATCH_SIZE, + CLIP_MODEL_ID, + ONNXRUNTIME_EXECUTION_PROVIDERS, + REQUIRED_ONNX_PROVIDERS, + TENSORRT_CACHE_PATH, +) +from inference.core.exceptions import OnnxProviderNotAvailable +from inference.core.models.roboflow import OnnxRoboflowCoreModel +from inference.core.models.types import PreprocessReturnMetadata +from inference.core.utils.image_utils import load_image_rgb +from inference.core.utils.onnx import get_onnxruntime_execution_providers +from inference.core.utils.postprocess import cosine_similarity + + +class Clip(OnnxRoboflowCoreModel): + """Roboflow ONNX ClipModel model. + + This class is responsible for handling the ONNX ClipModel model, including + loading the model, preprocessing the input, and performing inference. + + Attributes: + visual_onnx_session (onnxruntime.InferenceSession): ONNX Runtime session for visual inference. + textual_onnx_session (onnxruntime.InferenceSession): ONNX Runtime session for textual inference. + resolution (int): The resolution of the input image. + clip_preprocess (function): Function to preprocess the image. + """ + + def __init__( + self, + *args, + model_id: str = CLIP_MODEL_ID, + onnxruntime_execution_providers: List[ + str + ] = get_onnxruntime_execution_providers(ONNXRUNTIME_EXECUTION_PROVIDERS), + **kwargs, + ): + """Initializes the Clip with the given arguments and keyword arguments.""" + self.onnxruntime_execution_providers = onnxruntime_execution_providers + t1 = perf_counter() + super().__init__(*args, model_id=model_id, **kwargs) + # Create an ONNX Runtime Session with a list of execution providers in priority order. ORT attempts to load providers until one is successful. This keeps the code across devices identical. + self.log("Creating inference sessions") + self.visual_onnx_session = onnxruntime.InferenceSession( + self.cache_file("visual.onnx"), + providers=self.onnxruntime_execution_providers, + ) + + self.textual_onnx_session = onnxruntime.InferenceSession( + self.cache_file("textual.onnx"), + providers=self.onnxruntime_execution_providers, + ) + + if REQUIRED_ONNX_PROVIDERS: + available_providers = onnxruntime.get_available_providers() + for provider in REQUIRED_ONNX_PROVIDERS: + if provider not in available_providers: + raise OnnxProviderNotAvailable( + f"Required ONNX Execution Provider {provider} is not availble. Check that you are using the correct docker image on a supported device." + ) + + self.resolution = self.visual_onnx_session.get_inputs()[0].shape[2] + + self.clip_preprocess = clip.clip._transform(self.resolution) + self.log(f"CLIP model loaded in {perf_counter() - t1:.2f} seconds") + self.task_type = "embedding" + + def compare( + self, + subject: Any, + prompt: Any, + subject_type: str = "image", + prompt_type: Union[str, List[str], Dict[str, Any]] = "text", + **kwargs, + ) -> Union[List[float], Dict[str, float]]: + """ + Compares the subject with the prompt to calculate similarity scores. + + Args: + subject (Any): The subject data to be compared. Can be either an image or text. + prompt (Any): The prompt data to be compared against the subject. Can be a single value (image/text), list of values, or dictionary of values. + subject_type (str, optional): Specifies the type of the subject data. Must be either "image" or "text". Defaults to "image". + prompt_type (Union[str, List[str], Dict[str, Any]], optional): Specifies the type of the prompt data. Can be "image", "text", list of these types, or a dictionary containing these types. Defaults to "text". + **kwargs: Additional keyword arguments. + + Returns: + Union[List[float], Dict[str, float]]: A list or dictionary containing cosine similarity scores between the subject and prompt(s). If prompt is a dictionary, returns a dictionary with keys corresponding to the original prompt dictionary's keys. + + Raises: + ValueError: If subject_type or prompt_type is neither "image" nor "text". + ValueError: If the number of prompts exceeds the maximum batch size. + """ + + if subject_type == "image": + subject_embeddings = self.embed_image(subject) + elif subject_type == "text": + subject_embeddings = self.embed_text(subject) + else: + raise ValueError( + "subject_type must be either 'image' or 'text', but got {request.subject_type}" + ) + + if isinstance(prompt, dict) and not ("type" in prompt and "value" in prompt): + prompt_keys = prompt.keys() + prompt = [prompt[k] for k in prompt_keys] + prompt_obj = "dict" + else: + prompt = prompt + if not isinstance(prompt, list): + prompt = [prompt] + prompt_obj = "list" + + if len(prompt) > CLIP_MAX_BATCH_SIZE: + raise ValueError( + f"The maximum number of prompts that can be compared at once is {CLIP_MAX_BATCH_SIZE}" + ) + + if prompt_type == "image": + prompt_embeddings = self.embed_image(prompt) + elif prompt_type == "text": + prompt_embeddings = self.embed_text(prompt) + else: + raise ValueError( + "prompt_type must be either 'image' or 'text', but got {request.prompt_type}" + ) + + similarities = [ + cosine_similarity(subject_embeddings, p) for p in prompt_embeddings + ] + + if prompt_obj == "dict": + similarities = dict(zip(prompt_keys, similarities)) + + return similarities + + def make_compare_response( + self, similarities: Union[List[float], Dict[str, float]] + ) -> ClipCompareResponse: + """ + Creates a ClipCompareResponse object from the provided similarity data. + + Args: + similarities (Union[List[float], Dict[str, float]]): A list or dictionary containing similarity scores. + + Returns: + ClipCompareResponse: An instance of the ClipCompareResponse with the given similarity scores. + + Example: + Assuming `ClipCompareResponse` expects a dictionary of string-float pairs: + + >>> make_compare_response({"image1": 0.98, "image2": 0.76}) + ClipCompareResponse(similarity={"image1": 0.98, "image2": 0.76}) + """ + response = ClipCompareResponse(similarity=similarities) + return response + + def embed_image( + self, + image: Any, + **kwargs, + ) -> np.ndarray: + """ + Embeds an image or a list of images using the Clip model. + + Args: + image (Any): The image or list of images to be embedded. Image can be in any format that is acceptable by the preproc_image method. + **kwargs: Additional keyword arguments. + + Returns: + np.ndarray: The embeddings of the image(s) as a numpy array. + + Raises: + ValueError: If the number of images in the list exceeds the maximum batch size. + + Notes: + The function measures performance using perf_counter and also has support for ONNX session to get embeddings. + """ + t1 = perf_counter() + + if isinstance(image, list): + if len(image) > CLIP_MAX_BATCH_SIZE: + raise ValueError( + f"The maximum number of images that can be embedded at once is {CLIP_MAX_BATCH_SIZE}" + ) + imgs = [self.preproc_image(i) for i in image] + img_in = np.concatenate(imgs, axis=0) + else: + img_in = self.preproc_image(image) + + onnx_input_image = {self.visual_onnx_session.get_inputs()[0].name: img_in} + embeddings = self.visual_onnx_session.run(None, onnx_input_image)[0] + + return embeddings + + def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]: + onnx_input_image = {self.visual_onnx_session.get_inputs()[0].name: img_in} + embeddings = self.visual_onnx_session.run(None, onnx_input_image)[0] + return (embeddings,) + + def make_embed_image_response( + self, embeddings: np.ndarray + ) -> ClipEmbeddingResponse: + """ + Converts the given embeddings into a ClipEmbeddingResponse object. + + Args: + embeddings (np.ndarray): A numpy array containing the embeddings for an image or images. + + Returns: + ClipEmbeddingResponse: An instance of the ClipEmbeddingResponse with the provided embeddings converted to a list. + + Example: + >>> embeddings_array = np.array([[0.5, 0.3, 0.2], [0.1, 0.9, 0.0]]) + >>> make_embed_image_response(embeddings_array) + ClipEmbeddingResponse(embeddings=[[0.5, 0.3, 0.2], [0.1, 0.9, 0.0]]) + """ + response = ClipEmbeddingResponse(embeddings=embeddings.tolist()) + + return response + + def embed_text( + self, + text: Union[str, List[str]], + **kwargs, + ) -> np.ndarray: + """ + Embeds a text or a list of texts using the Clip model. + + Args: + text (Union[str, List[str]]): The text string or list of text strings to be embedded. + **kwargs: Additional keyword arguments. + + Returns: + np.ndarray: The embeddings of the text or texts as a numpy array. + + Raises: + ValueError: If the number of text strings in the list exceeds the maximum batch size. + + Notes: + The function utilizes an ONNX session to compute embeddings and measures the embedding time with perf_counter. + """ + t1 = perf_counter() + + if isinstance(text, list): + if len(text) > CLIP_MAX_BATCH_SIZE: + raise ValueError( + f"The maximum number of text strings that can be embedded at once is {CLIP_MAX_BATCH_SIZE}" + ) + + texts = text + else: + texts = [text] + + texts = clip.tokenize(texts).numpy().astype(np.int32) + + onnx_input_text = {self.textual_onnx_session.get_inputs()[0].name: texts} + embeddings = self.textual_onnx_session.run(None, onnx_input_text)[0] + + return embeddings + + def make_embed_text_response(self, embeddings: np.ndarray) -> ClipEmbeddingResponse: + """ + Converts the given text embeddings into a ClipEmbeddingResponse object. + + Args: + embeddings (np.ndarray): A numpy array containing the embeddings for a text or texts. + + Returns: + ClipEmbeddingResponse: An instance of the ClipEmbeddingResponse with the provided embeddings converted to a list. + + Example: + >>> embeddings_array = np.array([[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]]) + >>> make_embed_text_response(embeddings_array) + ClipEmbeddingResponse(embeddings=[[0.8, 0.1, 0.1], [0.4, 0.5, 0.1]]) + """ + response = ClipEmbeddingResponse(embeddings=embeddings.tolist()) + return response + + def get_infer_bucket_file_list(self) -> List[str]: + """Gets the list of files required for inference. + + Returns: + List[str]: The list of file names. + """ + return ["textual.onnx", "visual.onnx"] + + def infer_from_request( + self, request: ClipInferenceRequest + ) -> ClipEmbeddingResponse: + """Routes the request to the appropriate inference function. + + Args: + request (ClipInferenceRequest): The request object containing the inference details. + + Returns: + ClipEmbeddingResponse: The response object containing the embeddings. + """ + t1 = perf_counter() + if isinstance(request, ClipImageEmbeddingRequest): + infer_func = self.embed_image + make_response_func = self.make_embed_image_response + elif isinstance(request, ClipTextEmbeddingRequest): + infer_func = self.embed_text + make_response_func = self.make_embed_text_response + elif isinstance(request, ClipCompareRequest): + infer_func = self.compare + make_response_func = self.make_compare_response + else: + raise ValueError( + f"Request type {type(request)} is not a valid ClipInferenceRequest" + ) + data = infer_func(**request.dict()) + response = make_response_func(data) + response.time = perf_counter() - t1 + return response + + def make_response(self, embeddings, *args, **kwargs) -> InferenceResponse: + return [self.make_embed_image_response(embeddings)] + + def postprocess( + self, + predictions: Tuple[np.ndarray], + preprocess_return_metadata: PreprocessReturnMetadata, + **kwargs, + ) -> Any: + return [self.make_embed_image_response(predictions[0])] + + def infer(self, image: Any, **kwargs) -> Any: + """Embeds an image""" + return super().infer(image, **kwargs) + + def preproc_image(self, image: InferenceRequestImage) -> np.ndarray: + """Preprocesses an inference request image. + + Args: + image (InferenceRequestImage): The object containing information necessary to load the image for inference. + + Returns: + np.ndarray: A numpy array of the preprocessed image pixel data. + """ + pil_image = Image.fromarray(load_image_rgb(image)) + preprocessed_image = self.clip_preprocess(pil_image) + + img_in = np.expand_dims(preprocessed_image, axis=0) + + return img_in.astype(np.float32) + + def preprocess( + self, image: Any, **kwargs + ) -> Tuple[np.ndarray, PreprocessReturnMetadata]: + return self.preproc_image(image), PreprocessReturnMetadata({}) diff --git a/inference/models/cogvlm/__init__.py b/inference/models/cogvlm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c340e28888ef83462a786aa47020bf262bfeaf09 --- /dev/null +++ b/inference/models/cogvlm/__init__.py @@ -0,0 +1 @@ +from inference.models.cogvlm.cogvlm import CogVLM diff --git a/inference/models/cogvlm/__pycache__/__init__.cpython-310.pyc b/inference/models/cogvlm/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7248f99d40a3a1d7783262f9ac10ddf13b4dc852 Binary files /dev/null and b/inference/models/cogvlm/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/models/cogvlm/__pycache__/cogvlm.cpython-310.pyc b/inference/models/cogvlm/__pycache__/cogvlm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c85b41719a8a17a7fa65ee2658fb9de87a8c8f5 Binary files /dev/null and b/inference/models/cogvlm/__pycache__/cogvlm.cpython-310.pyc differ diff --git a/inference/models/cogvlm/cogvlm.py b/inference/models/cogvlm/cogvlm.py new file mode 100644 index 0000000000000000000000000000000000000000..5adbf1b0e7974846a5fb9b9d52e925417940dffc --- /dev/null +++ b/inference/models/cogvlm/cogvlm.py @@ -0,0 +1,98 @@ +import os +from time import perf_counter +from typing import Any, List, Tuple, Union + +import numpy as np +import requests +import torch +from PIL import Image +from transformers import AutoModelForCausalLM, LlamaTokenizer + +from inference.core.entities.requests.cogvlm import CogVLMInferenceRequest +from inference.core.entities.responses.cogvlm import CogVLMResponse +from inference.core.env import ( + API_KEY, + COGVLM_LOAD_4BIT, + COGVLM_LOAD_8BIT, + COGVLM_VERSION_ID, + MODEL_CACHE_DIR, +) +from inference.core.models.base import Model, PreprocessReturnMetadata +from inference.core.utils.image_utils import load_image_rgb + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +class CogVLM(Model): + def __init__(self, model_id=f"cogvlm/{COGVLM_VERSION_ID}", **kwargs): + self.model_id = model_id + self.endpoint = model_id + self.api_key = API_KEY + self.dataset_id, self.version_id = model_id.split("/") + if COGVLM_LOAD_4BIT and COGVLM_LOAD_8BIT: + raise ValueError( + "Only one of environment variable `COGVLM_LOAD_4BIT` or `COGVLM_LOAD_8BIT` can be true" + ) + self.cache_dir = os.path.join(MODEL_CACHE_DIR, self.endpoint) + with torch.inference_mode(): + self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5") + self.model = AutoModelForCausalLM.from_pretrained( + f"THUDM/{self.version_id}", + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + trust_remote_code=True, + load_in_4bit=COGVLM_LOAD_4BIT, + load_in_8bit=COGVLM_LOAD_8BIT, + cache_dir=self.cache_dir, + ).eval() + self.task_type = "lmm" + + def preprocess( + self, image: Any, **kwargs + ) -> Tuple[Image.Image, PreprocessReturnMetadata]: + pil_image = Image.fromarray(load_image_rgb(image)) + + return pil_image, PreprocessReturnMetadata({}) + + def postprocess( + self, + predictions: Tuple[str], + preprocess_return_metadata: PreprocessReturnMetadata, + **kwargs, + ) -> Any: + return predictions[0] + + def predict(self, image_in: Image.Image, prompt="", history=None, **kwargs): + images = [image_in] + if history is None: + history = [] + built_inputs = self.model.build_conversation_input_ids( + self.tokenizer, query=prompt, history=history, images=images + ) # chat mode + inputs = { + "input_ids": built_inputs["input_ids"].unsqueeze(0).to(DEVICE), + "token_type_ids": built_inputs["token_type_ids"].unsqueeze(0).to(DEVICE), + "attention_mask": built_inputs["attention_mask"].unsqueeze(0).to(DEVICE), + "images": [[built_inputs["images"][0].to(DEVICE).to(torch.float16)]], + } + gen_kwargs = {"max_length": 2048, "do_sample": False} + + with torch.inference_mode(): + outputs = self.model.generate(**inputs, **gen_kwargs) + outputs = outputs[:, inputs["input_ids"].shape[1] :] + text = self.tokenizer.decode(outputs[0]) + if text.endswith(""): + text = text[:-4] + return (text,) + + def infer_from_request(self, request: CogVLMInferenceRequest) -> CogVLMResponse: + t1 = perf_counter() + text = self.infer(**request.dict()) + response = CogVLMResponse(response=text) + response.time = perf_counter() - t1 + return response + + +if __name__ == "__main__": + m = CogVLM() + m.infer() diff --git a/inference/models/doctr/__init__.py b/inference/models/doctr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..64ed3603bb21a5b63e5c646fcf8075fe6db4acb0 --- /dev/null +++ b/inference/models/doctr/__init__.py @@ -0,0 +1 @@ +from inference.models.doctr.doctr_model import DocTR diff --git a/inference/models/doctr/__pycache__/__init__.cpython-310.pyc b/inference/models/doctr/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3180109c61af36f2650664d2cecc07708412835b Binary files /dev/null and b/inference/models/doctr/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/models/doctr/__pycache__/doctr_model.cpython-310.pyc b/inference/models/doctr/__pycache__/doctr_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2007f5c3ef941880e3f3da2b6852422ae235aae3 Binary files /dev/null and b/inference/models/doctr/__pycache__/doctr_model.cpython-310.pyc differ diff --git a/inference/models/doctr/doctr_model.py b/inference/models/doctr/doctr_model.py new file mode 100644 index 0000000000000000000000000000000000000000..446c95d69714bfb416ac66ffce0fa30dd655ff57 --- /dev/null +++ b/inference/models/doctr/doctr_model.py @@ -0,0 +1,173 @@ +import os +import shutil +import tempfile +from time import perf_counter +from typing import Any, List, Union + +from doctr import models as models +from doctr.io import DocumentFile +from doctr.models import ocr_predictor +from PIL import Image + +from inference.core.entities.requests.doctr import DoctrOCRInferenceRequest +from inference.core.entities.requests.inference import InferenceRequest +from inference.core.entities.responses.doctr import DoctrOCRInferenceResponse +from inference.core.entities.responses.inference import InferenceResponse +from inference.core.env import MODEL_CACHE_DIR +from inference.core.models.roboflow import RoboflowCoreModel +from inference.core.utils.image_utils import load_image + + +class DocTR(RoboflowCoreModel): + def __init__(self, *args, model_id: str = "doctr_rec/crnn_vgg16_bn", **kwargs): + """Initializes the DocTR model. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + self.api_key = kwargs.get("api_key") + self.dataset_id = "doctr" + self.version_id = "default" + self.endpoint = model_id + model_id = model_id.lower() + + os.environ["DOCTR_CACHE_DIR"] = os.path.join(MODEL_CACHE_DIR, "doctr_rec") + + self.det_model = DocTRDet(api_key=kwargs.get("api_key")) + self.rec_model = DocTRRec(api_key=kwargs.get("api_key")) + + os.makedirs(f"{MODEL_CACHE_DIR}/doctr_rec/models/", exist_ok=True) + os.makedirs(f"{MODEL_CACHE_DIR}/doctr_det/models/", exist_ok=True) + + shutil.copyfile( + f"{MODEL_CACHE_DIR}/doctr_det/db_resnet50/model.pt", + f"{MODEL_CACHE_DIR}/doctr_det/models/db_resnet50-ac60cadc.pt", + ) + shutil.copyfile( + f"{MODEL_CACHE_DIR}/doctr_rec/crnn_vgg16_bn/model.pt", + f"{MODEL_CACHE_DIR}/doctr_rec/models/crnn_vgg16_bn-9762b0b0.pt", + ) + + self.model = ocr_predictor( + det_arch=self.det_model.version_id, + reco_arch=self.rec_model.version_id, + pretrained=True, + ) + self.task_type = "ocr" + + def clear_cache(self) -> None: + self.det_model.clear_cache() + self.rec_model.clear_cache() + + def preprocess_image(self, image: Image.Image) -> Image.Image: + """ + DocTR pre-processes images as part of its inference pipeline. + + Thus, no preprocessing is required here. + """ + pass + + def infer_from_request( + self, request: DoctrOCRInferenceRequest + ) -> DoctrOCRInferenceResponse: + t1 = perf_counter() + result = self.infer(**request.dict()) + return DoctrOCRInferenceResponse( + result=result, + time=perf_counter() - t1, + ) + + def infer(self, image: Any, **kwargs): + """ + Run inference on a provided image. + + Args: + request (DoctrOCRInferenceRequest): The inference request. + + Returns: + DoctrOCRInferenceResponse: The inference response. + """ + + img = load_image(image) + + with tempfile.NamedTemporaryFile(suffix=".jpg") as f: + image = Image.fromarray(img[0]) + + image.save(f.name) + + doc = DocumentFile.from_images([f.name]) + + result = self.model(doc).export() + + result = result["pages"][0]["blocks"] + + result = [ + " ".join([word["value"] for word in line["words"]]) + for block in result + for line in block["lines"] + ] + + result = " ".join(result) + + return result + + def get_infer_bucket_file_list(self) -> list: + """Get the list of required files for inference. + + Returns: + list: A list of required files for inference, e.g., ["model.pt"]. + """ + return ["model.pt"] + + +class DocTRRec(RoboflowCoreModel): + def __init__(self, *args, model_id: str = "doctr_rec/crnn_vgg16_bn", **kwargs): + """Initializes the DocTR model. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + pass + + self.get_infer_bucket_file_list() + + super().__init__(*args, model_id=model_id, **kwargs) + + def get_infer_bucket_file_list(self) -> list: + """Get the list of required files for inference. + + Returns: + list: A list of required files for inference, e.g., ["model.pt"]. + """ + return ["model.pt"] + + +class DocTRDet(RoboflowCoreModel): + """DocTR class for document Optical Character Recognition (OCR). + + Attributes: + doctr: The DocTR model. + ort_session: ONNX runtime inference session. + """ + + def __init__(self, *args, model_id: str = "doctr_det/db_resnet50", **kwargs): + """Initializes the DocTR model. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + + self.get_infer_bucket_file_list() + + super().__init__(*args, model_id=model_id, **kwargs) + + def get_infer_bucket_file_list(self) -> list: + """Get the list of required files for inference. + + Returns: + list: A list of required files for inference, e.g., ["model.pt"]. + """ + return ["model.pt"] diff --git a/inference/models/gaze/__init__.py b/inference/models/gaze/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6028ddab9303321234b56e57d6681a8575328290 --- /dev/null +++ b/inference/models/gaze/__init__.py @@ -0,0 +1 @@ +from inference.models.gaze.gaze import Gaze diff --git a/inference/models/gaze/__pycache__/__init__.cpython-310.pyc b/inference/models/gaze/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bc6b4e5a891e74d1c9e30eebe4ccdae01ca2979 Binary files /dev/null and b/inference/models/gaze/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/models/gaze/__pycache__/gaze.cpython-310.pyc b/inference/models/gaze/__pycache__/gaze.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ece8bf00f8db828e09014852b29270dc528b620 Binary files /dev/null and b/inference/models/gaze/__pycache__/gaze.cpython-310.pyc differ diff --git a/inference/models/gaze/__pycache__/l2cs.cpython-310.pyc b/inference/models/gaze/__pycache__/l2cs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a0dc9aead3b0c612b492ccdf3f3b6a4c3d40733 Binary files /dev/null and b/inference/models/gaze/__pycache__/l2cs.cpython-310.pyc differ diff --git a/inference/models/gaze/gaze.py b/inference/models/gaze/gaze.py new file mode 100644 index 0000000000000000000000000000000000000000..28cc074febaddd2e48eb93fc2f2ff833edffd5b9 --- /dev/null +++ b/inference/models/gaze/gaze.py @@ -0,0 +1,366 @@ +import math +from time import perf_counter +from typing import List, Optional, Tuple, Union + +import cv2 +import mediapipe as mp +import numpy as np +import onnxruntime +import torch +import torch.nn as nn +import torchvision +from mediapipe.tasks.python.components.containers.bounding_box import BoundingBox +from mediapipe.tasks.python.components.containers.category import Category +from mediapipe.tasks.python.components.containers.detections import Detection +from torchvision import transforms + +from inference.core.entities.requests.gaze import GazeDetectionInferenceRequest +from inference.core.entities.responses.gaze import ( + GazeDetectionInferenceResponse, + GazeDetectionPrediction, +) +from inference.core.entities.responses.inference import FaceDetectionPrediction, Point +from inference.core.env import ( + GAZE_MAX_BATCH_SIZE, + MODEL_CACHE_DIR, + REQUIRED_ONNX_PROVIDERS, + TENSORRT_CACHE_PATH, +) +from inference.core.exceptions import OnnxProviderNotAvailable +from inference.core.models.roboflow import OnnxRoboflowCoreModel +from inference.core.utils.image_utils import load_image_rgb +from inference.models.gaze.l2cs import L2CS + + +class Gaze(OnnxRoboflowCoreModel): + """Roboflow ONNX Gaze model. + + This class is responsible for handling the ONNX Gaze model, including + loading the model, preprocessing the input, and performing inference. + + Attributes: + gaze_onnx_session (onnxruntime.InferenceSession): ONNX Runtime session for gaze detection inference. + """ + + def __init__(self, *args, **kwargs): + """Initializes the Gaze with the given arguments and keyword arguments.""" + + t1 = perf_counter() + super().__init__(*args, **kwargs) + # Create an ONNX Runtime Session with a list of execution providers in priority order. ORT attempts to load providers until one is successful. This keeps the code across devices identical. + self.log("Creating inference sessions") + + # TODO: convert face detector (TensorflowLite) to ONNX model + + self.gaze_onnx_session = onnxruntime.InferenceSession( + self.cache_file("L2CSNet_gaze360_resnet50_90bins.onnx"), + providers=[ + ( + "TensorrtExecutionProvider", + { + "trt_engine_cache_enable": True, + "trt_engine_cache_path": TENSORRT_CACHE_PATH, + }, + ), + "CUDAExecutionProvider", + "CPUExecutionProvider", + ], + ) + + if REQUIRED_ONNX_PROVIDERS: + available_providers = onnxruntime.get_available_providers() + for provider in REQUIRED_ONNX_PROVIDERS: + if provider not in available_providers: + raise OnnxProviderNotAvailable( + f"Required ONNX Execution Provider {provider} is not availble. Check that you are using the correct docker image on a supported device." + ) + + # init face detector + self.face_detector = mp.tasks.vision.FaceDetector.create_from_options( + mp.tasks.vision.FaceDetectorOptions( + base_options=mp.tasks.BaseOptions( + model_asset_path=self.cache_file("mediapipe_face_detector.tflite") + ), + running_mode=mp.tasks.vision.RunningMode.IMAGE, + ) + ) + + # additional settings for gaze detection + self._gaze_transformations = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Resize(448), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + self.task_type = "gaze-detection" + self.log(f"GAZE model loaded in {perf_counter() - t1:.2f} seconds") + + def _crop_face_img(self, np_img: np.ndarray, face: Detection) -> np.ndarray: + """Extract facial area in an image. + + Args: + np_img (np.ndarray): The numpy image. + face (mediapipe.tasks.python.components.containers.detections.Detection): The detected face. + + Returns: + np.ndarray: Cropped face image. + """ + # extract face area + bbox = face.bounding_box + x_min = bbox.origin_x + y_min = bbox.origin_y + x_max = bbox.origin_x + bbox.width + y_max = bbox.origin_y + bbox.height + face_img = np_img[y_min:y_max, x_min:x_max, :] + face_img = cv2.resize(face_img, (224, 224)) + return face_img + + def _detect_gaze(self, np_imgs: List[np.ndarray]) -> List[Tuple[float, float]]: + """Detect faces and gazes in an image. + + Args: + pil_imgs (List[np.ndarray]): The numpy image list, each image is a cropped facial image. + + Returns: + List[Tuple[float, float]]: Yaw (radian) and Pitch (radian). + """ + ret = [] + for i in range(0, len(np_imgs), GAZE_MAX_BATCH_SIZE): + img_batch = [] + for j in range(i, min(len(np_imgs), i + GAZE_MAX_BATCH_SIZE)): + img = self._gaze_transformations(np_imgs[j]) + img = np.expand_dims(img, axis=0).astype(np.float32) + img_batch.append(img) + + img_batch = np.concatenate(img_batch, axis=0) + onnx_input_image = {self.gaze_onnx_session.get_inputs()[0].name: img_batch} + yaw, pitch = self.gaze_onnx_session.run(None, onnx_input_image) + + for j in range(len(img_batch)): + ret.append((yaw[j], pitch[j])) + + return ret + + def _make_response( + self, + faces: List[Detection], + gazes: List[Tuple[float, float]], + imgW: int, + imgH: int, + time_total: float, + time_face_det: float = None, + time_gaze_det: float = None, + ) -> GazeDetectionInferenceResponse: + """Prepare response object from detected faces and corresponding gazes. + + Args: + faces (List[Detection]): The detected faces. + gazes (List[tuple(float, float)]): The detected gazes (yaw, pitch). + imgW (int): The width (px) of original image. + imgH (int): The height (px) of original image. + time_total (float): The processing time. + time_face_det (float): The processing time. + time_gaze_det (float): The processing time. + + Returns: + GazeDetectionInferenceResponse: The response object including the detected faces and gazes info. + """ + predictions = [] + for face, gaze in zip(faces, gazes): + landmarks = [] + for keypoint in face.keypoints: + x = min(max(int(keypoint.x * imgW), 0), imgW - 1) + y = min(max(int(keypoint.y * imgH), 0), imgH - 1) + landmarks.append(Point(x=x, y=y)) + + bbox = face.bounding_box + x_center = bbox.origin_x + bbox.width / 2 + y_center = bbox.origin_y + bbox.height / 2 + score = face.categories[0].score + + prediction = GazeDetectionPrediction( + face=FaceDetectionPrediction( + x=x_center, + y=y_center, + width=bbox.width, + height=bbox.height, + confidence=score, + class_name="face", + landmarks=landmarks, + ), + yaw=gaze[0], + pitch=gaze[1], + ) + predictions.append(prediction) + + response = GazeDetectionInferenceResponse( + predictions=predictions, + time=time_total, + time_face_det=time_face_det, + time_gaze_det=time_gaze_det, + ) + return response + + def get_infer_bucket_file_list(self) -> List[str]: + """Gets the list of files required for inference. + + Returns: + List[str]: The list of file names. + """ + return [ + "mediapipe_face_detector.tflite", + "L2CSNet_gaze360_resnet50_90bins.onnx", + ] + + def infer_from_request( + self, request: GazeDetectionInferenceRequest + ) -> List[GazeDetectionInferenceResponse]: + """Detect faces and gazes in image(s). + + Args: + request (GazeDetectionInferenceRequest): The request object containing the image. + + Returns: + List[GazeDetectionInferenceResponse]: The list of response objects containing the faces and corresponding gazes. + """ + if isinstance(request.image, list): + if len(request.image) > GAZE_MAX_BATCH_SIZE: + raise ValueError( + f"The maximum number of images that can be inferred with gaze detection at one time is {GAZE_MAX_BATCH_SIZE}" + ) + imgs = request.image + else: + imgs = [request.image] + + time_total = perf_counter() + + # load pil images + num_img = len(imgs) + np_imgs = [load_image_rgb(img) for img in imgs] + + # face detection + # TODO: face detection for batch + time_face_det = perf_counter() + faces = [] + for np_img in np_imgs: + if request.do_run_face_detection: + mp_img = mp.Image( + image_format=mp.ImageFormat.SRGB, data=np_img.astype(np.uint8) + ) + faces_per_img = self.face_detector.detect(mp_img).detections + else: + faces_per_img = [ + Detection( + bounding_box=BoundingBox( + origin_x=0, + origin_y=0, + width=np_img.shape[1], + height=np_img.shape[0], + ), + categories=[Category(score=1.0, category_name="face")], + keypoints=[], + ) + ] + faces.append(faces_per_img) + time_face_det = (perf_counter() - time_face_det) / num_img + + # gaze detection + time_gaze_det = perf_counter() + face_imgs = [] + for i, np_img in enumerate(np_imgs): + if request.do_run_face_detection: + face_imgs.extend( + [self._crop_face_img(np_img, face) for face in faces[i]] + ) + else: + face_imgs.append(cv2.resize(np_img, (224, 224))) + gazes = self._detect_gaze(face_imgs) + time_gaze_det = (perf_counter() - time_gaze_det) / num_img + + time_total = (perf_counter() - time_total) / num_img + + # prepare response + response = [] + idx_gaze = 0 + for i in range(len(np_imgs)): + imgH, imgW, _ = np_imgs[i].shape + faces_per_img = faces[i] + gazes_per_img = gazes[idx_gaze : idx_gaze + len(faces_per_img)] + response.append( + self._make_response( + faces_per_img, gazes_per_img, imgW, imgH, time_total + ) + ) + + return response + + +class L2C2Wrapper(L2CS): + """Roboflow L2CS Gaze detection model. + + This class is responsible for converting L2CS model to ONNX model. + It is ONLY intended for internal usage. + + Workflow: + After training a L2CS model, create an instance of this wrapper class. + Load the trained weights file, and save it as ONNX model. + """ + + def __init__(self): + self.device = torch.device("cpu") + self.num_bins = 90 + super().__init__( + torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], self.num_bins + ) + self._gaze_softmax = nn.Softmax(dim=1) + self._gaze_idx_tensor = torch.FloatTensor([i for i in range(90)]).to( + self.device + ) + + def forward(self, x): + idx_tensor = torch.stack( + [self._gaze_idx_tensor for i in range(x.shape[0])], dim=0 + ) + gaze_yaw, gaze_pitch = super().forward(x) + + yaw_predicted = self._gaze_softmax(gaze_yaw) + yaw_radian = ( + (torch.sum(yaw_predicted * idx_tensor, dim=1) * 4 - 180) * np.pi / 180 + ) + + pitch_predicted = self._gaze_softmax(gaze_pitch) + pitch_radian = ( + (torch.sum(pitch_predicted * idx_tensor, dim=1) * 4 - 180) * np.pi / 180 + ) + + return yaw_radian, pitch_radian + + def load_L2CS_model( + self, + file_path=f"{MODEL_CACHE_DIR}/gaze/L2CS/L2CSNet_gaze360_resnet50_90bins.pkl", + ): + super().load_state_dict(torch.load(file_path, map_location=self.device)) + super().to(self.device) + + def saveas_ONNX_model( + self, + file_path=f"{MODEL_CACHE_DIR}/gaze/L2CS/L2CSNet_gaze360_resnet50_90bins.onnx", + ): + dummy_input = torch.randn(1, 3, 448, 448) + dynamic_axes = { + "input": {0: "batch_size"}, + "output_yaw": {0: "batch_size"}, + "output_pitch": {0: "batch_size"}, + } + torch.onnx.export( + self, + dummy_input, + file_path, + input_names=["input"], + output_names=["output_yaw", "output_pitch"], + dynamic_axes=dynamic_axes, + verbose=False, + ) diff --git a/inference/models/gaze/l2cs.py b/inference/models/gaze/l2cs.py new file mode 100644 index 0000000000000000000000000000000000000000..2a814b92de4e93d3729b47ee93b533be5d5ec293 --- /dev/null +++ b/inference/models/gaze/l2cs.py @@ -0,0 +1,84 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable + + +class L2CS(nn.Module): + """L2CS Gaze Detection Model. + + This class is responsible for performing gaze detection using the L2CS-Net model. + Ref: https://github.com/Ahmednull/L2CS-Net + + Methods: + forward: Performs inference on the given image. + """ + + def __init__(self, block, layers, num_bins): + self.inplanes = 64 + super(L2CS, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + self.fc_yaw_gaze = nn.Linear(512 * block.expansion, num_bins) + self.fc_pitch_gaze = nn.Linear(512 * block.expansion, num_bins) + + # Vestigial layer from previous experiments + self.fc_finetune = nn.Linear(512 * block.expansion + 3, 3) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2.0 / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + + # gaze + pre_yaw_gaze = self.fc_yaw_gaze(x) + pre_pitch_gaze = self.fc_pitch_gaze(x) + return pre_yaw_gaze, pre_pitch_gaze diff --git a/inference/models/grounding_dino/__init__.py b/inference/models/grounding_dino/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b2d5c932a9bf4bbf2fdaaedad4ac6f74053673b1 --- /dev/null +++ b/inference/models/grounding_dino/__init__.py @@ -0,0 +1 @@ +from inference.models.grounding_dino.grounding_dino import GroundingDINO diff --git a/inference/models/grounding_dino/__pycache__/__init__.cpython-310.pyc b/inference/models/grounding_dino/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da1e347ebda6baf4468b8d30cd20dadf1b73065d Binary files /dev/null and b/inference/models/grounding_dino/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/models/grounding_dino/__pycache__/grounding_dino.cpython-310.pyc b/inference/models/grounding_dino/__pycache__/grounding_dino.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..571d8a60ac19b2785efa13517f9c7ffbc9abbf75 Binary files /dev/null and b/inference/models/grounding_dino/__pycache__/grounding_dino.cpython-310.pyc differ diff --git a/inference/models/grounding_dino/grounding_dino.py b/inference/models/grounding_dino/grounding_dino.py new file mode 100644 index 0000000000000000000000000000000000000000..b3d94cdc8d2388b442e125cd7897ba5c0459b92b --- /dev/null +++ b/inference/models/grounding_dino/grounding_dino.py @@ -0,0 +1,147 @@ +import os +import urllib.request +from time import perf_counter +from typing import Any + +import torch +from groundingdino.util.inference import Model + +from inference.core.entities.requests.groundingdino import GroundingDINOInferenceRequest +from inference.core.entities.requests.inference import InferenceRequestImage +from inference.core.entities.responses.inference import ( + InferenceResponseImage, + ObjectDetectionInferenceResponse, + ObjectDetectionPrediction, +) +from inference.core.env import MODEL_CACHE_DIR +from inference.core.models.roboflow import RoboflowCoreModel +from inference.core.utils.image_utils import load_image_rgb, xyxy_to_xywh + + +class GroundingDINO(RoboflowCoreModel): + """GroundingDINO class for zero-shot object detection. + + Attributes: + model: The GroundingDINO model. + """ + + def __init__( + self, *args, model_id="grounding_dino/groundingdino_swint_ogc", **kwargs + ): + """Initializes the GroundingDINO model. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + + super().__init__(*args, model_id=model_id, **kwargs) + + GROUDNING_DINO_CACHE_DIR = os.path.join(MODEL_CACHE_DIR, model_id) + + GROUNDING_DINO_CONFIG_PATH = os.path.join( + GROUDNING_DINO_CACHE_DIR, "GroundingDINO_SwinT_OGC.py" + ) + # GROUNDING_DINO_CHECKPOINT_PATH = os.path.join( + # GROUDNING_DINO_CACHE_DIR, "groundingdino_swint_ogc.pth" + # ) + + if not os.path.exists(GROUDNING_DINO_CACHE_DIR): + os.makedirs(GROUDNING_DINO_CACHE_DIR) + + if not os.path.exists(GROUNDING_DINO_CONFIG_PATH): + url = "https://raw.githubusercontent.com/roboflow/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py" + urllib.request.urlretrieve(url, GROUNDING_DINO_CONFIG_PATH) + + # if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH): + # url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth" + # urllib.request.urlretrieve(url, GROUNDING_DINO_CHECKPOINT_PATH) + + self.model = Model( + model_config_path=GROUNDING_DINO_CONFIG_PATH, + model_checkpoint_path=os.path.join( + GROUDNING_DINO_CACHE_DIR, "groundingdino_swint_ogc.pth" + ), + device="cuda" if torch.cuda.is_available() else "cpu", + ) + + def preproc_image(self, image: Any): + """Preprocesses an image. + + Args: + image (InferenceRequestImage): The image to preprocess. + + Returns: + np.array: The preprocessed image. + """ + np_image = load_image_rgb(image) + return np_image + + def infer_from_request( + self, + request: GroundingDINOInferenceRequest, + ) -> ObjectDetectionInferenceResponse: + """ + Perform inference based on the details provided in the request, and return the associated responses. + """ + result = self.infer(**request.dict()) + return result + + def infer( + self, image: Any = None, text: list = None, class_filter: list = None, **kwargs + ): + """ + Run inference on a provided image. + + Args: + request (CVInferenceRequest): The inference request. + class_filter (Optional[List[str]]): A list of class names to filter, if provided. + + Returns: + GroundingDINOInferenceRequest: The inference response. + """ + t1 = perf_counter() + image = self.preproc_image(image) + img_dims = image.shape + + detections = self.model.predict_with_classes( + image=image, + classes=text, + box_threshold=0.5, + text_threshold=0.5, + ) + + self.class_names = text + + xywh_bboxes = [xyxy_to_xywh(detection) for detection in detections.xyxy] + + t2 = perf_counter() - t1 + + responses = ObjectDetectionInferenceResponse( + predictions=[ + ObjectDetectionPrediction( + **{ + "x": xywh_bboxes[i][0], + "y": xywh_bboxes[i][1], + "width": xywh_bboxes[i][2], + "height": xywh_bboxes[i][3], + "confidence": detections.confidence[i], + "class": self.class_names[int(detections.class_id[i])], + "class_id": int(detections.class_id[i]), + } + ) + for i, pred in enumerate(detections.xyxy) + if not class_filter or self.class_names[int(pred[6])] in class_filter + ], + image=InferenceResponseImage(width=img_dims[1], height=img_dims[0]), + time=t2, + ) + return responses + + def get_infer_bucket_file_list(self) -> list: + """Get the list of required files for inference. + + Returns: + list: A list of required files for inference, e.g., ["model.pt"]. + """ + return ["groundingdino_swint_ogc.pth"] diff --git a/inference/models/sam/__init__.py b/inference/models/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2b9dbd3f78ae808f06281b58759e36bbccca9065 --- /dev/null +++ b/inference/models/sam/__init__.py @@ -0,0 +1 @@ +from inference.models.sam.segment_anything import SegmentAnything diff --git a/inference/models/sam/__pycache__/__init__.cpython-310.pyc b/inference/models/sam/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d43fa8487429661041db2f376cdbe161281d7b56 Binary files /dev/null and b/inference/models/sam/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/models/sam/__pycache__/segment_anything.cpython-310.pyc b/inference/models/sam/__pycache__/segment_anything.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6d5a0898d97e3f2a24f1016d3effca9b76bbe15 Binary files /dev/null and b/inference/models/sam/__pycache__/segment_anything.cpython-310.pyc differ diff --git a/inference/models/sam/segment_anything.py b/inference/models/sam/segment_anything.py new file mode 100644 index 0000000000000000000000000000000000000000..043eca756311afe1c9e399b8933350983039fe38 --- /dev/null +++ b/inference/models/sam/segment_anything.py @@ -0,0 +1,317 @@ +import base64 +from io import BytesIO +from time import perf_counter +from typing import Any, List, Optional, Union + +import numpy as np +import onnxruntime +import rasterio.features +import torch +from segment_anything import SamPredictor, sam_model_registry +from shapely.geometry import Polygon as ShapelyPolygon + +from inference.core.entities.requests.inference import InferenceRequestImage +from inference.core.entities.requests.sam import ( + SamEmbeddingRequest, + SamInferenceRequest, + SamSegmentationRequest, +) +from inference.core.entities.responses.sam import ( + SamEmbeddingResponse, + SamSegmentationResponse, +) +from inference.core.env import SAM_MAX_EMBEDDING_CACHE_SIZE, SAM_VERSION_ID +from inference.core.models.roboflow import RoboflowCoreModel +from inference.core.utils.image_utils import load_image_rgb +from inference.core.utils.postprocess import masks2poly + + +class SegmentAnything(RoboflowCoreModel): + """SegmentAnything class for handling segmentation tasks. + + Attributes: + sam: The segmentation model. + predictor: The predictor for the segmentation model. + ort_session: ONNX runtime inference session. + embedding_cache: Cache for embeddings. + image_size_cache: Cache for image sizes. + embedding_cache_keys: Keys for the embedding cache. + low_res_logits_cache: Cache for low resolution logits. + segmentation_cache_keys: Keys for the segmentation cache. + """ + + def __init__(self, *args, model_id: str = f"sam/{SAM_VERSION_ID}", **kwargs): + """Initializes the SegmentAnything. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + super().__init__(*args, model_id=model_id, **kwargs) + self.sam = sam_model_registry[self.version_id]( + checkpoint=self.cache_file("encoder.pth") + ) + self.sam.to(device="cuda" if torch.cuda.is_available() else "cpu") + self.predictor = SamPredictor(self.sam) + self.ort_session = onnxruntime.InferenceSession( + self.cache_file("decoder.onnx"), + providers=[ + "CUDAExecutionProvider", + "CPUExecutionProvider", + ], + ) + self.embedding_cache = {} + self.image_size_cache = {} + self.embedding_cache_keys = [] + + self.low_res_logits_cache = {} + self.segmentation_cache_keys = [] + self.task_type = "unsupervised-segmentation" + + def get_infer_bucket_file_list(self) -> List[str]: + """Gets the list of files required for inference. + + Returns: + List[str]: List of file names. + """ + return ["encoder.pth", "decoder.onnx"] + + def embed_image(self, image: Any, image_id: Optional[str] = None, **kwargs): + """ + Embeds an image and caches the result if an image_id is provided. If the image has been embedded before and cached, + the cached result will be returned. + + Args: + image (Any): The image to be embedded. The format should be compatible with the preproc_image method. + image_id (Optional[str]): An identifier for the image. If provided, the embedding result will be cached + with this ID. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + Tuple[np.ndarray, Tuple[int, int]]: A tuple where the first element is the embedding of the image + and the second element is the shape (height, width) of the processed image. + + Notes: + - Embeddings and image sizes are cached to improve performance on repeated requests for the same image. + - The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size, + the oldest entries are removed. + + Example: + >>> img_array = ... # some image array + >>> embed_image(img_array, image_id="sample123") + (array([...]), (224, 224)) + """ + if image_id and image_id in self.embedding_cache: + return ( + self.embedding_cache[image_id], + self.image_size_cache[image_id], + ) + img_in = self.preproc_image(image) + self.predictor.set_image(img_in) + embedding = self.predictor.get_image_embedding().cpu().numpy() + if image_id: + self.embedding_cache[image_id] = embedding + self.image_size_cache[image_id] = img_in.shape[:2] + self.embedding_cache_keys.append(image_id) + if len(self.embedding_cache_keys) > SAM_MAX_EMBEDDING_CACHE_SIZE: + cache_key = self.embedding_cache_keys.pop(0) + del self.embedding_cache[cache_key] + del self.image_size_cache[cache_key] + return (embedding, img_in.shape[:2]) + + def infer_from_request(self, request: SamInferenceRequest): + """Performs inference based on the request type. + + Args: + request (SamInferenceRequest): The inference request. + + Returns: + Union[SamEmbeddingResponse, SamSegmentationResponse]: The inference response. + """ + t1 = perf_counter() + if isinstance(request, SamEmbeddingRequest): + embedding, _ = self.embed_image(**request.dict()) + inference_time = perf_counter() - t1 + if request.format == "json": + return SamEmbeddingResponse( + embeddings=embedding.tolist(), time=inference_time + ) + elif request.format == "binary": + binary_vector = BytesIO() + np.save(binary_vector, embedding) + binary_vector.seek(0) + return SamEmbeddingResponse( + embeddings=binary_vector.getvalue(), time=inference_time + ) + elif isinstance(request, SamSegmentationRequest): + masks, low_res_masks = self.segment_image(**request.dict()) + if request.format == "json": + masks = masks > self.predictor.model.mask_threshold + masks = masks2poly(masks) + low_res_masks = low_res_masks > self.predictor.model.mask_threshold + low_res_masks = masks2poly(low_res_masks) + elif request.format == "binary": + binary_vector = BytesIO() + np.savez_compressed( + binary_vector, masks=masks, low_res_masks=low_res_masks + ) + binary_vector.seek(0) + binary_data = binary_vector.getvalue() + return binary_data + else: + raise ValueError(f"Invalid format {request.format}") + + response = SamSegmentationResponse( + masks=[m.tolist() for m in masks], + low_res_masks=[m.tolist() for m in low_res_masks], + time=perf_counter() - t1, + ) + return response + + def preproc_image(self, image: InferenceRequestImage): + """Preprocesses an image. + + Args: + image (InferenceRequestImage): The image to preprocess. + + Returns: + np.array: The preprocessed image. + """ + np_image = load_image_rgb(image) + return np_image + + def segment_image( + self, + image: Any, + embeddings: Optional[Union[np.ndarray, List[List[float]]]] = None, + embeddings_format: Optional[str] = "json", + has_mask_input: Optional[bool] = False, + image_id: Optional[str] = None, + mask_input: Optional[Union[np.ndarray, List[List[List[float]]]]] = None, + mask_input_format: Optional[str] = "json", + orig_im_size: Optional[List[int]] = None, + point_coords: Optional[List[List[float]]] = [], + point_labels: Optional[List[int]] = [], + use_mask_input_cache: Optional[bool] = True, + **kwargs, + ): + """ + Segments an image based on provided embeddings, points, masks, or cached results. + If embeddings are not directly provided, the function can derive them from the input image or cache. + + Args: + image (Any): The image to be segmented. + embeddings (Optional[Union[np.ndarray, List[List[float]]]]): The embeddings of the image. + Defaults to None, in which case the image is used to compute embeddings. + embeddings_format (Optional[str]): Format of the provided embeddings; either 'json' or 'binary'. Defaults to 'json'. + has_mask_input (Optional[bool]): Specifies whether mask input is provided. Defaults to False. + image_id (Optional[str]): A cached identifier for the image. Useful for accessing cached embeddings or masks. + mask_input (Optional[Union[np.ndarray, List[List[List[float]]]]]): Input mask for the image. + mask_input_format (Optional[str]): Format of the provided mask input; either 'json' or 'binary'. Defaults to 'json'. + orig_im_size (Optional[List[int]]): Original size of the image when providing embeddings directly. + point_coords (Optional[List[List[float]]]): Coordinates of points in the image. Defaults to an empty list. + point_labels (Optional[List[int]]): Labels associated with the provided points. Defaults to an empty list. + use_mask_input_cache (Optional[bool]): Flag to determine if cached mask input should be used. Defaults to True. + **kwargs: Additional keyword arguments. + + Returns: + Tuple[np.ndarray, np.ndarray]: A tuple where the first element is the segmentation masks of the image + and the second element is the low resolution segmentation masks. + + Raises: + ValueError: If necessary inputs are missing or inconsistent. + + Notes: + - Embeddings, segmentations, and low-resolution logits can be cached to improve performance + on repeated requests for the same image. + - The cache has a maximum size defined by SAM_MAX_EMBEDDING_CACHE_SIZE. When the cache exceeds this size, + the oldest entries are removed. + """ + if not embeddings: + if not image and not image_id: + raise ValueError( + "Must provide either image, cached image_id, or embeddings" + ) + elif image_id and not image and image_id not in self.embedding_cache: + raise ValueError( + f"Image ID {image_id} not in embedding cache, must provide the image or embeddings" + ) + embedding, original_image_size = self.embed_image( + image=image, image_id=image_id + ) + else: + if not orig_im_size: + raise ValueError( + "Must provide original image size if providing embeddings" + ) + original_image_size = orig_im_size + if embeddings_format == "json": + embedding = np.array(embeddings) + elif embeddings_format == "binary": + embedding = np.load(BytesIO(embeddings)) + + point_coords = point_coords + point_coords.append([0, 0]) + point_coords = np.array(point_coords, dtype=np.float32) + point_coords = np.expand_dims(point_coords, axis=0) + point_coords = self.predictor.transform.apply_coords( + point_coords, + original_image_size, + ) + + point_labels = point_labels + point_labels.append(-1) + point_labels = np.array(point_labels, dtype=np.float32) + point_labels = np.expand_dims(point_labels, axis=0) + + if has_mask_input: + if ( + image_id + and image_id in self.low_res_logits_cache + and use_mask_input_cache + ): + mask_input = self.low_res_logits_cache[image_id] + elif not mask_input and ( + not image_id or image_id not in self.low_res_logits_cache + ): + raise ValueError("Must provide either mask_input or cached image_id") + else: + if mask_input_format == "json": + polys = mask_input + mask_input = np.zeros((1, len(polys), 256, 256), dtype=np.uint8) + for i, poly in enumerate(polys): + poly = ShapelyPolygon(poly) + raster = rasterio.features.rasterize( + [poly], out_shape=(256, 256) + ) + mask_input[0, i, :, :] = raster + elif mask_input_format == "binary": + binary_data = base64.b64decode(mask_input) + mask_input = np.load(BytesIO(binary_data)) + else: + mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) + + ort_inputs = { + "image_embeddings": embedding.astype(np.float32), + "point_coords": point_coords.astype(np.float32), + "point_labels": point_labels, + "mask_input": mask_input.astype(np.float32), + "has_mask_input": ( + np.zeros(1, dtype=np.float32) + if not has_mask_input + else np.ones(1, dtype=np.float32) + ), + "orig_im_size": np.array(original_image_size, dtype=np.float32), + } + masks, _, low_res_logits = self.ort_session.run(None, ort_inputs) + if image_id: + self.low_res_logits_cache[image_id] = low_res_logits + if image_id not in self.segmentation_cache_keys: + self.segmentation_cache_keys.append(image_id) + if len(self.segmentation_cache_keys) > SAM_MAX_EMBEDDING_CACHE_SIZE: + cache_key = self.segmentation_cache_keys.pop(0) + del self.low_res_logits_cache[cache_key] + masks = masks[0] + low_res_masks = low_res_logits[0] + + return masks, low_res_masks diff --git a/inference/models/utils.py b/inference/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ea4dbf05c868806d18ba5e0640ce4fb08c7faf3f --- /dev/null +++ b/inference/models/utils.py @@ -0,0 +1,191 @@ +from inference.core.env import API_KEY, API_KEY_ENV_NAMES +from inference.core.exceptions import MissingApiKeyError +from inference.core.models.stubs import ( + ClassificationModelStub, + InstanceSegmentationModelStub, + KeypointsDetectionModelStub, + ObjectDetectionModelStub, +) +from inference.core.registries.roboflow import get_model_type +from inference.models import ( + YOLACT, + VitClassification, + YOLONASObjectDetection, + YOLOv5InstanceSegmentation, + YOLOv5ObjectDetection, + YOLOv7InstanceSegmentation, + YOLOv8Classification, + YOLOv8InstanceSegmentation, + YOLOv8ObjectDetection, +) +from inference.models.yolov8.yolov8_keypoints_detection import YOLOv8KeypointsDetection + +ROBOFLOW_MODEL_TYPES = { + ("classification", "stub"): ClassificationModelStub, + ("classification", "vit"): VitClassification, + ("classification", "yolov8n"): YOLOv8Classification, + ("classification", "yolov8s"): YOLOv8Classification, + ("classification", "yolov8m"): YOLOv8Classification, + ("classification", "yolov8l"): YOLOv8Classification, + ("classification", "yolov8x"): YOLOv8Classification, + ("object-detection", "stub"): ObjectDetectionModelStub, + ("object-detection", "yolov5"): YOLOv5ObjectDetection, + ("object-detection", "yolov5v2s"): YOLOv5ObjectDetection, + ("object-detection", "yolov5v6n"): YOLOv5ObjectDetection, + ("object-detection", "yolov5v6s"): YOLOv5ObjectDetection, + ("object-detection", "yolov5v6m"): YOLOv5ObjectDetection, + ("object-detection", "yolov5v6l"): YOLOv5ObjectDetection, + ("object-detection", "yolov5v6x"): YOLOv5ObjectDetection, + ("object-detection", "yolov8"): YOLOv8ObjectDetection, + ("object-detection", "yolov8s"): YOLOv8ObjectDetection, + ("object-detection", "yolov8n"): YOLOv8ObjectDetection, + ("object-detection", "yolov8s"): YOLOv8ObjectDetection, + ("object-detection", "yolov8m"): YOLOv8ObjectDetection, + ("object-detection", "yolov8l"): YOLOv8ObjectDetection, + ("object-detection", "yolov8x"): YOLOv8ObjectDetection, + ("object-detection", "yolo_nas_s"): YOLONASObjectDetection, + ("object-detection", "yolo_nas_m"): YOLONASObjectDetection, + ("object-detection", "yolo_nas_l"): YOLONASObjectDetection, + ("instance-segmentation", "stub"): InstanceSegmentationModelStub, + ( + "instance-segmentation", + "yolov5-seg", + ): YOLOv5InstanceSegmentation, + ( + "instance-segmentation", + "yolov5n-seg", + ): YOLOv5InstanceSegmentation, + ( + "instance-segmentation", + "yolov5s-seg", + ): YOLOv5InstanceSegmentation, + ( + "instance-segmentation", + "yolov5m-seg", + ): YOLOv5InstanceSegmentation, + ( + "instance-segmentation", + "yolov5l-seg", + ): YOLOv5InstanceSegmentation, + ( + "instance-segmentation", + "yolov5x-seg", + ): YOLOv5InstanceSegmentation, + ( + "instance-segmentation", + "yolact", + ): YOLACT, + ( + "instance-segmentation", + "yolov7-seg", + ): YOLOv7InstanceSegmentation, + ( + "instance-segmentation", + "yolov8n", + ): YOLOv8InstanceSegmentation, + ( + "instance-segmentation", + "yolov8s", + ): YOLOv8InstanceSegmentation, + ( + "instance-segmentation", + "yolov8m", + ): YOLOv8InstanceSegmentation, + ( + "instance-segmentation", + "yolov8l", + ): YOLOv8InstanceSegmentation, + ( + "instance-segmentation", + "yolov8x", + ): YOLOv8InstanceSegmentation, + ( + "instance-segmentation", + "yolov8n-seg", + ): YOLOv8InstanceSegmentation, + ( + "instance-segmentation", + "yolov8s-seg", + ): YOLOv8InstanceSegmentation, + ( + "instance-segmentation", + "yolov8m-seg", + ): YOLOv8InstanceSegmentation, + ( + "instance-segmentation", + "yolov8l-seg", + ): YOLOv8InstanceSegmentation, + ( + "instance-segmentation", + "yolov8x-seg", + ): YOLOv8InstanceSegmentation, + ( + "instance-segmentation", + "yolov8-seg", + ): YOLOv8InstanceSegmentation, + ("keypoint-detection", "stub"): KeypointsDetectionModelStub, + ("keypoint-detection", "yolov8n"): YOLOv8KeypointsDetection, + ("keypoint-detection", "yolov8s"): YOLOv8KeypointsDetection, + ("keypoint-detection", "yolov8m"): YOLOv8KeypointsDetection, + ("keypoint-detection", "yolov8l"): YOLOv8KeypointsDetection, + ("keypoint-detection", "yolov8x"): YOLOv8KeypointsDetection, + ("keypoint-detection", "yolov8n-pose"): YOLOv8KeypointsDetection, + ("keypoint-detection", "yolov8s-pose"): YOLOv8KeypointsDetection, + ("keypoint-detection", "yolov8m-pose"): YOLOv8KeypointsDetection, + ("keypoint-detection", "yolov8l-pose"): YOLOv8KeypointsDetection, + ("keypoint-detection", "yolov8x-pose"): YOLOv8KeypointsDetection, +} + +try: + from inference.models import SegmentAnything + + ROBOFLOW_MODEL_TYPES[("embed", "sam")] = SegmentAnything +except: + pass + +try: + from inference.models import Clip + + ROBOFLOW_MODEL_TYPES[("embed", "clip")] = Clip +except: + pass + +try: + from inference.models import Gaze + + ROBOFLOW_MODEL_TYPES[("gaze", "l2cs")] = Gaze +except: + pass + +try: + from inference.models import DocTR + + ROBOFLOW_MODEL_TYPES[("ocr", "doctr")] = DocTR +except: + pass + +try: + from inference.models import GroundingDINO + + ROBOFLOW_MODEL_TYPES[("object-detection", "grounding-dino")] = GroundingDINO +except: + pass + +try: + from inference.models import CogVLM + + ROBOFLOW_MODEL_TYPES[("llm", "cogvlm")] = CogVLM +except: + pass + +try: + from inference.models import YOLOWorld + + ROBOFLOW_MODEL_TYPES[("object-detection", "yolo-world")] = YOLOWorld +except: + pass + + +def get_roboflow_model(model_id, api_key=API_KEY, **kwargs): + task, model = get_model_type(model_id, api_key=api_key) + return ROBOFLOW_MODEL_TYPES[(task, model)](model_id, api_key=api_key, **kwargs) diff --git a/inference/models/vit/__init__.py b/inference/models/vit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3b64ac3fa2e7b481018e4c613d79a55b989f47d7 --- /dev/null +++ b/inference/models/vit/__init__.py @@ -0,0 +1 @@ +from inference.models.vit.vit_classification import VitClassification diff --git a/inference/models/vit/__pycache__/__init__.cpython-310.pyc b/inference/models/vit/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4c21dc73bf43c2af71cb91fd370f0efce6e8f86 Binary files /dev/null and b/inference/models/vit/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/models/vit/__pycache__/vit_classification.cpython-310.pyc b/inference/models/vit/__pycache__/vit_classification.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e9974e488f22e9e2b984c44afcdac157e806ddf Binary files /dev/null and b/inference/models/vit/__pycache__/vit_classification.cpython-310.pyc differ diff --git a/inference/models/vit/vit_classification.py b/inference/models/vit/vit_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..0806b3eff34bc78eef8e586b0753dd3827ff048f --- /dev/null +++ b/inference/models/vit/vit_classification.py @@ -0,0 +1,42 @@ +from inference.core.env import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, LAMBDA +from inference.core.models.classification_base import ( + ClassificationBaseOnnxRoboflowInferenceModel, +) + + +class VitClassification(ClassificationBaseOnnxRoboflowInferenceModel): + """VitClassification handles classification inference + for Vision Transformer (ViT) models using ONNX. + + Inherits: + ClassificationBaseOnnxRoboflowInferenceModel: Base class for ONNX Roboflow Inference. + ClassificationMixin: Mixin class providing classification-specific methods. + + Attributes: + multiclass (bool): A flag that specifies if the model should handle multiclass classification. + """ + + def __init__(self, *args, **kwargs): + """Initializes the VitClassification instance. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + super().__init__(*args, **kwargs) + self.multiclass = self.environment.get("MULTICLASS", False) + + @property + def weights_file(self) -> str: + """Determines the weights file to be used based on the availability of AWS keys. + + If AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY are set, it returns the path to 'weights.onnx'. + Otherwise, it returns the path to 'best.onnx'. + + Returns: + str: Path to the weights file. + """ + if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY and LAMBDA: + return "weights.onnx" + else: + return "best.onnx" diff --git a/inference/models/yolact/__init__.py b/inference/models/yolact/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f147d63b384b2984040138c49bc4a668e8df10b7 --- /dev/null +++ b/inference/models/yolact/__init__.py @@ -0,0 +1 @@ +from inference.models.yolact.yolact_instance_segmentation import YOLACT diff --git a/inference/models/yolact/__pycache__/__init__.cpython-310.pyc b/inference/models/yolact/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39b51b376db985f09d45435bc7c025dc395a7f16 Binary files /dev/null and b/inference/models/yolact/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/models/yolact/__pycache__/yolact_instance_segmentation.cpython-310.pyc b/inference/models/yolact/__pycache__/yolact_instance_segmentation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e3eaa156547443c05146e068afdb884ac62df27 Binary files /dev/null and b/inference/models/yolact/__pycache__/yolact_instance_segmentation.cpython-310.pyc differ diff --git a/inference/models/yolact/yolact_instance_segmentation.py b/inference/models/yolact/yolact_instance_segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..7c687df1c825a16d0404bf230a8d1a27980d474f --- /dev/null +++ b/inference/models/yolact/yolact_instance_segmentation.py @@ -0,0 +1,321 @@ +from time import perf_counter +from typing import Any, List, Tuple + +import cv2 +import numpy as np + +from inference.core.entities.responses.inference import ( + InferenceResponseImage, + InstanceSegmentationInferenceResponse, + InstanceSegmentationPrediction, +) +from inference.core.models.roboflow import OnnxRoboflowInferenceModel +from inference.core.models.types import PreprocessReturnMetadata +from inference.core.nms import w_np_non_max_suppression +from inference.core.utils.postprocess import ( + crop_mask, + masks2poly, + post_process_bboxes, + post_process_polygons, +) + + +class YOLACT(OnnxRoboflowInferenceModel): + """Roboflow ONNX Object detection model (Implements an object detection specific infer method)""" + + task_type = "instance-segmentation" + + @property + def weights_file(self) -> str: + """Gets the weights file. + + Returns: + str: Path to the weights file. + """ + return "weights.onnx" + + def infer( + self, + image: Any, + class_agnostic_nms: bool = False, + confidence: float = 0.5, + iou_threshold: float = 0.5, + max_candidates: int = 3000, + max_detections: int = 300, + return_image_dims: bool = False, + **kwargs, + ) -> List[List[dict]]: + """ + Performs instance segmentation inference on a given image, post-processes the results, + and returns the segmented instances as dictionaries containing their properties. + + Args: + image (Any): The image or list of images to segment. Can be in various formats (e.g., raw array, PIL image). + class_agnostic_nms (bool, optional): Whether to perform class-agnostic non-max suppression. Defaults to False. + confidence (float, optional): Confidence threshold for filtering weak detections. Defaults to 0.5. + iou_threshold (float, optional): Intersection-over-union threshold for non-max suppression. Defaults to 0.5. + max_candidates (int, optional): Maximum number of candidate detections to consider. Defaults to 3000. + max_detections (int, optional): Maximum number of detections to return after non-max suppression. Defaults to 300. + return_image_dims (bool, optional): Whether to return the dimensions of the input image(s). Defaults to False. + **kwargs: Additional keyword arguments. + + Returns: + List[List[dict]]: Each list contains dictionaries of segmented instances for a given image. Each dictionary contains: + - x, y: Center coordinates of the instance. + - width, height: Width and height of the bounding box around the instance. + - class: Name of the detected class. + - confidence: Confidence score of the detection. + - points: List of points describing the segmented mask's boundary. + - class_id: ID corresponding to the detected class. + If `return_image_dims` is True, the function returns a tuple where the first element is the list of detections and the + second element is the list of image dimensions. + + Notes: + - The function supports processing multiple images in a batch. + - If an input list of images is provided, the function returns a list of lists, + where each inner list corresponds to the detections for a specific image. + - The function internally uses an ONNX model for inference. + """ + return super().infer( + image, + class_agnostic_nms=class_agnostic_nms, + confidence=confidence, + iou_threshold=iou_threshold, + max_candidates=max_candidates, + max_detections=max_detections, + return_image_dims=return_image_dims, + **kwargs, + ) + + def preprocess( + self, image: Any, **kwargs + ) -> Tuple[np.ndarray, PreprocessReturnMetadata]: + if isinstance(image, list): + imgs_with_dims = [self.preproc_image(i) for i in image] + imgs, img_dims = zip(*imgs_with_dims) + img_in = np.concatenate(imgs, axis=0) + unwrap = False + else: + img_in, img_dims = self.preproc_image(image) + img_dims = [img_dims] + unwrap = True + + # IN BGR order (for some reason) + mean = (103.94, 116.78, 123.68) + std = (57.38, 57.12, 58.40) + + img_in = img_in.astype(np.float32) + + # Our channels are RGB, so apply mean and std accordingly + img_in[:, 0, :, :] = (img_in[:, 0, :, :] - mean[2]) / std[2] + img_in[:, 1, :, :] = (img_in[:, 1, :, :] - mean[1]) / std[1] + img_in[:, 2, :, :] = (img_in[:, 2, :, :] - mean[0]) / std[0] + + return img_in, PreprocessReturnMetadata( + { + "img_dims": img_dims, + "im_shape": img_in.shape, + } + ) + + def predict( + self, img_in: np.ndarray, **kwargs + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + return self.onnx_session.run(None, {self.input_name: img_in}) + + def postprocess( + self, + predictions: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray], + preprocess_return_metadata: PreprocessReturnMetadata, + **kwargs, + ) -> List[InstanceSegmentationInferenceResponse]: + loc_data = np.float32(predictions[0]) + conf_data = np.float32(predictions[1]) + mask_data = np.float32(predictions[2]) + prior_data = np.float32(predictions[3]) + proto_data = np.float32(predictions[4]) + + batch_size = loc_data.shape[0] + num_priors = prior_data.shape[0] + + boxes = np.zeros((batch_size, num_priors, 4)) + for batch_idx in range(batch_size): + boxes[batch_idx, :, :] = self.decode_predicted_bboxes( + loc_data[batch_idx], prior_data + ) + + conf_preds = np.reshape( + conf_data, (batch_size, num_priors, self.num_classes + 1) + ) + class_confs = conf_preds[:, :, 1:] # remove background class + box_confs = np.expand_dims( + np.max(class_confs, axis=2), 2 + ) # get max conf for each box + + predictions = np.concatenate((boxes, box_confs, class_confs, mask_data), axis=2) + + img_in_shape = preprocess_return_metadata["im_shape"] + predictions[:, :, 0] *= img_in_shape[2] + predictions[:, :, 1] *= img_in_shape[3] + predictions[:, :, 2] *= img_in_shape[2] + predictions[:, :, 3] *= img_in_shape[3] + predictions = w_np_non_max_suppression( + predictions, + conf_thresh=kwargs["confidence"], + iou_thresh=kwargs["iou_threshold"], + class_agnostic=kwargs["class_agnostic_nms"], + max_detections=kwargs["max_detections"], + max_candidate_detections=kwargs["max_candidates"], + num_masks=32, + box_format="xyxy", + ) + predictions = np.array(predictions) + batch_preds = [] + if predictions.shape != (1, 0): + for batch_idx, img_dim in enumerate(preprocess_return_metadata["img_dims"]): + boxes = predictions[batch_idx, :, :4] + scores = predictions[batch_idx, :, 4] + classes = predictions[batch_idx, :, 6] + masks = predictions[batch_idx, :, 7:] + proto = proto_data[batch_idx] + decoded_masks = self.decode_masks(boxes, masks, proto, img_in_shape[2:]) + polys = masks2poly(decoded_masks) + infer_shape = (self.img_size_w, self.img_size_h) + boxes = post_process_bboxes( + [boxes], infer_shape, [img_dim], self.preproc, self.resize_method + )[0] + polys = post_process_polygons( + img_in_shape[2:], + polys, + img_dim, + self.preproc, + resize_method=self.resize_method, + ) + preds = [] + for box, poly, score, cls in zip(boxes, polys, scores, classes): + confidence = float(score) + class_name = self.class_names[int(cls)] + points = [{"x": round(x, 1), "y": round(y, 1)} for (x, y) in poly] + pred = { + "x": round((box[2] + box[0]) / 2, 1), + "y": round((box[3] + box[1]) / 2, 1), + "width": int(box[2] - box[0]), + "height": int(box[3] - box[1]), + "class": class_name, + "confidence": round(confidence, 3), + "points": points, + "class_id": int(cls), + } + preds.append(pred) + batch_preds.append(preds) + else: + batch_preds.append([]) + img_dims = preprocess_return_metadata["img_dims"] + responses = self.make_response(batch_preds, img_dims, **kwargs) + if kwargs["return_image_dims"]: + return responses, preprocess_return_metadata["img_dims"] + else: + return responses + + def make_response( + self, + predictions: List[List[dict]], + img_dims: List[Tuple[int, int]], + class_filter: List[str] = None, + **kwargs, + ) -> List[InstanceSegmentationInferenceResponse]: + """ + Constructs a list of InstanceSegmentationInferenceResponse objects based on the provided predictions + and image dimensions, optionally filtering by class name. + + Args: + predictions (List[List[dict]]): A list containing batch predictions, where each inner list contains + dictionaries of segmented instances for a given image. + img_dims (List[Tuple[int, int]]): List of tuples specifying the dimensions of each image in the format + (height, width). + class_filter (List[str], optional): A list of class names to filter the predictions by. If not provided, + all predictions are included. + + Returns: + List[InstanceSegmentationInferenceResponse]: A list of response objects, each containing the filtered + predictions and corresponding image dimensions for a given image. + + Examples: + >>> predictions = [[{"class_name": "cat", ...}, {"class_name": "dog", ...}], ...] + >>> img_dims = [(300, 400), ...] + >>> responses = make_response(predictions, img_dims, class_filter=["cat"]) + >>> len(responses[0].predictions) # Only predictions with "cat" class are included + 1 + """ + responses = [ + InstanceSegmentationInferenceResponse( + predictions=[ + InstanceSegmentationPrediction(**p) + for p in batch_pred + if not class_filter or p["class_name"] in class_filter + ], + image=InferenceResponseImage( + width=img_dims[i][1], height=img_dims[i][0] + ), + ) + for i, batch_pred in enumerate(predictions) + ] + return responses + + def decode_masks(self, boxes, masks, proto, img_dim): + """Decodes the masks from the given parameters. + + Args: + boxes (np.array): Bounding boxes. + masks (np.array): Masks. + proto (np.array): Proto data. + img_dim (tuple): Image dimensions. + + Returns: + np.array: Decoded masks. + """ + ret_mask = np.matmul(proto, np.transpose(masks)) + ret_mask = 1 / (1 + np.exp(-ret_mask)) + w, h, _ = ret_mask.shape + gain = min(h / img_dim[0], w / img_dim[1]) # gain = old / new + pad = (w - img_dim[1] * gain) / 2, (h - img_dim[0] * gain) / 2 # wh padding + top, left = int(pad[1]), int(pad[0]) # y, x + bottom, right = int(h - pad[1]), int(w - pad[0]) + ret_mask = np.transpose(ret_mask, (2, 0, 1)) + ret_mask = ret_mask[:, top:bottom, left:right] + if len(ret_mask.shape) == 2: + ret_mask = np.expand_dims(ret_mask, axis=0) + ret_mask = ret_mask.transpose((1, 2, 0)) + ret_mask = cv2.resize(ret_mask, img_dim, interpolation=cv2.INTER_LINEAR) + if len(ret_mask.shape) == 2: + ret_mask = np.expand_dims(ret_mask, axis=2) + ret_mask = ret_mask.transpose((2, 0, 1)) + ret_mask = crop_mask(ret_mask, boxes) # CHW + ret_mask[ret_mask < 0.5] = 0 + + return ret_mask + + def decode_predicted_bboxes(self, loc, priors): + """Decode predicted bounding box coordinates using the scheme employed by Yolov2. + + Args: + loc (np.array): The predicted bounding boxes of size [num_priors, 4]. + priors (np.array): The prior box coordinates with size [num_priors, 4]. + + Returns: + np.array: A tensor of decoded relative coordinates in point form with size [num_priors, 4]. + """ + + variances = [0.1, 0.2] + + boxes = np.concatenate( + [ + priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * np.exp(loc[:, 2:] * variances[1]), + ], + 1, + ) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + + return boxes diff --git a/inference/models/yolo_world/__init__.py b/inference/models/yolo_world/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5ba06ce9e9eac69fca15b5eb34d3cf67503bfc13 --- /dev/null +++ b/inference/models/yolo_world/__init__.py @@ -0,0 +1 @@ +from inference.models.yolo_world.yolo_world import YOLOWorld diff --git a/inference/models/yolo_world/__pycache__/__init__.cpython-310.pyc b/inference/models/yolo_world/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c620297a82fe6c6b359aad4e8bfff343e4d20d96 Binary files /dev/null and b/inference/models/yolo_world/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/models/yolo_world/__pycache__/yolo_world.cpython-310.pyc b/inference/models/yolo_world/__pycache__/yolo_world.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32b8558e7fb59a2a9a3f3dac9672ef2d18c88742 Binary files /dev/null and b/inference/models/yolo_world/__pycache__/yolo_world.cpython-310.pyc differ diff --git a/inference/models/yolo_world/yolo_world.py b/inference/models/yolo_world/yolo_world.py new file mode 100644 index 0000000000000000000000000000000000000000..95452596850ac89fa8585035e119fc92e12d5cbb --- /dev/null +++ b/inference/models/yolo_world/yolo_world.py @@ -0,0 +1,143 @@ +from time import perf_counter +from typing import Any + +from ultralytics import YOLO + +from inference.core.cache import cache +from inference.core.entities.requests.yolo_world import YOLOWorldInferenceRequest +from inference.core.entities.responses.inference import ( + InferenceResponseImage, + ObjectDetectionInferenceResponse, + ObjectDetectionPrediction, +) +from inference.core.models.defaults import DEFAULT_CONFIDENCE +from inference.core.models.roboflow import RoboflowCoreModel +from inference.core.utils.hash import get_string_list_hash +from inference.core.utils.image_utils import load_image_rgb + + +class YOLOWorld(RoboflowCoreModel): + """GroundingDINO class for zero-shot object detection. + + Attributes: + model: The GroundingDINO model. + """ + + def __init__(self, *args, model_id="yolo_world/l", **kwargs): + """Initializes the YOLO-World model. + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + + super().__init__(*args, model_id=model_id, **kwargs) + + self.model = YOLO(self.cache_file("yolo-world.pt")) + self.class_names = None + + def preproc_image(self, image: Any): + """Preprocesses an image. + + Args: + image (InferenceRequestImage): The image to preprocess. + + Returns: + np.array: The preprocessed image. + """ + np_image = load_image_rgb(image) + return np_image[:, :, ::-1] + + def infer_from_request( + self, + request: YOLOWorldInferenceRequest, + ) -> ObjectDetectionInferenceResponse: + """ + Perform inference based on the details provided in the request, and return the associated responses. + """ + result = self.infer(**request.dict()) + return result + + def infer( + self, + image: Any = None, + text: list = None, + confidence: float = DEFAULT_CONFIDENCE, + **kwargs, + ): + """ + Run inference on a provided image. + + Args: + request (CVInferenceRequest): The inference request. + class_filter (Optional[List[str]]): A list of class names to filter, if provided. + + Returns: + GroundingDINOInferenceRequest: The inference response. + """ + t1 = perf_counter() + image = self.preproc_image(image) + img_dims = image.shape + + if text is not None and text != self.class_names: + self.set_classes(text) + if self.class_names is None: + raise ValueError( + "Class names not set and not provided in the request. Must set class names before inference or provide them via the argument `text`." + ) + results = self.model.predict( + image, + conf=confidence, + verbose=False, + )[0] + + t2 = perf_counter() - t1 + + predictions = [] + for i, box in enumerate(results.boxes): + x, y, w, h = box.xywh.tolist()[0] + class_id = int(box.cls) + predictions.append( + ObjectDetectionPrediction( + **{ + "x": x, + "y": y, + "width": w, + "height": h, + "confidence": float(box.conf), + "class": self.class_names[class_id], + "class_id": class_id, + } + ) + ) + + responses = ObjectDetectionInferenceResponse( + predictions=predictions, + image=InferenceResponseImage(width=img_dims[1], height=img_dims[0]), + time=t2, + ) + return responses + + def set_classes(self, text: list): + """Set the class names for the model. + + Args: + text (list): The class names. + """ + text_hash = get_string_list_hash(text) + cached_embeddings = cache.get_numpy(text_hash) + if cached_embeddings is not None: + self.model.model.txt_feats = cached_embeddings + self.model.model.model[-1].nc = len(text) + else: + self.model.set_classes(text) + cache.set_numpy(text_hash, self.model.model.txt_feats, expire=300) + self.class_names = text + + def get_infer_bucket_file_list(self) -> list: + """Get the list of required files for inference. + + Returns: + list: A list of required files for inference, e.g., ["model.pt"]. + """ + return ["yolo-world.pt"] diff --git a/inference/models/yolonas/__init__.py b/inference/models/yolonas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3dbb6f334f9cdaba3d28c1fb8cbdf33e93aaa40d --- /dev/null +++ b/inference/models/yolonas/__init__.py @@ -0,0 +1 @@ +from inference.models.yolonas.yolonas_object_detection import YOLONASObjectDetection diff --git a/inference/models/yolonas/__pycache__/__init__.cpython-310.pyc b/inference/models/yolonas/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c13c3f769710298c39adfcf8f698edca227775f Binary files /dev/null and b/inference/models/yolonas/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/models/yolonas/__pycache__/yolonas_object_detection.cpython-310.pyc b/inference/models/yolonas/__pycache__/yolonas_object_detection.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e0c2e8c107f263cd00fcbc1fd8e52397ee424e6 Binary files /dev/null and b/inference/models/yolonas/__pycache__/yolonas_object_detection.cpython-310.pyc differ diff --git a/inference/models/yolonas/yolonas_object_detection.py b/inference/models/yolonas/yolonas_object_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..1666dbcbe45c9d14ab6211d8f4a2ee5022b0632d --- /dev/null +++ b/inference/models/yolonas/yolonas_object_detection.py @@ -0,0 +1,36 @@ +from typing import Tuple + +import numpy as np + +from inference.core.models.object_detection_base import ( + ObjectDetectionBaseOnnxRoboflowInferenceModel, +) + + +class YOLONASObjectDetection(ObjectDetectionBaseOnnxRoboflowInferenceModel): + box_format = "xyxy" + + @property + def weights_file(self) -> str: + """Gets the weights file for the YOLO-NAS model. + + Returns: + str: Path to the ONNX weights file. + """ + return "weights.onnx" + + def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]: + """Performs object detection on the given image using the ONNX session. + + Args: + img_in (np.ndarray): Input image as a NumPy array. + + Returns: + Tuple[np.ndarray]: NumPy array representing the predictions, including boxes, confidence scores, and class confidence scores. + """ + predictions = self.onnx_session.run(None, {self.input_name: img_in}) + boxes = predictions[0] + class_confs = predictions[1] + confs = np.expand_dims(np.max(class_confs, axis=2), axis=2) + predictions = np.concatenate([boxes, confs, class_confs], axis=2) + return (predictions,) diff --git a/inference/models/yolov5/__init__.py b/inference/models/yolov5/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15da9418e92465c0b981e36cc2f1525cde8b8bca --- /dev/null +++ b/inference/models/yolov5/__init__.py @@ -0,0 +1,4 @@ +from inference.models.yolov5.yolov5_instance_segmentation import ( + YOLOv5InstanceSegmentation, +) +from inference.models.yolov5.yolov5_object_detection import YOLOv5ObjectDetection diff --git a/inference/models/yolov5/__pycache__/__init__.cpython-310.pyc b/inference/models/yolov5/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2610e4092f434ecba2a55238ad5b43baf6394c03 Binary files /dev/null and b/inference/models/yolov5/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/models/yolov5/__pycache__/yolov5_instance_segmentation.cpython-310.pyc b/inference/models/yolov5/__pycache__/yolov5_instance_segmentation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23e45193ad506297f111f314cb1a7bd18f608c29 Binary files /dev/null and b/inference/models/yolov5/__pycache__/yolov5_instance_segmentation.cpython-310.pyc differ diff --git a/inference/models/yolov5/__pycache__/yolov5_object_detection.cpython-310.pyc b/inference/models/yolov5/__pycache__/yolov5_object_detection.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93e743d524c9806d9802b010c4097b31bea5f06e Binary files /dev/null and b/inference/models/yolov5/__pycache__/yolov5_object_detection.cpython-310.pyc differ diff --git a/inference/models/yolov5/yolov5_instance_segmentation.py b/inference/models/yolov5/yolov5_instance_segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..0280d8935eff1061410e179c972c5b44f7441a62 --- /dev/null +++ b/inference/models/yolov5/yolov5_instance_segmentation.py @@ -0,0 +1,39 @@ +from typing import List, Tuple + +import numpy as np + +from inference.core.models.instance_segmentation_base import ( + InstanceSegmentationBaseOnnxRoboflowInferenceModel, +) + + +class YOLOv5InstanceSegmentation(InstanceSegmentationBaseOnnxRoboflowInferenceModel): + """YOLOv5 Instance Segmentation ONNX Inference Model. + + This class is responsible for performing instance segmentation using the YOLOv5 model + with ONNX runtime. + + Attributes: + weights_file (str): Path to the ONNX weights file. + """ + + @property + def weights_file(self) -> str: + """Gets the weights file for the YOLOv5 model. + + Returns: + str: Path to the ONNX weights file. + """ + return "yolov5s_weights.onnx" + + def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, np.ndarray]: + """Performs inference on the given image using the ONNX session. + + Args: + img_in (np.ndarray): Input image as a NumPy array. + + Returns: + Tuple[np.ndarray, np.ndarray]: Tuple containing two NumPy arrays representing the predictions. + """ + predictions = self.onnx_session.run(None, {self.input_name: img_in}) + return predictions[0], predictions[1] diff --git a/inference/models/yolov5/yolov5_object_detection.py b/inference/models/yolov5/yolov5_object_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..382eccd98290cfd4ca983246228601e5b710f9d6 --- /dev/null +++ b/inference/models/yolov5/yolov5_object_detection.py @@ -0,0 +1,39 @@ +from typing import Tuple + +import numpy as np + +from inference.core.models.object_detection_base import ( + ObjectDetectionBaseOnnxRoboflowInferenceModel, +) + + +class YOLOv5ObjectDetection(ObjectDetectionBaseOnnxRoboflowInferenceModel): + """Roboflow ONNX Object detection model (Implements an object detection specific infer method). + + This class is responsible for performing object detection using the YOLOv5 model + with ONNX runtime. + + Attributes: + weights_file (str): Path to the ONNX weights file. + """ + + @property + def weights_file(self) -> str: + """Gets the weights file for the YOLOv5 model. + + Returns: + str: Path to the ONNX weights file. + """ + return "yolov5s_weights.onnx" + + def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]: + """Performs object detection on the given image using the ONNX session. + + Args: + img_in (np.ndarray): Input image as a NumPy array. + + Returns: + Tuple[np.ndarray]: NumPy array representing the predictions. + """ + predictions = self.onnx_session.run(None, {self.input_name: img_in})[0] + return (predictions,) diff --git a/inference/models/yolov7/__init__.py b/inference/models/yolov7/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c89bfc96e60377fc2cf08c0eba443b7e7404b253 --- /dev/null +++ b/inference/models/yolov7/__init__.py @@ -0,0 +1,3 @@ +from inference.models.yolov7.yolov7_instance_segmentation import ( + YOLOv7InstanceSegmentation, +) diff --git a/inference/models/yolov7/__pycache__/__init__.cpython-310.pyc b/inference/models/yolov7/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f61c780db64b493e08d0ba9474ef51ca663209f4 Binary files /dev/null and b/inference/models/yolov7/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/models/yolov7/__pycache__/yolov7_instance_segmentation.cpython-310.pyc b/inference/models/yolov7/__pycache__/yolov7_instance_segmentation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4f0bf201d3c6ec081c4ef2c03a2de837d6d2e82 Binary files /dev/null and b/inference/models/yolov7/__pycache__/yolov7_instance_segmentation.cpython-310.pyc differ diff --git a/inference/models/yolov7/yolov7_instance_segmentation.py b/inference/models/yolov7/yolov7_instance_segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..4acda66626fc36806793989736f978c4b91e18ab --- /dev/null +++ b/inference/models/yolov7/yolov7_instance_segmentation.py @@ -0,0 +1,32 @@ +from typing import List, Tuple + +import numpy as np + +from inference.core.models.instance_segmentation_base import ( + InstanceSegmentationBaseOnnxRoboflowInferenceModel, +) + + +class YOLOv7InstanceSegmentation(InstanceSegmentationBaseOnnxRoboflowInferenceModel): + """YOLOv7 Instance Segmentation ONNX Inference Model. + + This class is responsible for performing instance segmentation using the YOLOv7 model + with ONNX runtime. + + Methods: + predict: Performs inference on the given image using the ONNX session. + """ + + def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, np.ndarray]: + """Performs inference on the given image using the ONNX session. + + Args: + img_in (np.ndarray): Input image as a NumPy array. + + Returns: + Tuple[np.ndarray, np.ndarray]: Tuple containing two NumPy arrays representing the predictions and protos. + """ + predictions = self.onnx_session.run(None, {self.input_name: img_in}) + protos = predictions[4] + predictions = predictions[0] + return predictions, protos diff --git a/inference/models/yolov8/__init__.py b/inference/models/yolov8/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49ae68fd122179ff2dce981fc7d4a61b7daf3021 --- /dev/null +++ b/inference/models/yolov8/__init__.py @@ -0,0 +1,6 @@ +from inference.models.yolov8.yolov8_classification import YOLOv8Classification +from inference.models.yolov8.yolov8_instance_segmentation import ( + YOLOv8InstanceSegmentation, +) +from inference.models.yolov8.yolov8_keypoints_detection import YOLOv8KeypointsDetection +from inference.models.yolov8.yolov8_object_detection import YOLOv8ObjectDetection diff --git a/inference/models/yolov8/__pycache__/__init__.cpython-310.pyc b/inference/models/yolov8/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bdf281992e89ef6d2af18cc1e7394122579ef1f Binary files /dev/null and b/inference/models/yolov8/__pycache__/__init__.cpython-310.pyc differ diff --git a/inference/models/yolov8/__pycache__/yolov8_classification.cpython-310.pyc b/inference/models/yolov8/__pycache__/yolov8_classification.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebfcef2dd9b39e11599b9cfa16da80cc61986e2f Binary files /dev/null and b/inference/models/yolov8/__pycache__/yolov8_classification.cpython-310.pyc differ diff --git a/inference/models/yolov8/__pycache__/yolov8_instance_segmentation.cpython-310.pyc b/inference/models/yolov8/__pycache__/yolov8_instance_segmentation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43514042dfb0396eb148429f2f38eaa75b42d046 Binary files /dev/null and b/inference/models/yolov8/__pycache__/yolov8_instance_segmentation.cpython-310.pyc differ diff --git a/inference/models/yolov8/__pycache__/yolov8_keypoints_detection.cpython-310.pyc b/inference/models/yolov8/__pycache__/yolov8_keypoints_detection.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53251fb2749fbaad49af85f5656efd76d5b72bc6 Binary files /dev/null and b/inference/models/yolov8/__pycache__/yolov8_keypoints_detection.cpython-310.pyc differ diff --git a/inference/models/yolov8/__pycache__/yolov8_object_detection.cpython-310.pyc b/inference/models/yolov8/__pycache__/yolov8_object_detection.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3556f215e820c894bb95bc619cf5a272c0e2604 Binary files /dev/null and b/inference/models/yolov8/__pycache__/yolov8_object_detection.cpython-310.pyc differ diff --git a/inference/models/yolov8/yolov8_classification.py b/inference/models/yolov8/yolov8_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..74ca93fa80eae3adc2efa8e1b8ce9178c2c05504 --- /dev/null +++ b/inference/models/yolov8/yolov8_classification.py @@ -0,0 +1,13 @@ +from inference.core.models.classification_base import ( + ClassificationBaseOnnxRoboflowInferenceModel, +) + + +class YOLOv8Classification(ClassificationBaseOnnxRoboflowInferenceModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.multiclass = self.environment.get("MULTICLASS", False) + + @property + def weights_file(self) -> str: + return "weights.onnx" diff --git a/inference/models/yolov8/yolov8_instance_segmentation.py b/inference/models/yolov8/yolov8_instance_segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..01b5c9370c28fa0537da79c17386ef846c1201bb --- /dev/null +++ b/inference/models/yolov8/yolov8_instance_segmentation.py @@ -0,0 +1,50 @@ +from typing import List, Tuple + +import numpy as np + +from inference.core.models.instance_segmentation_base import ( + InstanceSegmentationBaseOnnxRoboflowInferenceModel, +) + + +class YOLOv8InstanceSegmentation(InstanceSegmentationBaseOnnxRoboflowInferenceModel): + """YOLOv8 Instance Segmentation ONNX Inference Model. + + This class is responsible for performing instance segmentation using the YOLOv8 model + with ONNX runtime. + + Attributes: + weights_file (str): Path to the ONNX weights file. + + Methods: + predict: Performs inference on the given image using the ONNX session. + """ + + @property + def weights_file(self) -> str: + """Gets the weights file for the YOLOv8 model. + + Returns: + str: Path to the ONNX weights file. + """ + return "weights.onnx" + + def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, np.ndarray]: + """Performs inference on the given image using the ONNX session. + + Args: + img_in (np.ndarray): Input image as a NumPy array. + + Returns: + Tuple[np.ndarray, np.ndarray]: Tuple containing two NumPy arrays representing the predictions and protos. The predictions include boxes, confidence scores, class confidence scores, and masks. + """ + predictions = self.onnx_session.run(None, {self.input_name: img_in}) + protos = predictions[1] + predictions = predictions[0] + predictions = predictions.transpose(0, 2, 1) + boxes = predictions[:, :, :4] + class_confs = predictions[:, :, 4:-32] + confs = np.expand_dims(np.max(class_confs, axis=2), axis=2) + masks = predictions[:, :, -32:] + predictions = np.concatenate([boxes, confs, class_confs, masks], axis=2) + return predictions, protos diff --git a/inference/models/yolov8/yolov8_keypoints_detection.py b/inference/models/yolov8/yolov8_keypoints_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4daf912487500808c63daae066e95b02a6d342 --- /dev/null +++ b/inference/models/yolov8/yolov8_keypoints_detection.py @@ -0,0 +1,59 @@ +from typing import Tuple + +import numpy as np + +from inference.core.exceptions import ModelArtefactError +from inference.core.models.keypoints_detection_base import ( + KeypointsDetectionBaseOnnxRoboflowInferenceModel, +) +from inference.core.models.utils.keypoints import superset_keypoints_count + + +class YOLOv8KeypointsDetection(KeypointsDetectionBaseOnnxRoboflowInferenceModel): + """Roboflow ONNX keypoints detection model (Implements an object detection specific infer method). + + This class is responsible for performing keypoints detection using the YOLOv8 model + with ONNX runtime. + + Attributes: + weights_file (str): Path to the ONNX weights file. + + Methods: + predict: Performs object detection on the given image using the ONNX session. + """ + + @property + def weights_file(self) -> str: + """Gets the weights file for the YOLOv8 model. + + Returns: + str: Path to the ONNX weights file. + """ + return "weights.onnx" + + def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray, ...]: + """Performs object detection on the given image using the ONNX session. + + Args: + img_in (np.ndarray): Input image as a NumPy array. + + Returns: + Tuple[np.ndarray]: NumPy array representing the predictions, including boxes, confidence scores, and class confidence scores. + """ + predictions = self.onnx_session.run(None, {self.input_name: img_in})[0] + predictions = predictions.transpose(0, 2, 1) + boxes = predictions[:, :, :4] + number_of_classes = len(self.get_class_names) + class_confs = predictions[:, :, 4 : 4 + number_of_classes] + keypoints_detections = predictions[:, :, 4 + number_of_classes :] + confs = np.expand_dims(np.max(class_confs, axis=2), axis=2) + bboxes_predictions = np.concatenate( + [boxes, confs, class_confs, keypoints_detections], axis=2 + ) + return (bboxes_predictions,) + + def keypoints_count(self) -> int: + """Returns the number of keypoints in the model.""" + if self.keypoints_metadata is None: + raise ModelArtefactError("Keypoints metadata not available.") + return superset_keypoints_count(self.keypoints_metadata) diff --git a/inference/models/yolov8/yolov8_object_detection.py b/inference/models/yolov8/yolov8_object_detection.py new file mode 100644 index 0000000000000000000000000000000000000000..ead43878f56fba586e4d36c443ce7dbd24b90fd3 --- /dev/null +++ b/inference/models/yolov8/yolov8_object_detection.py @@ -0,0 +1,47 @@ +from typing import Tuple + +import numpy as np + +from inference.core.models.object_detection_base import ( + ObjectDetectionBaseOnnxRoboflowInferenceModel, +) + + +class YOLOv8ObjectDetection(ObjectDetectionBaseOnnxRoboflowInferenceModel): + """Roboflow ONNX Object detection model (Implements an object detection specific infer method). + + This class is responsible for performing object detection using the YOLOv8 model + with ONNX runtime. + + Attributes: + weights_file (str): Path to the ONNX weights file. + + Methods: + predict: Performs object detection on the given image using the ONNX session. + """ + + @property + def weights_file(self) -> str: + """Gets the weights file for the YOLOv8 model. + + Returns: + str: Path to the ONNX weights file. + """ + return "weights.onnx" + + def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]: + """Performs object detection on the given image using the ONNX session. + + Args: + img_in (np.ndarray): Input image as a NumPy array. + + Returns: + Tuple[np.ndarray]: NumPy array representing the predictions, including boxes, confidence scores, and class confidence scores. + """ + predictions = self.onnx_session.run(None, {self.input_name: img_in})[0] + predictions = predictions.transpose(0, 2, 1) + boxes = predictions[:, :, :4] + class_confs = predictions[:, :, 4:] + confs = np.expand_dims(np.max(class_confs, axis=2), axis=2) + predictions = np.concatenate([boxes, confs, class_confs], axis=2) + return (predictions,)