Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,222 Bytes
df6c67d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
from typing import List, Optional, Tuple
import numpy as np
from inference.core.entities.requests.inference import InferenceRequest
from inference.core.entities.responses.inference import InferenceResponse
from inference.core.env import API_KEY
from inference.core.managers.base import Model, ModelManager
from inference.core.models.types import PreprocessReturnMetadata
class ModelManagerDecorator(ModelManager):
"""Basic decorator, it acts like a `ModelManager` and contains a `ModelManager`.
Args:
model_manager (ModelManager): Instance of a ModelManager.
Methods:
add_model: Adds a model to the manager.
infer: Processes a complete inference request.
infer_only: Performs only the inference part of a request.
preprocess: Processes the preprocessing part of a request.
get_task_type: Gets the task type associated with a model.
get_class_names: Gets the class names for a given model.
remove: Removes a model from the manager.
__len__: Returns the number of models in the manager.
__getitem__: Retrieves a model by its ID.
__contains__: Checks if a model exists in the manager.
keys: Returns the keys (model IDs) from the manager.
"""
@property
def _models(self):
raise ValueError("Should only be accessing self.model_manager._models")
@property
def model_registry(self):
raise ValueError("Should only be accessing self.model_manager.model_registry")
def __init__(self, model_manager: ModelManager):
"""Initializes the decorator with an instance of a ModelManager."""
self.model_manager = model_manager
def add_model(
self, model_id: str, api_key: str, model_id_alias: Optional[str] = None
):
"""Adds a model to the manager.
Args:
model_id (str): The identifier of the model.
model (Model): The model instance.
"""
if model_id in self:
return
self.model_manager.add_model(model_id, api_key, model_id_alias=model_id_alias)
async def infer_from_request(
self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
"""Processes a complete inference request.
Args:
model_id (str): The identifier of the model.
request (InferenceRequest): The request to process.
Returns:
InferenceResponse: The response from the inference.
"""
return await self.model_manager.infer_from_request(model_id, request, **kwargs)
def infer_only(self, model_id: str, request, img_in, img_dims, batch_size=None):
"""Performs only the inference part of a request.
Args:
model_id (str): The identifier of the model.
request: The request to process.
img_in: Input image.
img_dims: Image dimensions.
batch_size (int, optional): Batch size.
Returns:
Response from the inference-only operation.
"""
return self.model_manager.infer_only(
model_id, request, img_in, img_dims, batch_size
)
def preprocess(self, model_id: str, request: InferenceRequest):
"""Processes the preprocessing part of a request.
Args:
model_id (str): The identifier of the model.
request (InferenceRequest): The request to preprocess.
"""
return self.model_manager.preprocess(model_id, request)
def get_task_type(self, model_id: str, api_key: str = None) -> str:
"""Gets the task type associated with a model.
Args:
model_id (str): The identifier of the model.
Returns:
str: The task type.
"""
if api_key is None:
api_key = API_KEY
return self.model_manager.get_task_type(model_id, api_key=api_key)
def get_class_names(self, model_id):
"""Gets the class names for a given model.
Args:
model_id: The identifier of the model.
Returns:
List of class names.
"""
return self.model_manager.get_class_names(model_id)
def remove(self, model_id: str) -> Model:
"""Removes a model from the manager.
Args:
model_id (str): The identifier of the model.
Returns:
Model: The removed model.
"""
return self.model_manager.remove(model_id)
def __len__(self) -> int:
"""Returns the number of models in the manager.
Returns:
int: Number of models.
"""
return len(self.model_manager)
def __getitem__(self, key: str) -> Model:
"""Retrieves a model by its ID.
Args:
key (str): The identifier of the model.
Returns:
Model: The model instance.
"""
return self.model_manager[key]
def __contains__(self, model_id: str):
"""Checks if a model exists in the manager.
Args:
model_id (str): The identifier of the model.
Returns:
bool: True if the model exists, False otherwise.
"""
return model_id in self.model_manager
def keys(self):
"""Returns the keys (model IDs) from the manager.
Returns:
List of keys (model IDs).
"""
return self.model_manager.keys()
def models(self):
return self.model_manager.models()
def predict(self, model_id: str, *args, **kwargs) -> Tuple[np.ndarray, ...]:
return self.model_manager.predict(model_id, *args, **kwargs)
def postprocess(
self,
model_id: str,
predictions: Tuple[np.ndarray, ...],
preprocess_return_metadata: PreprocessReturnMetadata,
*args,
**kwargs
) -> List[List[float]]:
return self.model_manager.postprocess(
model_id, predictions, preprocess_return_metadata, *args, **kwargs
)
def make_response(
self, model_id: str, predictions: List[List[float]], *args, **kwargs
) -> InferenceResponse:
return self.model_manager.make_response(model_id, predictions, *args, **kwargs)
|