File size: 5,907 Bytes
2eafbc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import time
from typing import Dict, Optional

from fastapi import BackgroundTasks

from inference.core import logger
from inference.core.active_learning.middlewares import ActiveLearningMiddleware
from inference.core.cache.base import BaseCache
from inference.core.entities.requests.inference import InferenceRequest
from inference.core.entities.responses.inference import InferenceResponse
from inference.core.env import DISABLE_PREPROC_AUTO_ORIENT
from inference.core.managers.base import ModelManager
from inference.core.registries.base import ModelRegistry

ACTIVE_LEARNING_ELIGIBLE_PARAM = "active_learning_eligible"
DISABLE_ACTIVE_LEARNING_PARAM = "disable_active_learning"
BACKGROUND_TASKS_PARAM = "background_tasks"


class ActiveLearningManager(ModelManager):
    def __init__(
        self,
        model_registry: ModelRegistry,
        cache: BaseCache,
        middlewares: Optional[Dict[str, ActiveLearningMiddleware]] = None,
    ):
        super().__init__(model_registry=model_registry)
        self._cache = cache
        self._middlewares = middlewares if middlewares is not None else {}

    async def infer_from_request(
        self, model_id: str, request: InferenceRequest, **kwargs
    ) -> InferenceResponse:
        prediction = await super().infer_from_request(
            model_id=model_id, request=request, **kwargs
        )
        active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False)
        active_learning_disabled_for_request = getattr(
            request, DISABLE_ACTIVE_LEARNING_PARAM, False
        )
        if (
            not active_learning_eligible
            or active_learning_disabled_for_request
            or request.api_key is None
        ):
            return prediction
        self.register(prediction=prediction, model_id=model_id, request=request)
        return prediction

    def register(
        self, prediction: InferenceResponse, model_id: str, request: InferenceRequest
    ) -> None:
        try:
            self.ensure_middleware_initialised(model_id=model_id, request=request)
            self.register_datapoint(
                prediction=prediction,
                model_id=model_id,
                request=request,
            )
        except Exception as error:
            # Error handling to be decided
            logger.warning(
                f"Error in datapoint registration for Active Learning. Details: {error}. "
                f"Error is suppressed in favour of normal operations of API."
            )

    def ensure_middleware_initialised(
        self, model_id: str, request: InferenceRequest
    ) -> None:
        if model_id in self._middlewares:
            return None
        start = time.perf_counter()
        logger.debug(f"Initialising AL middleware for {model_id}")
        self._middlewares[model_id] = ActiveLearningMiddleware.init(
            api_key=request.api_key,
            model_id=model_id,
            cache=self._cache,
        )
        end = time.perf_counter()
        logger.debug(f"Middleware init latency: {(end - start) * 1000} ms")

    def register_datapoint(
        self, prediction: InferenceResponse, model_id: str, request: InferenceRequest
    ) -> None:
        start = time.perf_counter()
        inference_inputs = getattr(request, "image", None)
        if inference_inputs is None:
            logger.warning(
                "Could not register datapoint, as inference input has no `image` field."
            )
            return None
        if not issubclass(type(inference_inputs), list):
            inference_inputs = [inference_inputs]
        if not issubclass(type(prediction), list):
            results_dicts = [prediction.dict(by_alias=True, exclude={"visualization"})]
        else:
            results_dicts = [
                e.dict(by_alias=True, exclude={"visualization"}) for e in prediction
            ]
        prediction_type = self.get_task_type(model_id=model_id)
        disable_preproc_auto_orient = (
            getattr(request, "disable_preproc_auto_orient", False)
            or DISABLE_PREPROC_AUTO_ORIENT
        )
        self._middlewares[model_id].register_batch(
            inference_inputs=inference_inputs,
            predictions=results_dicts,
            prediction_type=prediction_type,
            disable_preproc_auto_orient=disable_preproc_auto_orient,
        )
        end = time.perf_counter()
        logger.debug(f"Registration: {(end - start) * 1000} ms")


class BackgroundTaskActiveLearningManager(ActiveLearningManager):
    async def infer_from_request(
        self, model_id: str, request: InferenceRequest, **kwargs
    ) -> InferenceResponse:
        active_learning_eligible = kwargs.get(ACTIVE_LEARNING_ELIGIBLE_PARAM, False)
        active_learning_disabled_for_request = getattr(
            request, DISABLE_ACTIVE_LEARNING_PARAM, False
        )
        kwargs[ACTIVE_LEARNING_ELIGIBLE_PARAM] = False  # disabling AL in super-classes
        prediction = await super().infer_from_request(
            model_id=model_id, request=request, **kwargs
        )
        if (
            not active_learning_eligible
            or active_learning_disabled_for_request
            or request.api_key is None
        ):
            return prediction
        if BACKGROUND_TASKS_PARAM not in kwargs:
            logger.warning(
                "BackgroundTaskActiveLearningManager used against rules - `background_tasks` argument not "
                "provided making Active Learning registration running sequentially."
            )
            self.register(prediction=prediction, model_id=model_id, request=request)
        else:
            background_tasks: BackgroundTasks = kwargs["background_tasks"]
            background_tasks.add_task(
                self.register, prediction=prediction, model_id=model_id, request=request
            )
        return prediction