OMG / inference /core /managers /active_learning.py
Fucius's picture
Upload 422 files
df6c67d verified
raw
history blame
5.91 kB
import time
from typing import Dict, Optional
from fastapi import BackgroundTasks
from inference.core import logger
from inference.core.active_learning.middlewares import ActiveLearningMiddleware
from inference.core.cache.base import BaseCache
from inference.core.entities.requests.inference import InferenceRequest
from inference.core.entities.responses.inference import InferenceResponse
from inference.core.env import DISABLE_PREPROC_AUTO_ORIENT
from inference.core.managers.base import ModelManager
from inference.core.registries.base import ModelRegistry
ACTIVE_LEARNING_ELIGIBLE_PARAM = "active_learning_eligible"
DISABLE_ACTIVE_LEARNING_PARAM = "disable_active_learning"
BACKGROUND_TASKS_PARAM = "background_tasks"
class ActiveLearningManager(ModelManager):
def __init__(
self,
model_registry: ModelRegistry,
cache: BaseCache,
middlewares: Optional[Dict[str, ActiveLearningMiddleware]] = None,
):
super().__init__(model_registry=model_registry)
self._cache = cache
self._middlewares = middlewares if middlewares is not None else {}
async def infer_from_request(
self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
prediction = await super().infer_from_request(
model_id=model_id, request=request, **kwargs
)
active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False)
active_learning_disabled_for_request = getattr(
request, DISABLE_ACTIVE_LEARNING_PARAM, False
)
if (
not active_learning_eligible
or active_learning_disabled_for_request
or request.api_key is None
):
return prediction
self.register(prediction=prediction, model_id=model_id, request=request)
return prediction
def register(
self, prediction: InferenceResponse, model_id: str, request: InferenceRequest
) -> None:
try:
self.ensure_middleware_initialised(model_id=model_id, request=request)
self.register_datapoint(
prediction=prediction,
model_id=model_id,
request=request,
)
except Exception as error:
# Error handling to be decided
logger.warning(
f"Error in datapoint registration for Active Learning. Details: {error}. "
f"Error is suppressed in favour of normal operations of API."
)
def ensure_middleware_initialised(
self, model_id: str, request: InferenceRequest
) -> None:
if model_id in self._middlewares:
return None
start = time.perf_counter()
logger.debug(f"Initialising AL middleware for {model_id}")
self._middlewares[model_id] = ActiveLearningMiddleware.init(
api_key=request.api_key,
model_id=model_id,
cache=self._cache,
)
end = time.perf_counter()
logger.debug(f"Middleware init latency: {(end - start) * 1000} ms")
def register_datapoint(
self, prediction: InferenceResponse, model_id: str, request: InferenceRequest
) -> None:
start = time.perf_counter()
inference_inputs = getattr(request, "image", None)
if inference_inputs is None:
logger.warning(
"Could not register datapoint, as inference input has no `image` field."
)
return None
if not issubclass(type(inference_inputs), list):
inference_inputs = [inference_inputs]
if not issubclass(type(prediction), list):
results_dicts = [prediction.dict(by_alias=True, exclude={"visualization"})]
else:
results_dicts = [
e.dict(by_alias=True, exclude={"visualization"}) for e in prediction
]
prediction_type = self.get_task_type(model_id=model_id)
disable_preproc_auto_orient = (
getattr(request, "disable_preproc_auto_orient", False)
or DISABLE_PREPROC_AUTO_ORIENT
)
self._middlewares[model_id].register_batch(
inference_inputs=inference_inputs,
predictions=results_dicts,
prediction_type=prediction_type,
disable_preproc_auto_orient=disable_preproc_auto_orient,
)
end = time.perf_counter()
logger.debug(f"Registration: {(end - start) * 1000} ms")
class BackgroundTaskActiveLearningManager(ActiveLearningManager):
async def infer_from_request(
self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False)
active_learning_disabled_for_request = getattr(
request, DISABLE_ACTIVE_LEARNING_PARAM, False
)
kwargs[ACTIVE_LEARNING_ELIGIBLE_PARAM] = False # disabling AL in super-classes
prediction = await super().infer_from_request(
model_id=model_id, request=request, **kwargs
)
if (
not active_learning_eligible
or active_learning_disabled_for_request
or request.api_key is None
):
return prediction
if BACKGROUND_TASKS_PARAM not in kwargs:
logger.warning(
"BackgroundTaskActiveLearningManager used against rules - `background_tasks` argument not "
"provided making Active Learning registration running sequentially."
)
self.register(prediction=prediction, model_id=model_id, request=request)
else:
background_tasks: BackgroundTasks = kwargs["background_tasks"]
background_tasks.add_task(
self.register, prediction=prediction, model_id=model_id, request=request
)
return prediction