File size: 11,559 Bytes
2eafbc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
import time
from typing import Dict, List, Optional, Tuple

import numpy as np
from fastapi.encoders import jsonable_encoder

from inference.core.cache import cache
from inference.core.cache.serializers import to_cachable_inference_item
from inference.core.devices.utils import GLOBAL_INFERENCE_SERVER_ID
from inference.core.entities.requests.inference import InferenceRequest
from inference.core.entities.responses.inference import InferenceResponse
from inference.core.env import (
    DISABLE_INFERENCE_CACHE,
    METRICS_ENABLED,
    METRICS_INTERVAL,
    ROBOFLOW_SERVER_UUID,
)
from inference.core.exceptions import InferenceModelNotFound
from inference.core.logger import logger
from inference.core.managers.entities import ModelDescription
from inference.core.managers.pingback import PingbackInfo
from inference.core.models.base import Model, PreprocessReturnMetadata
from inference.core.registries.base import ModelRegistry


class ModelManager:
    """Model managers keep track of a dictionary of Model objects and is responsible for passing requests to the right model using the infer method."""

    def __init__(self, model_registry: ModelRegistry, models: Optional[dict] = None):
        self.model_registry = model_registry
        self._models: Dict[str, Model] = models if models is not None else {}

    def init_pingback(self):
        """Initializes pingback mechanism."""
        self.num_errors = 0  # in the device
        self.uuid = ROBOFLOW_SERVER_UUID
        if METRICS_ENABLED:
            self.pingback = PingbackInfo(self)
            self.pingback.start()

    def add_model(
        self, model_id: str, api_key: str, model_id_alias: Optional[str] = None
    ) -> None:
        """Adds a new model to the manager.

        Args:
            model_id (str): The identifier of the model.
            model (Model): The model instance.
        """
        logger.debug(
            f"ModelManager - Adding model with model_id={model_id}, model_id_alias={model_id_alias}"
        )
        if model_id in self._models:
            logger.debug(
                f"ModelManager - model with model_id={model_id} is already loaded."
            )
            return
        logger.debug("ModelManager - model initialisation...")
        model = self.model_registry.get_model(
            model_id if model_id_alias is None else model_id_alias, api_key
        )(
            model_id=model_id,
            api_key=api_key,
        )
        logger.debug("ModelManager - model successfully loaded.")
        self._models[model_id if model_id_alias is None else model_id_alias] = model

    def check_for_model(self, model_id: str) -> None:
        """Checks whether the model with the given ID is in the manager.

        Args:
            model_id (str): The identifier of the model.

        Raises:
            InferenceModelNotFound: If the model is not found in the manager.
        """
        if model_id not in self:
            raise InferenceModelNotFound(f"Model with id {model_id} not loaded.")

    async def infer_from_request(
        self, model_id: str, request: InferenceRequest, **kwargs
    ) -> InferenceResponse:
        """Runs inference on the specified model with the given request.

        Args:
            model_id (str): The identifier of the model.
            request (InferenceRequest): The request to process.

        Returns:
            InferenceResponse: The response from the inference.
        """
        logger.debug(
            f"ModelManager - inference from request started for model_id={model_id}."
        )
        try:
            rtn_val = await self.model_infer(
                model_id=model_id, request=request, **kwargs
            )
            logger.debug(
                f"ModelManager - inference from request finished for model_id={model_id}."
            )
            finish_time = time.time()
            if not DISABLE_INFERENCE_CACHE:
                logger.debug(
                    f"ModelManager - caching inference request started for model_id={model_id}"
                )
                cache.zadd(
                    f"models",
                    value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}",
                    score=finish_time,
                    expire=METRICS_INTERVAL * 2,
                )
                if (
                    hasattr(request, "image")
                    and hasattr(request.image, "type")
                    and request.image.type == "numpy"
                ):
                    request.image.value = str(request.image.value)
                cache.zadd(
                    f"inference:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}",
                    value=to_cachable_inference_item(request, rtn_val),
                    score=finish_time,
                    expire=METRICS_INTERVAL * 2,
                )
                logger.debug(
                    f"ModelManager - caching inference request finished for model_id={model_id}"
                )
            return rtn_val
        except Exception as e:
            finish_time = time.time()
            if not DISABLE_INFERENCE_CACHE:
                cache.zadd(
                    f"models",
                    value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}",
                    score=finish_time,
                    expire=METRICS_INTERVAL * 2,
                )
                cache.zadd(
                    f"error:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}",
                    value={
                        "request": jsonable_encoder(
                            request.dict(exclude={"image", "subject", "prompt"})
                        ),
                        "error": str(e),
                    },
                    score=finish_time,
                    expire=METRICS_INTERVAL * 2,
                )
            raise

    async def model_infer(self, model_id: str, request: InferenceRequest, **kwargs):
        self.check_for_model(model_id)
        return self._models[model_id].infer_from_request(request)

    def make_response(
        self, model_id: str, predictions: List[List[float]], *args, **kwargs
    ) -> InferenceResponse:
        """Creates a response object from the model's predictions.

        Args:
            model_id (str): The identifier of the model.
            predictions (List[List[float]]): The model's predictions.

        Returns:
            InferenceResponse: The created response object.
        """
        self.check_for_model(model_id)
        return self._models[model_id].make_response(predictions, *args, **kwargs)

    def postprocess(
        self,
        model_id: str,
        predictions: Tuple[np.ndarray, ...],
        preprocess_return_metadata: PreprocessReturnMetadata,
        *args,
        **kwargs,
    ) -> List[List[float]]:
        """Processes the model's predictions after inference.

        Args:
            model_id (str): The identifier of the model.
            predictions (np.ndarray): The model's predictions.

        Returns:
            List[List[float]]: The post-processed predictions.
        """
        self.check_for_model(model_id)
        return self._models[model_id].postprocess(
            predictions, preprocess_return_metadata, *args, **kwargs
        )

    def predict(self, model_id: str, *args, **kwargs) -> Tuple[np.ndarray, ...]:
        """Runs prediction on the specified model.

        Args:
            model_id (str): The identifier of the model.

        Returns:
            np.ndarray: The predictions from the model.
        """
        self.check_for_model(model_id)
        self._models[model_id].metrics["num_inferences"] += 1
        tic = time.perf_counter()
        res = self._models[model_id].predict(*args, **kwargs)
        toc = time.perf_counter()
        self._models[model_id].metrics["avg_inference_time"] += toc - tic
        return res

    def preprocess(
        self, model_id: str, request: InferenceRequest
    ) -> Tuple[np.ndarray, PreprocessReturnMetadata]:
        """Preprocesses the request before inference.

        Args:
            model_id (str): The identifier of the model.
            request (InferenceRequest): The request to preprocess.

        Returns:
            Tuple[np.ndarray, List[Tuple[int, int]]]: The preprocessed data.
        """
        self.check_for_model(model_id)
        return self._models[model_id].preprocess(**request.dict())

    def get_class_names(self, model_id):
        """Retrieves the class names for a given model.

        Args:
            model_id (str): The identifier of the model.

        Returns:
            List[str]: The class names of the model.
        """
        self.check_for_model(model_id)
        return self._models[model_id].class_names

    def get_task_type(self, model_id: str, api_key: str = None) -> str:
        """Retrieves the task type for a given model.

        Args:
            model_id (str): The identifier of the model.

        Returns:
            str: The task type of the model.
        """
        self.check_for_model(model_id)
        return self._models[model_id].task_type

    def remove(self, model_id: str) -> None:
        """Removes a model from the manager.

        Args:
            model_id (str): The identifier of the model.
        """
        try:
            self.check_for_model(model_id)
            self._models[model_id].clear_cache()
            del self._models[model_id]
        except InferenceModelNotFound:
            logger.warning(
                f"Attempted to remove model with id {model_id}, but it is not loaded. Skipping..."
            )

    def clear(self) -> None:
        """Removes all models from the manager."""
        for model_id in list(self.keys()):
            self.remove(model_id)

    def __contains__(self, model_id: str) -> bool:
        """Checks if the model is contained in the manager.

        Args:
            model_id (str): The identifier of the model.

        Returns:
            bool: Whether the model is in the manager.
        """
        return model_id in self._models

    def __getitem__(self, key: str) -> Model:
        """Retrieve a model from the manager by key.

        Args:
            key (str): The identifier of the model.

        Returns:
            Model: The model corresponding to the key.
        """
        self.check_for_model(model_id=key)
        return self._models[key]

    def __len__(self) -> int:
        """Retrieve the number of models in the manager.

        Returns:
            int: The number of models in the manager.
        """
        return len(self._models)

    def keys(self):
        """Retrieve the keys (model identifiers) from the manager.

        Returns:
            List[str]: The keys of the models in the manager.
        """
        return self._models.keys()

    def models(self) -> Dict[str, Model]:
        """Retrieve the models dictionary from the manager.

        Returns:
            Dict[str, Model]: The keys of the models in the manager.
        """
        return self._models

    def describe_models(self) -> List[ModelDescription]:
        return [
            ModelDescription(
                model_id=model_id,
                task_type=model.task_type,
                batch_size=getattr(model, "batch_size", None),
                input_width=getattr(model, "img_size_w", None),
                input_height=getattr(model, "img_size_h", None),
            )
            for model_id, model in self._models.items()
        ]