Spaces:
				
			
			
	
			
			
		Configuration error
		
	
	
	
			
			
	
	
	
	
		
		
		Configuration error
		
	| import time | |
| from typing import Dict, Optional | |
| from fastapi import BackgroundTasks | |
| from inference.core import logger | |
| from inference.core.active_learning.middlewares import ActiveLearningMiddleware | |
| from inference.core.cache.base import BaseCache | |
| from inference.core.entities.requests.inference import InferenceRequest | |
| from inference.core.entities.responses.inference import InferenceResponse | |
| from inference.core.env import DISABLE_PREPROC_AUTO_ORIENT | |
| from inference.core.managers.base import ModelManager | |
| from inference.core.registries.base import ModelRegistry | |
| ACTIVE_LEARNING_ELIGIBLE_PARAM = "active_learning_eligible" | |
| DISABLE_ACTIVE_LEARNING_PARAM = "disable_active_learning" | |
| BACKGROUND_TASKS_PARAM = "background_tasks" | |
| class ActiveLearningManager(ModelManager): | |
| def __init__( | |
| self, | |
| model_registry: ModelRegistry, | |
| cache: BaseCache, | |
| middlewares: Optional[Dict[str, ActiveLearningMiddleware]] = None, | |
| ): | |
| super().__init__(model_registry=model_registry) | |
| self._cache = cache | |
| self._middlewares = middlewares if middlewares is not None else {} | |
| async def infer_from_request( | |
| self, model_id: str, request: InferenceRequest, **kwargs | |
| ) -> InferenceResponse: | |
| prediction = await super().infer_from_request( | |
| model_id=model_id, request=request, **kwargs | |
| ) | |
| active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False) | |
| active_learning_disabled_for_request = getattr( | |
| request, DISABLE_ACTIVE_LEARNING_PARAM, False | |
| ) | |
| if ( | |
| not active_learning_eligible | |
| or active_learning_disabled_for_request | |
| or request.api_key is None | |
| ): | |
| return prediction | |
| self.register(prediction=prediction, model_id=model_id, request=request) | |
| return prediction | |
| def register( | |
| self, prediction: InferenceResponse, model_id: str, request: InferenceRequest | |
| ) -> None: | |
| try: | |
| self.ensure_middleware_initialised(model_id=model_id, request=request) | |
| self.register_datapoint( | |
| prediction=prediction, | |
| model_id=model_id, | |
| request=request, | |
| ) | |
| except Exception as error: | |
| # Error handling to be decided | |
| logger.warning( | |
| f"Error in datapoint registration for Active Learning. Details: {error}. " | |
| f"Error is suppressed in favour of normal operations of API." | |
| ) | |
| def ensure_middleware_initialised( | |
| self, model_id: str, request: InferenceRequest | |
| ) -> None: | |
| if model_id in self._middlewares: | |
| return None | |
| start = time.perf_counter() | |
| logger.debug(f"Initialising AL middleware for {model_id}") | |
| self._middlewares[model_id] = ActiveLearningMiddleware.init( | |
| api_key=request.api_key, | |
| model_id=model_id, | |
| cache=self._cache, | |
| ) | |
| end = time.perf_counter() | |
| logger.debug(f"Middleware init latency: {(end - start) * 1000} ms") | |
| def register_datapoint( | |
| self, prediction: InferenceResponse, model_id: str, request: InferenceRequest | |
| ) -> None: | |
| start = time.perf_counter() | |
| inference_inputs = getattr(request, "image", None) | |
| if inference_inputs is None: | |
| logger.warning( | |
| "Could not register datapoint, as inference input has no `image` field." | |
| ) | |
| return None | |
| if not issubclass(type(inference_inputs), list): | |
| inference_inputs = [inference_inputs] | |
| if not issubclass(type(prediction), list): | |
| results_dicts = [prediction.dict(by_alias=True, exclude={"visualization"})] | |
| else: | |
| results_dicts = [ | |
| e.dict(by_alias=True, exclude={"visualization"}) for e in prediction | |
| ] | |
| prediction_type = self.get_task_type(model_id=model_id) | |
| disable_preproc_auto_orient = ( | |
| getattr(request, "disable_preproc_auto_orient", False) | |
| or DISABLE_PREPROC_AUTO_ORIENT | |
| ) | |
| self._middlewares[model_id].register_batch( | |
| inference_inputs=inference_inputs, | |
| predictions=results_dicts, | |
| prediction_type=prediction_type, | |
| disable_preproc_auto_orient=disable_preproc_auto_orient, | |
| ) | |
| end = time.perf_counter() | |
| logger.debug(f"Registration: {(end - start) * 1000} ms") | |
| class BackgroundTaskActiveLearningManager(ActiveLearningManager): | |
| async def infer_from_request( | |
| self, model_id: str, request: InferenceRequest, **kwargs | |
| ) -> InferenceResponse: | |
| active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False) | |
| active_learning_disabled_for_request = getattr( | |
| request, DISABLE_ACTIVE_LEARNING_PARAM, False | |
| ) | |
| kwargs[ACTIVE_LEARNING_ELIGIBLE_PARAM] = False # disabling AL in super-classes | |
| prediction = await super().infer_from_request( | |
| model_id=model_id, request=request, **kwargs | |
| ) | |
| if ( | |
| not active_learning_eligible | |
| or active_learning_disabled_for_request | |
| or request.api_key is None | |
| ): | |
| return prediction | |
| if BACKGROUND_TASKS_PARAM not in kwargs: | |
| logger.warning( | |
| "BackgroundTaskActiveLearningManager used against rules - `background_tasks` argument not " | |
| "provided making Active Learning registration running sequentially." | |
| ) | |
| self.register(prediction=prediction, model_id=model_id, request=request) | |
| else: | |
| background_tasks: BackgroundTasks = kwargs["background_tasks"] | |
| background_tasks.add_task( | |
| self.register, prediction=prediction, model_id=model_id, request=request | |
| ) | |
| return prediction | |