Fucius's picture
Upload 422 files
df6c67d verified
raw
history blame
10 kB
import queue
from queue import Queue
from threading import Thread
from typing import Any, List, Optional
from inference.core import logger
from inference.core.active_learning.accounting import image_can_be_submitted_to_batch
from inference.core.active_learning.batching import generate_batch_name
from inference.core.active_learning.configuration import (
prepare_active_learning_configuration,
prepare_active_learning_configuration_inplace,
)
from inference.core.active_learning.core import (
execute_datapoint_registration,
execute_sampling,
)
from inference.core.active_learning.entities import (
ActiveLearningConfiguration,
Prediction,
PredictionType,
)
from inference.core.cache.base import BaseCache
from inference.core.utils.image_utils import load_image
MAX_REGISTRATION_QUEUE_SIZE = 512
class NullActiveLearningMiddleware:
def register_batch(
self,
inference_inputs: List[Any],
predictions: List[Prediction],
prediction_type: PredictionType,
disable_preproc_auto_orient: bool = False,
) -> None:
pass
def register(
self,
inference_input: Any,
prediction: dict,
prediction_type: PredictionType,
disable_preproc_auto_orient: bool = False,
) -> None:
pass
def start_registration_thread(self) -> None:
pass
def stop_registration_thread(self) -> None:
pass
def __enter__(self) -> "NullActiveLearningMiddleware":
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
pass
class ActiveLearningMiddleware:
@classmethod
def init(
cls, api_key: str, model_id: str, cache: BaseCache
) -> "ActiveLearningMiddleware":
configuration = prepare_active_learning_configuration(
api_key=api_key,
model_id=model_id,
cache=cache,
)
return cls(
api_key=api_key,
configuration=configuration,
cache=cache,
)
@classmethod
def init_from_config(
cls, api_key: str, model_id: str, cache: BaseCache, config: Optional[dict]
) -> "ActiveLearningMiddleware":
configuration = prepare_active_learning_configuration_inplace(
api_key=api_key,
model_id=model_id,
active_learning_configuration=config,
)
return cls(
api_key=api_key,
configuration=configuration,
cache=cache,
)
def __init__(
self,
api_key: str,
configuration: Optional[ActiveLearningConfiguration],
cache: BaseCache,
):
self._api_key = api_key
self._configuration = configuration
self._cache = cache
def register_batch(
self,
inference_inputs: List[Any],
predictions: List[Prediction],
prediction_type: PredictionType,
disable_preproc_auto_orient: bool = False,
) -> None:
for inference_input, prediction in zip(inference_inputs, predictions):
self.register(
inference_input=inference_input,
prediction=prediction,
prediction_type=prediction_type,
disable_preproc_auto_orient=disable_preproc_auto_orient,
)
def register(
self,
inference_input: Any,
prediction: dict,
prediction_type: PredictionType,
disable_preproc_auto_orient: bool = False,
) -> None:
self._execute_registration(
inference_input=inference_input,
prediction=prediction,
prediction_type=prediction_type,
disable_preproc_auto_orient=disable_preproc_auto_orient,
)
def _execute_registration(
self,
inference_input: Any,
prediction: dict,
prediction_type: PredictionType,
disable_preproc_auto_orient: bool = False,
) -> None:
if self._configuration is None:
return None
image, is_bgr = load_image(
value=inference_input,
disable_preproc_auto_orient=disable_preproc_auto_orient,
)
if not is_bgr:
image = image[:, :, ::-1]
matching_strategies = execute_sampling(
image=image,
prediction=prediction,
prediction_type=prediction_type,
sampling_methods=self._configuration.sampling_methods,
)
if len(matching_strategies) == 0:
return None
batch_name = generate_batch_name(configuration=self._configuration)
if not image_can_be_submitted_to_batch(
batch_name=batch_name,
workspace_id=self._configuration.workspace_id,
dataset_id=self._configuration.dataset_id,
max_batch_images=self._configuration.max_batch_images,
api_key=self._api_key,
):
logger.debug(f"Limit on Active Learning batch size reached.")
return None
execute_datapoint_registration(
cache=self._cache,
matching_strategies=matching_strategies,
image=image,
prediction=prediction,
prediction_type=prediction_type,
configuration=self._configuration,
api_key=self._api_key,
batch_name=batch_name,
)
class ThreadingActiveLearningMiddleware(ActiveLearningMiddleware):
@classmethod
def init(
cls,
api_key: str,
model_id: str,
cache: BaseCache,
max_queue_size: int = MAX_REGISTRATION_QUEUE_SIZE,
) -> "ThreadingActiveLearningMiddleware":
configuration = prepare_active_learning_configuration(
api_key=api_key,
model_id=model_id,
cache=cache,
)
task_queue = Queue(max_queue_size)
return cls(
api_key=api_key,
configuration=configuration,
cache=cache,
task_queue=task_queue,
)
@classmethod
def init_from_config(
cls,
api_key: str,
model_id: str,
cache: BaseCache,
config: Optional[dict],
max_queue_size: int = MAX_REGISTRATION_QUEUE_SIZE,
) -> "ThreadingActiveLearningMiddleware":
configuration = prepare_active_learning_configuration_inplace(
api_key=api_key,
model_id=model_id,
active_learning_configuration=config,
)
task_queue = Queue(max_queue_size)
return cls(
api_key=api_key,
configuration=configuration,
cache=cache,
task_queue=task_queue,
)
def __init__(
self,
api_key: str,
configuration: ActiveLearningConfiguration,
cache: BaseCache,
task_queue: Queue,
):
super().__init__(api_key=api_key, configuration=configuration, cache=cache)
self._task_queue = task_queue
self._registration_thread: Optional[Thread] = None
def register(
self,
inference_input: Any,
prediction: dict,
prediction_type: PredictionType,
disable_preproc_auto_orient: bool = False,
) -> None:
logger.debug(f"Putting registration task into queue")
try:
self._task_queue.put_nowait(
(
inference_input,
prediction,
prediction_type,
disable_preproc_auto_orient,
)
)
except queue.Full:
logger.warning(
f"Dropping datapoint registered in Active Learning due to insufficient processing "
f"capabilities."
)
def start_registration_thread(self) -> None:
if self._registration_thread is not None:
logger.warning(f"Registration thread already started.")
return None
logger.debug("Staring registration thread")
self._registration_thread = Thread(target=self._consume_queue)
self._registration_thread.start()
def stop_registration_thread(self) -> None:
if self._registration_thread is None:
logger.warning("Registration thread is already stopped.")
return None
logger.debug("Stopping registration thread")
self._task_queue.put(None)
self._registration_thread.join()
if self._registration_thread.is_alive():
logger.warning(f"Registration thread stopping was unsuccessful.")
self._registration_thread = None
def _consume_queue(self) -> None:
queue_closed = False
while not queue_closed:
queue_closed = self._consume_queue_task()
def _consume_queue_task(self) -> bool:
logger.debug("Consuming registration task")
task = self._task_queue.get()
logger.debug("Received registration task")
if task is None:
logger.debug("Terminating registration thread")
self._task_queue.task_done()
return True
inference_input, prediction, prediction_type, disable_preproc_auto_orient = task
try:
self._execute_registration(
inference_input=inference_input,
prediction=prediction,
prediction_type=prediction_type,
disable_preproc_auto_orient=disable_preproc_auto_orient,
)
except Exception as error:
# Error handling to be decided
logger.warning(
f"Error in datapoint registration for Active Learning. Details: {error}. "
f"Error is suppressed in favour of normal operations of registration thread."
)
self._task_queue.task_done()
return False
def __enter__(self) -> "ThreadingActiveLearningMiddleware":
self.start_registration_thread()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.stop_registration_thread()