import queue from queue import Queue from threading import Thread from typing import Any, List, Optional from inference.core import logger from inference.core.active_learning.accounting import image_can_be_submitted_to_batch from inference.core.active_learning.batching import generate_batch_name from inference.core.active_learning.configuration import ( prepare_active_learning_configuration, prepare_active_learning_configuration_inplace, ) from inference.core.active_learning.core import ( execute_datapoint_registration, execute_sampling, ) from inference.core.active_learning.entities import ( ActiveLearningConfiguration, Prediction, PredictionType, ) from inference.core.cache.base import BaseCache from inference.core.utils.image_utils import load_image MAX_REGISTRATION_QUEUE_SIZE = 512 class NullActiveLearningMiddleware: def register_batch( self, inference_inputs: List[Any], predictions: List[Prediction], prediction_type: PredictionType, disable_preproc_auto_orient: bool = False, ) -> None: pass def register( self, inference_input: Any, prediction: dict, prediction_type: PredictionType, disable_preproc_auto_orient: bool = False, ) -> None: pass def start_registration_thread(self) -> None: pass def stop_registration_thread(self) -> None: pass def __enter__(self) -> "NullActiveLearningMiddleware": return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: pass class ActiveLearningMiddleware: @classmethod def init( cls, api_key: str, model_id: str, cache: BaseCache ) -> "ActiveLearningMiddleware": configuration = prepare_active_learning_configuration( api_key=api_key, model_id=model_id, cache=cache, ) return cls( api_key=api_key, configuration=configuration, cache=cache, ) @classmethod def init_from_config( cls, api_key: str, model_id: str, cache: BaseCache, config: Optional[dict] ) -> "ActiveLearningMiddleware": configuration = prepare_active_learning_configuration_inplace( api_key=api_key, model_id=model_id, active_learning_configuration=config, ) return cls( api_key=api_key, configuration=configuration, cache=cache, ) def __init__( self, api_key: str, configuration: Optional[ActiveLearningConfiguration], cache: BaseCache, ): self._api_key = api_key self._configuration = configuration self._cache = cache def register_batch( self, inference_inputs: List[Any], predictions: List[Prediction], prediction_type: PredictionType, disable_preproc_auto_orient: bool = False, ) -> None: for inference_input, prediction in zip(inference_inputs, predictions): self.register( inference_input=inference_input, prediction=prediction, prediction_type=prediction_type, disable_preproc_auto_orient=disable_preproc_auto_orient, ) def register( self, inference_input: Any, prediction: dict, prediction_type: PredictionType, disable_preproc_auto_orient: bool = False, ) -> None: self._execute_registration( inference_input=inference_input, prediction=prediction, prediction_type=prediction_type, disable_preproc_auto_orient=disable_preproc_auto_orient, ) def _execute_registration( self, inference_input: Any, prediction: dict, prediction_type: PredictionType, disable_preproc_auto_orient: bool = False, ) -> None: if self._configuration is None: return None image, is_bgr = load_image( value=inference_input, disable_preproc_auto_orient=disable_preproc_auto_orient, ) if not is_bgr: image = image[:, :, ::-1] matching_strategies = execute_sampling( image=image, prediction=prediction, prediction_type=prediction_type, sampling_methods=self._configuration.sampling_methods, ) if len(matching_strategies) == 0: return None batch_name = generate_batch_name(configuration=self._configuration) if not image_can_be_submitted_to_batch( batch_name=batch_name, workspace_id=self._configuration.workspace_id, dataset_id=self._configuration.dataset_id, max_batch_images=self._configuration.max_batch_images, api_key=self._api_key, ): logger.debug(f"Limit on Active Learning batch size reached.") return None execute_datapoint_registration( cache=self._cache, matching_strategies=matching_strategies, image=image, prediction=prediction, prediction_type=prediction_type, configuration=self._configuration, api_key=self._api_key, batch_name=batch_name, ) class ThreadingActiveLearningMiddleware(ActiveLearningMiddleware): @classmethod def init( cls, api_key: str, model_id: str, cache: BaseCache, max_queue_size: int = MAX_REGISTRATION_QUEUE_SIZE, ) -> "ThreadingActiveLearningMiddleware": configuration = prepare_active_learning_configuration( api_key=api_key, model_id=model_id, cache=cache, ) task_queue = Queue(max_queue_size) return cls( api_key=api_key, configuration=configuration, cache=cache, task_queue=task_queue, ) @classmethod def init_from_config( cls, api_key: str, model_id: str, cache: BaseCache, config: Optional[dict], max_queue_size: int = MAX_REGISTRATION_QUEUE_SIZE, ) -> "ThreadingActiveLearningMiddleware": configuration = prepare_active_learning_configuration_inplace( api_key=api_key, model_id=model_id, active_learning_configuration=config, ) task_queue = Queue(max_queue_size) return cls( api_key=api_key, configuration=configuration, cache=cache, task_queue=task_queue, ) def __init__( self, api_key: str, configuration: ActiveLearningConfiguration, cache: BaseCache, task_queue: Queue, ): super().__init__(api_key=api_key, configuration=configuration, cache=cache) self._task_queue = task_queue self._registration_thread: Optional[Thread] = None def register( self, inference_input: Any, prediction: dict, prediction_type: PredictionType, disable_preproc_auto_orient: bool = False, ) -> None: logger.debug(f"Putting registration task into queue") try: self._task_queue.put_nowait( ( inference_input, prediction, prediction_type, disable_preproc_auto_orient, ) ) except queue.Full: logger.warning( f"Dropping datapoint registered in Active Learning due to insufficient processing " f"capabilities." ) def start_registration_thread(self) -> None: if self._registration_thread is not None: logger.warning(f"Registration thread already started.") return None logger.debug("Staring registration thread") self._registration_thread = Thread(target=self._consume_queue) self._registration_thread.start() def stop_registration_thread(self) -> None: if self._registration_thread is None: logger.warning("Registration thread is already stopped.") return None logger.debug("Stopping registration thread") self._task_queue.put(None) self._registration_thread.join() if self._registration_thread.is_alive(): logger.warning(f"Registration thread stopping was unsuccessful.") self._registration_thread = None def _consume_queue(self) -> None: queue_closed = False while not queue_closed: queue_closed = self._consume_queue_task() def _consume_queue_task(self) -> bool: logger.debug("Consuming registration task") task = self._task_queue.get() logger.debug("Received registration task") if task is None: logger.debug("Terminating registration thread") self._task_queue.task_done() return True inference_input, prediction, prediction_type, disable_preproc_auto_orient = task try: self._execute_registration( inference_input=inference_input, prediction=prediction, prediction_type=prediction_type, disable_preproc_auto_orient=disable_preproc_auto_orient, ) except Exception as error: # Error handling to be decided logger.warning( f"Error in datapoint registration for Active Learning. Details: {error}. " f"Error is suppressed in favour of normal operations of registration thread." ) self._task_queue.task_done() return False def __enter__(self) -> "ThreadingActiveLearningMiddleware": self.start_registration_thread() return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.stop_registration_thread()