import hashlib from dataclasses import asdict from typing import Any, Dict, List, Optional from inference.core import logger from inference.core.active_learning.entities import ( ActiveLearningConfiguration, RoboflowProjectMetadata, SamplingMethod, ) from inference.core.active_learning.samplers.close_to_threshold import ( initialize_close_to_threshold_sampling, ) from inference.core.active_learning.samplers.contains_classes import ( initialize_classes_based_sampling, ) from inference.core.active_learning.samplers.number_of_detections import ( initialize_detections_number_based_sampling, ) from inference.core.active_learning.samplers.random import initialize_random_sampling from inference.core.cache.base import BaseCache from inference.core.exceptions import ( ActiveLearningConfigurationDecodingError, ActiveLearningConfigurationError, RoboflowAPINotAuthorizedError, RoboflowAPINotNotFoundError, ) from inference.core.roboflow_api import ( get_roboflow_active_learning_configuration, get_roboflow_dataset_type, get_roboflow_workspace, ) from inference.core.utils.roboflow import get_model_id_chunks TYPE2SAMPLING_INITIALIZERS = { "random": initialize_random_sampling, "close_to_threshold": initialize_close_to_threshold_sampling, "classes_based": initialize_classes_based_sampling, "detections_number_based": initialize_detections_number_based_sampling, } ACTIVE_LEARNING_CONFIG_CACHE_EXPIRE = 900 # 15 min def prepare_active_learning_configuration( api_key: str, model_id: str, cache: BaseCache, ) -> Optional[ActiveLearningConfiguration]: project_metadata = get_roboflow_project_metadata( api_key=api_key, model_id=model_id, cache=cache, ) if not project_metadata.active_learning_configuration.get("enabled", False): return None logger.info( f"Configuring active learning for workspace: {project_metadata.workspace_id}, " f"project: {project_metadata.dataset_id} of type: {project_metadata.dataset_type}. " f"AL configuration: {project_metadata.active_learning_configuration}" ) return initialise_active_learning_configuration( project_metadata=project_metadata, ) def prepare_active_learning_configuration_inplace( api_key: str, model_id: str, active_learning_configuration: Optional[dict], ) -> Optional[ActiveLearningConfiguration]: if ( active_learning_configuration is None or active_learning_configuration.get("enabled", False) is False ): return None dataset_id, version_id = get_model_id_chunks(model_id=model_id) workspace_id = get_roboflow_workspace(api_key=api_key) dataset_type = get_roboflow_dataset_type( api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id, ) project_metadata = RoboflowProjectMetadata( dataset_id=dataset_id, version_id=version_id, workspace_id=workspace_id, dataset_type=dataset_type, active_learning_configuration=active_learning_configuration, ) return initialise_active_learning_configuration( project_metadata=project_metadata, ) def get_roboflow_project_metadata( api_key: str, model_id: str, cache: BaseCache, ) -> RoboflowProjectMetadata: logger.info(f"Fetching active learning configuration.") config_cache_key = construct_cache_key_for_active_learning_config( api_key=api_key, model_id=model_id ) cached_config = cache.get(config_cache_key) if cached_config is not None: logger.info("Found Active Learning configuration in cache.") return parse_cached_roboflow_project_metadata(cached_config=cached_config) dataset_id, version_id = get_model_id_chunks(model_id=model_id) workspace_id = get_roboflow_workspace(api_key=api_key) dataset_type = get_roboflow_dataset_type( api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id, ) try: roboflow_api_configuration = get_roboflow_active_learning_configuration( api_key=api_key, workspace_id=workspace_id, dataset_id=dataset_id ) except (RoboflowAPINotAuthorizedError, RoboflowAPINotNotFoundError): # currently backend returns HTTP 404 if dataset does not exist # or workspace_id from api_key indicate that the owner is different, # so in the situation when we query for Universe dataset. # We want the owner of public dataset to be able to set AL configs # and use them, but not other people. At this point it's known # that HTTP 404 means not authorised (which will probably change # in future iteration of backend) - so on both NotAuth and NotFound # errors we assume that we simply cannot use AL with this model and # this api_key. roboflow_api_configuration = {"enabled": False} configuration = RoboflowProjectMetadata( dataset_id=dataset_id, version_id=version_id, workspace_id=workspace_id, dataset_type=dataset_type, active_learning_configuration=roboflow_api_configuration, ) cache.set( key=config_cache_key, value=asdict(configuration), expire=ACTIVE_LEARNING_CONFIG_CACHE_EXPIRE, ) return configuration def construct_cache_key_for_active_learning_config(api_key: str, model_id: str) -> str: dataset_id = model_id.split("/")[0] api_key_hash = hashlib.md5(api_key.encode("utf-8")).hexdigest() return f"active_learning:configurations:{api_key_hash}:{dataset_id}" def parse_cached_roboflow_project_metadata( cached_config: dict, ) -> RoboflowProjectMetadata: try: return RoboflowProjectMetadata(**cached_config) except Exception as error: raise ActiveLearningConfigurationDecodingError( f"Failed to initialise Active Learning configuration. Cause: {str(error)}" ) from error def initialise_active_learning_configuration( project_metadata: RoboflowProjectMetadata, ) -> ActiveLearningConfiguration: sampling_methods = initialize_sampling_methods( sampling_strategies_configs=project_metadata.active_learning_configuration[ "sampling_strategies" ], ) target_workspace_id = project_metadata.active_learning_configuration.get( "target_workspace", project_metadata.workspace_id ) target_dataset_id = project_metadata.active_learning_configuration.get( "target_project", project_metadata.dataset_id ) return ActiveLearningConfiguration.init( roboflow_api_configuration=project_metadata.active_learning_configuration, sampling_methods=sampling_methods, workspace_id=target_workspace_id, dataset_id=target_dataset_id, model_id=f"{project_metadata.dataset_id}/{project_metadata.version_id}", ) def initialize_sampling_methods( sampling_strategies_configs: List[Dict[str, Any]] ) -> List[SamplingMethod]: result = [] for sampling_strategy_config in sampling_strategies_configs: sampling_type = sampling_strategy_config["type"] if sampling_type not in TYPE2SAMPLING_INITIALIZERS: logger.warn( f"Could not identify sampling method `{sampling_type}` - skipping initialisation." ) continue initializer = TYPE2SAMPLING_INITIALIZERS[sampling_type] result.append(initializer(sampling_strategy_config)) names = set(m.name for m in result) if len(names) != len(result): raise ActiveLearningConfigurationError( "Detected duplication of Active Learning strategies names." ) return result