File size: 6,222 Bytes
df6c67d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from typing import List, Optional, Tuple

import numpy as np

from inference.core.entities.requests.inference import InferenceRequest
from inference.core.entities.responses.inference import InferenceResponse
from inference.core.env import API_KEY
from inference.core.managers.base import Model, ModelManager
from inference.core.models.types import PreprocessReturnMetadata


class ModelManagerDecorator(ModelManager):
    """Basic decorator, it acts like a `ModelManager` and contains a `ModelManager`.

    Args:
        model_manager (ModelManager): Instance of a ModelManager.

    Methods:
        add_model: Adds a model to the manager.
        infer: Processes a complete inference request.
        infer_only: Performs only the inference part of a request.
        preprocess: Processes the preprocessing part of a request.
        get_task_type: Gets the task type associated with a model.
        get_class_names: Gets the class names for a given model.
        remove: Removes a model from the manager.
        __len__: Returns the number of models in the manager.
        __getitem__: Retrieves a model by its ID.
        __contains__: Checks if a model exists in the manager.
        keys: Returns the keys (model IDs) from the manager.
    """

    @property
    def _models(self):
        raise ValueError("Should only be accessing self.model_manager._models")

    @property
    def model_registry(self):
        raise ValueError("Should only be accessing self.model_manager.model_registry")

    def __init__(self, model_manager: ModelManager):
        """Initializes the decorator with an instance of a ModelManager."""
        self.model_manager = model_manager

    def add_model(
        self, model_id: str, api_key: str, model_id_alias: Optional[str] = None
    ):
        """Adds a model to the manager.

        Args:
            model_id (str): The identifier of the model.
            model (Model): The model instance.
        """
        if model_id in self:
            return
        self.model_manager.add_model(model_id, api_key, model_id_alias=model_id_alias)

    async def infer_from_request(
        self, model_id: str, request: InferenceRequest, **kwargs
    ) -> InferenceResponse:
        """Processes a complete inference request.

        Args:
            model_id (str): The identifier of the model.
            request (InferenceRequest): The request to process.

        Returns:
            InferenceResponse: The response from the inference.
        """
        return await self.model_manager.infer_from_request(model_id, request, **kwargs)

    def infer_only(self, model_id: str, request, img_in, img_dims, batch_size=None):
        """Performs only the inference part of a request.

        Args:
            model_id (str): The identifier of the model.
            request: The request to process.
            img_in: Input image.
            img_dims: Image dimensions.
            batch_size (int, optional): Batch size.

        Returns:
            Response from the inference-only operation.
        """
        return self.model_manager.infer_only(
            model_id, request, img_in, img_dims, batch_size
        )

    def preprocess(self, model_id: str, request: InferenceRequest):
        """Processes the preprocessing part of a request.

        Args:
            model_id (str): The identifier of the model.
            request (InferenceRequest): The request to preprocess.
        """
        return self.model_manager.preprocess(model_id, request)

    def get_task_type(self, model_id: str, api_key: str = None) -> str:
        """Gets the task type associated with a model.

        Args:
            model_id (str): The identifier of the model.

        Returns:
            str: The task type.
        """
        if api_key is None:
            api_key = API_KEY
        return self.model_manager.get_task_type(model_id, api_key=api_key)

    def get_class_names(self, model_id):
        """Gets the class names for a given model.

        Args:
            model_id: The identifier of the model.

        Returns:
            List of class names.
        """
        return self.model_manager.get_class_names(model_id)

    def remove(self, model_id: str) -> Model:
        """Removes a model from the manager.

        Args:
            model_id (str): The identifier of the model.

        Returns:
            Model: The removed model.
        """
        return self.model_manager.remove(model_id)

    def __len__(self) -> int:
        """Returns the number of models in the manager.

        Returns:
            int: Number of models.
        """
        return len(self.model_manager)

    def __getitem__(self, key: str) -> Model:
        """Retrieves a model by its ID.

        Args:
            key (str): The identifier of the model.

        Returns:
            Model: The model instance.
        """
        return self.model_manager[key]

    def __contains__(self, model_id: str):
        """Checks if a model exists in the manager.

        Args:
            model_id (str): The identifier of the model.

        Returns:
            bool: True if the model exists, False otherwise.
        """
        return model_id in self.model_manager

    def keys(self):
        """Returns the keys (model IDs) from the manager.

        Returns:
            List of keys (model IDs).
        """
        return self.model_manager.keys()

    def models(self):
        return self.model_manager.models()

    def predict(self, model_id: str, *args, **kwargs) -> Tuple[np.ndarray, ...]:
        return self.model_manager.predict(model_id, *args, **kwargs)

    def postprocess(
        self,
        model_id: str,
        predictions: Tuple[np.ndarray, ...],
        preprocess_return_metadata: PreprocessReturnMetadata,
        *args,
        **kwargs
    ) -> List[List[float]]:
        return self.model_manager.postprocess(
            model_id, predictions, preprocess_return_metadata, *args, **kwargs
        )

    def make_response(
        self, model_id: str, predictions: List[List[float]], *args, **kwargs
    ) -> InferenceResponse:
        return self.model_manager.make_response(model_id, predictions, *args, **kwargs)