Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,907 Bytes
2eafbc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import time
from typing import Dict, Optional
from fastapi import BackgroundTasks
from inference.core import logger
from inference.core.active_learning.middlewares import ActiveLearningMiddleware
from inference.core.cache.base import BaseCache
from inference.core.entities.requests.inference import InferenceRequest
from inference.core.entities.responses.inference import InferenceResponse
from inference.core.env import DISABLE_PREPROC_AUTO_ORIENT
from inference.core.managers.base import ModelManager
from inference.core.registries.base import ModelRegistry
ACTIVE_LEARNING_ELIGIBLE_PARAM = "active_learning_eligible"
DISABLE_ACTIVE_LEARNING_PARAM = "disable_active_learning"
BACKGROUND_TASKS_PARAM = "background_tasks"
class ActiveLearningManager(ModelManager):
def __init__(
self,
model_registry: ModelRegistry,
cache: BaseCache,
middlewares: Optional[Dict[str, ActiveLearningMiddleware]] = None,
):
super().__init__(model_registry=model_registry)
self._cache = cache
self._middlewares = middlewares if middlewares is not None else {}
async def infer_from_request(
self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
prediction = await super().infer_from_request(
model_id=model_id, request=request, **kwargs
)
active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False)
active_learning_disabled_for_request = getattr(
request, DISABLE_ACTIVE_LEARNING_PARAM, False
)
if (
not active_learning_eligible
or active_learning_disabled_for_request
or request.api_key is None
):
return prediction
self.register(prediction=prediction, model_id=model_id, request=request)
return prediction
def register(
self, prediction: InferenceResponse, model_id: str, request: InferenceRequest
) -> None:
try:
self.ensure_middleware_initialised(model_id=model_id, request=request)
self.register_datapoint(
prediction=prediction,
model_id=model_id,
request=request,
)
except Exception as error:
# Error handling to be decided
logger.warning(
f"Error in datapoint registration for Active Learning. Details: {error}. "
f"Error is suppressed in favour of normal operations of API."
)
def ensure_middleware_initialised(
self, model_id: str, request: InferenceRequest
) -> None:
if model_id in self._middlewares:
return None
start = time.perf_counter()
logger.debug(f"Initialising AL middleware for {model_id}")
self._middlewares[model_id] = ActiveLearningMiddleware.init(
api_key=request.api_key,
model_id=model_id,
cache=self._cache,
)
end = time.perf_counter()
logger.debug(f"Middleware init latency: {(end - start) * 1000} ms")
def register_datapoint(
self, prediction: InferenceResponse, model_id: str, request: InferenceRequest
) -> None:
start = time.perf_counter()
inference_inputs = getattr(request, "image", None)
if inference_inputs is None:
logger.warning(
"Could not register datapoint, as inference input has no `image` field."
)
return None
if not issubclass(type(inference_inputs), list):
inference_inputs = [inference_inputs]
if not issubclass(type(prediction), list):
results_dicts = [prediction.dict(by_alias=True, exclude={"visualization"})]
else:
results_dicts = [
e.dict(by_alias=True, exclude={"visualization"}) for e in prediction
]
prediction_type = self.get_task_type(model_id=model_id)
disable_preproc_auto_orient = (
getattr(request, "disable_preproc_auto_orient", False)
or DISABLE_PREPROC_AUTO_ORIENT
)
self._middlewares[model_id].register_batch(
inference_inputs=inference_inputs,
predictions=results_dicts,
prediction_type=prediction_type,
disable_preproc_auto_orient=disable_preproc_auto_orient,
)
end = time.perf_counter()
logger.debug(f"Registration: {(end - start) * 1000} ms")
class BackgroundTaskActiveLearningManager(ActiveLearningManager):
async def infer_from_request(
self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False)
active_learning_disabled_for_request = getattr(
request, DISABLE_ACTIVE_LEARNING_PARAM, False
)
kwargs[ACTIVE_LEARNING_ELIGIBLE_PARAM] = False # disabling AL in super-classes
prediction = await super().infer_from_request(
model_id=model_id, request=request, **kwargs
)
if (
not active_learning_eligible
or active_learning_disabled_for_request
or request.api_key is None
):
return prediction
if BACKGROUND_TASKS_PARAM not in kwargs:
logger.warning(
"BackgroundTaskActiveLearningManager used against rules - `background_tasks` argument not "
"provided making Active Learning registration running sequentially."
)
self.register(prediction=prediction, model_id=model_id, request=request)
else:
background_tasks: BackgroundTasks = kwargs["background_tasks"]
background_tasks.add_task(
self.register, prediction=prediction, model_id=model_id, request=request
)
return prediction
|