import time from typing import Dict, Optional from fastapi import BackgroundTasks from inference.core import logger from inference.core.active_learning.middlewares import ActiveLearningMiddleware from inference.core.cache.base import BaseCache from inference.core.entities.requests.inference import InferenceRequest from inference.core.entities.responses.inference import InferenceResponse from inference.core.env import DISABLE_PREPROC_AUTO_ORIENT from inference.core.managers.base import ModelManager from inference.core.registries.base import ModelRegistry ACTIVE_LEARNING_ELIGIBLE_PARAM = "active_learning_eligible" DISABLE_ACTIVE_LEARNING_PARAM = "disable_active_learning" BACKGROUND_TASKS_PARAM = "background_tasks" class ActiveLearningManager(ModelManager): def __init__( self, model_registry: ModelRegistry, cache: BaseCache, middlewares: Optional[Dict[str, ActiveLearningMiddleware]] = None, ): super().__init__(model_registry=model_registry) self._cache = cache self._middlewares = middlewares if middlewares is not None else {} async def infer_from_request( self, model_id: str, request: InferenceRequest, **kwargs ) -> InferenceResponse: prediction = await super().infer_from_request( model_id=model_id, request=request, **kwargs ) active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False) active_learning_disabled_for_request = getattr( request, DISABLE_ACTIVE_LEARNING_PARAM, False ) if ( not active_learning_eligible or active_learning_disabled_for_request or request.api_key is None ): return prediction self.register(prediction=prediction, model_id=model_id, request=request) return prediction def register( self, prediction: InferenceResponse, model_id: str, request: InferenceRequest ) -> None: try: self.ensure_middleware_initialised(model_id=model_id, request=request) self.register_datapoint( prediction=prediction, model_id=model_id, request=request, ) except Exception as error: # Error handling to be decided logger.warning( f"Error in datapoint registration for Active Learning. Details: {error}. " f"Error is suppressed in favour of normal operations of API." ) def ensure_middleware_initialised( self, model_id: str, request: InferenceRequest ) -> None: if model_id in self._middlewares: return None start = time.perf_counter() logger.debug(f"Initialising AL middleware for {model_id}") self._middlewares[model_id] = ActiveLearningMiddleware.init( api_key=request.api_key, model_id=model_id, cache=self._cache, ) end = time.perf_counter() logger.debug(f"Middleware init latency: {(end - start) * 1000} ms") def register_datapoint( self, prediction: InferenceResponse, model_id: str, request: InferenceRequest ) -> None: start = time.perf_counter() inference_inputs = getattr(request, "image", None) if inference_inputs is None: logger.warning( "Could not register datapoint, as inference input has no `image` field." ) return None if not issubclass(type(inference_inputs), list): inference_inputs = [inference_inputs] if not issubclass(type(prediction), list): results_dicts = [prediction.dict(by_alias=True, exclude={"visualization"})] else: results_dicts = [ e.dict(by_alias=True, exclude={"visualization"}) for e in prediction ] prediction_type = self.get_task_type(model_id=model_id) disable_preproc_auto_orient = ( getattr(request, "disable_preproc_auto_orient", False) or DISABLE_PREPROC_AUTO_ORIENT ) self._middlewares[model_id].register_batch( inference_inputs=inference_inputs, predictions=results_dicts, prediction_type=prediction_type, disable_preproc_auto_orient=disable_preproc_auto_orient, ) end = time.perf_counter() logger.debug(f"Registration: {(end - start) * 1000} ms") class BackgroundTaskActiveLearningManager(ActiveLearningManager): async def infer_from_request( self, model_id: str, request: InferenceRequest, **kwargs ) -> InferenceResponse: active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False) active_learning_disabled_for_request = getattr( request, DISABLE_ACTIVE_LEARNING_PARAM, False ) kwargs[ACTIVE_LEARNING_ELIGIBLE_PARAM] = False # disabling AL in super-classes prediction = await super().infer_from_request( model_id=model_id, request=request, **kwargs ) if ( not active_learning_eligible or active_learning_disabled_for_request or request.api_key is None ): return prediction if BACKGROUND_TASKS_PARAM not in kwargs: logger.warning( "BackgroundTaskActiveLearningManager used against rules - `background_tasks` argument not " "provided making Active Learning registration running sequentially." ) self.register(prediction=prediction, model_id=model_id, request=request) else: background_tasks: BackgroundTasks = kwargs["background_tasks"] background_tasks.add_task( self.register, prediction=prediction, model_id=model_id, request=request ) return prediction