Fucius's picture
Upload 422 files
2eafbc4 verified
import queue
from queue import Queue
from threading import Thread
from typing import Any, List, Optional
from inference.core import logger
from inference.core.active_learning.accounting import image_can_be_submitted_to_batch
from inference.core.active_learning.batching import generate_batch_name
from inference.core.active_learning.configuration import (
prepare_active_learning_configuration,
prepare_active_learning_configuration_inplace,
)
from inference.core.active_learning.core import (
execute_datapoint_registration,
execute_sampling,
)
from inference.core.active_learning.entities import (
ActiveLearningConfiguration,
Prediction,
PredictionType,
)
from inference.core.cache.base import BaseCache
from inference.core.utils.image_utils import load_image
MAX_REGISTRATION_QUEUE_SIZE = 512
class NullActiveLearningMiddleware:
def register_batch(
self,
inference_inputs: List[Any],
predictions: List[Prediction],
prediction_type: PredictionType,
disable_preproc_auto_orient: bool = False,
) -> None:
pass
def register(
self,
inference_input: Any,
prediction: dict,
prediction_type: PredictionType,
disable_preproc_auto_orient: bool = False,
) -> None:
pass
def start_registration_thread(self) -> None:
pass
def stop_registration_thread(self) -> None:
pass
def __enter__(self) -> "NullActiveLearningMiddleware":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
pass
class ActiveLearningMiddleware:
@classmethod
def init(
cls, api_key: str, model_id: str, cache: BaseCache
) -> "ActiveLearningMiddleware":
configuration = prepare_active_learning_configuration(
api_key=api_key,
model_id=model_id,
cache=cache,
)
return cls(
api_key=api_key,
configuration=configuration,
cache=cache,
)
@classmethod
def init_from_config(
cls, api_key: str, model_id: str, cache: BaseCache, config: Optional[dict]
) -> "ActiveLearningMiddleware":
configuration = prepare_active_learning_configuration_inplace(
api_key=api_key,
model_id=model_id,
active_learning_configuration=config,
)
return cls(
api_key=api_key,
configuration=configuration,
cache=cache,
)
def __init__(
self,
api_key: str,
configuration: Optional[ActiveLearningConfiguration],
cache: BaseCache,
):
self._api_key = api_key
self._configuration = configuration
self._cache = cache
def register_batch(
self,
inference_inputs: List[Any],
predictions: List[Prediction],
prediction_type: PredictionType,
disable_preproc_auto_orient: bool = False,
) -> None:
for inference_input, prediction in zip(inference_inputs, predictions):
self.register(
inference_input=inference_input,
prediction=prediction,
prediction_type=prediction_type,
disable_preproc_auto_orient=disable_preproc_auto_orient,
)
def register(
self,
inference_input: Any,
prediction: dict,
prediction_type: PredictionType,
disable_preproc_auto_orient: bool = False,
) -> None:
self._execute_registration(
inference_input=inference_input,
prediction=prediction,
prediction_type=prediction_type,
disable_preproc_auto_orient=disable_preproc_auto_orient,
)
def _execute_registration(
self,
inference_input: Any,
prediction: dict,
prediction_type: PredictionType,
disable_preproc_auto_orient: bool = False,
) -> None:
if self._configuration is None:
return None
image, is_bgr = load_image(
value=inference_input,
disable_preproc_auto_orient=disable_preproc_auto_orient,
)
if not is_bgr:
image = image[:, :, ::-1]
matching_strategies = execute_sampling(
image=image,
prediction=prediction,
prediction_type=prediction_type,
sampling_methods=self._configuration.sampling_methods,
)
if len(matching_strategies) == 0:
return None
batch_name = generate_batch_name(configuration=self._configuration)
if not image_can_be_submitted_to_batch(
batch_name=batch_name,
workspace_id=self._configuration.workspace_id,
dataset_id=self._configuration.dataset_id,
max_batch_images=self._configuration.max_batch_images,
api_key=self._api_key,
):
logger.debug(f"Limit on Active Learning batch size reached.")
return None
execute_datapoint_registration(
cache=self._cache,
matching_strategies=matching_strategies,
image=image,
prediction=prediction,
prediction_type=prediction_type,
configuration=self._configuration,
api_key=self._api_key,
batch_name=batch_name,
)
class ThreadingActiveLearningMiddleware(ActiveLearningMiddleware):
@classmethod
def init(
cls,
api_key: str,
model_id: str,
cache: BaseCache,
max_queue_size: int = MAX_REGISTRATION_QUEUE_SIZE,
) -> "ThreadingActiveLearningMiddleware":
configuration = prepare_active_learning_configuration(
api_key=api_key,
model_id=model_id,
cache=cache,
)
task_queue = Queue(max_queue_size)
return cls(
api_key=api_key,
configuration=configuration,
cache=cache,
task_queue=task_queue,
)
@classmethod
def init_from_config(
cls,
api_key: str,
model_id: str,
cache: BaseCache,
config: Optional[dict],
max_queue_size: int = MAX_REGISTRATION_QUEUE_SIZE,
) -> "ThreadingActiveLearningMiddleware":
configuration = prepare_active_learning_configuration_inplace(
api_key=api_key,
model_id=model_id,
active_learning_configuration=config,
)
task_queue = Queue(max_queue_size)
return cls(
api_key=api_key,
configuration=configuration,
cache=cache,
task_queue=task_queue,
)
def __init__(
self,
api_key: str,
configuration: ActiveLearningConfiguration,
cache: BaseCache,
task_queue: Queue,
):
super().__init__(api_key=api_key, configuration=configuration, cache=cache)
self._task_queue = task_queue
self._registration_thread: Optional[Thread] = None
def register(
self,
inference_input: Any,
prediction: dict,
prediction_type: PredictionType,
disable_preproc_auto_orient: bool = False,
) -> None:
logger.debug(f"Putting registration task into queue")
try:
self._task_queue.put_nowait(
(
inference_input,
prediction,
prediction_type,
disable_preproc_auto_orient,
)
)
except queue.Full:
logger.warning(
f"Dropping datapoint registered in Active Learning due to insufficient processing "
f"capabilities."
)
def start_registration_thread(self) -> None:
if self._registration_thread is not None:
logger.warning(f"Registration thread already started.")
return None
logger.debug("Staring registration thread")
self._registration_thread = Thread(target=self._consume_queue)
self._registration_thread.start()
def stop_registration_thread(self) -> None:
if self._registration_thread is None:
logger.warning("Registration thread is already stopped.")
return None
logger.debug("Stopping registration thread")
self._task_queue.put(None)
self._registration_thread.join()
if self._registration_thread.is_alive():
logger.warning(f"Registration thread stopping was unsuccessful.")
self._registration_thread = None
def _consume_queue(self) -> None:
queue_closed = False
while not queue_closed:
queue_closed = self._consume_queue_task()
def _consume_queue_task(self) -> bool:
logger.debug("Consuming registration task")
task = self._task_queue.get()
logger.debug("Received registration task")
if task is None:
logger.debug("Terminating registration thread")
self._task_queue.task_done()
return True
inference_input, prediction, prediction_type, disable_preproc_auto_orient = task
try:
self._execute_registration(
inference_input=inference_input,
prediction=prediction,
prediction_type=prediction_type,
disable_preproc_auto_orient=disable_preproc_auto_orient,
)
except Exception as error:
# Error handling to be decided
logger.warning(
f"Error in datapoint registration for Active Learning. Details: {error}. "
f"Error is suppressed in favour of normal operations of registration thread."
)
self._task_queue.task_done()
return False
def __enter__(self) -> "ThreadingActiveLearningMiddleware":
self.start_registration_thread()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.stop_registration_thread()