Spaces:
Running
on
Zero
Running
on
Zero
File size: 11,559 Bytes
2eafbc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 |
import time
from typing import Dict, List, Optional, Tuple
import numpy as np
from fastapi.encoders import jsonable_encoder
from inference.core.cache import cache
from inference.core.cache.serializers import to_cachable_inference_item
from inference.core.devices.utils import GLOBAL_INFERENCE_SERVER_ID
from inference.core.entities.requests.inference import InferenceRequest
from inference.core.entities.responses.inference import InferenceResponse
from inference.core.env import (
DISABLE_INFERENCE_CACHE,
METRICS_ENABLED,
METRICS_INTERVAL,
ROBOFLOW_SERVER_UUID,
)
from inference.core.exceptions import InferenceModelNotFound
from inference.core.logger import logger
from inference.core.managers.entities import ModelDescription
from inference.core.managers.pingback import PingbackInfo
from inference.core.models.base import Model, PreprocessReturnMetadata
from inference.core.registries.base import ModelRegistry
class ModelManager:
"""Model managers keep track of a dictionary of Model objects and is responsible for passing requests to the right model using the infer method."""
def __init__(self, model_registry: ModelRegistry, models: Optional[dict] = None):
self.model_registry = model_registry
self._models: Dict[str, Model] = models if models is not None else {}
def init_pingback(self):
"""Initializes pingback mechanism."""
self.num_errors = 0 # in the device
self.uuid = ROBOFLOW_SERVER_UUID
if METRICS_ENABLED:
self.pingback = PingbackInfo(self)
self.pingback.start()
def add_model(
self, model_id: str, api_key: str, model_id_alias: Optional[str] = None
) -> None:
"""Adds a new model to the manager.
Args:
model_id (str): The identifier of the model.
model (Model): The model instance.
"""
logger.debug(
f"ModelManager - Adding model with model_id={model_id}, model_id_alias={model_id_alias}"
)
if model_id in self._models:
logger.debug(
f"ModelManager - model with model_id={model_id} is already loaded."
)
return
logger.debug("ModelManager - model initialisation...")
model = self.model_registry.get_model(
model_id if model_id_alias is None else model_id_alias, api_key
)(
model_id=model_id,
api_key=api_key,
)
logger.debug("ModelManager - model successfully loaded.")
self._models[model_id if model_id_alias is None else model_id_alias] = model
def check_for_model(self, model_id: str) -> None:
"""Checks whether the model with the given ID is in the manager.
Args:
model_id (str): The identifier of the model.
Raises:
InferenceModelNotFound: If the model is not found in the manager.
"""
if model_id not in self:
raise InferenceModelNotFound(f"Model with id {model_id} not loaded.")
async def infer_from_request(
self, model_id: str, request: InferenceRequest, **kwargs
) -> InferenceResponse:
"""Runs inference on the specified model with the given request.
Args:
model_id (str): The identifier of the model.
request (InferenceRequest): The request to process.
Returns:
InferenceResponse: The response from the inference.
"""
logger.debug(
f"ModelManager - inference from request started for model_id={model_id}."
)
try:
rtn_val = await self.model_infer(
model_id=model_id, request=request, **kwargs
)
logger.debug(
f"ModelManager - inference from request finished for model_id={model_id}."
)
finish_time = time.time()
if not DISABLE_INFERENCE_CACHE:
logger.debug(
f"ModelManager - caching inference request started for model_id={model_id}"
)
cache.zadd(
f"models",
value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}",
score=finish_time,
expire=METRICS_INTERVAL * 2,
)
if (
hasattr(request, "image")
and hasattr(request.image, "type")
and request.image.type == "numpy"
):
request.image.value = str(request.image.value)
cache.zadd(
f"inference:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}",
value=to_cachable_inference_item(request, rtn_val),
score=finish_time,
expire=METRICS_INTERVAL * 2,
)
logger.debug(
f"ModelManager - caching inference request finished for model_id={model_id}"
)
return rtn_val
except Exception as e:
finish_time = time.time()
if not DISABLE_INFERENCE_CACHE:
cache.zadd(
f"models",
value=f"{GLOBAL_INFERENCE_SERVER_ID}:{request.api_key}:{model_id}",
score=finish_time,
expire=METRICS_INTERVAL * 2,
)
cache.zadd(
f"error:{GLOBAL_INFERENCE_SERVER_ID}:{model_id}",
value={
"request": jsonable_encoder(
request.dict(exclude={"image", "subject", "prompt"})
),
"error": str(e),
},
score=finish_time,
expire=METRICS_INTERVAL * 2,
)
raise
async def model_infer(self, model_id: str, request: InferenceRequest, **kwargs):
self.check_for_model(model_id)
return self._models[model_id].infer_from_request(request)
def make_response(
self, model_id: str, predictions: List[List[float]], *args, **kwargs
) -> InferenceResponse:
"""Creates a response object from the model's predictions.
Args:
model_id (str): The identifier of the model.
predictions (List[List[float]]): The model's predictions.
Returns:
InferenceResponse: The created response object.
"""
self.check_for_model(model_id)
return self._models[model_id].make_response(predictions, *args, **kwargs)
def postprocess(
self,
model_id: str,
predictions: Tuple[np.ndarray, ...],
preprocess_return_metadata: PreprocessReturnMetadata,
*args,
**kwargs,
) -> List[List[float]]:
"""Processes the model's predictions after inference.
Args:
model_id (str): The identifier of the model.
predictions (np.ndarray): The model's predictions.
Returns:
List[List[float]]: The post-processed predictions.
"""
self.check_for_model(model_id)
return self._models[model_id].postprocess(
predictions, preprocess_return_metadata, *args, **kwargs
)
def predict(self, model_id: str, *args, **kwargs) -> Tuple[np.ndarray, ...]:
"""Runs prediction on the specified model.
Args:
model_id (str): The identifier of the model.
Returns:
np.ndarray: The predictions from the model.
"""
self.check_for_model(model_id)
self._models[model_id].metrics["num_inferences"] += 1
tic = time.perf_counter()
res = self._models[model_id].predict(*args, **kwargs)
toc = time.perf_counter()
self._models[model_id].metrics["avg_inference_time"] += toc - tic
return res
def preprocess(
self, model_id: str, request: InferenceRequest
) -> Tuple[np.ndarray, PreprocessReturnMetadata]:
"""Preprocesses the request before inference.
Args:
model_id (str): The identifier of the model.
request (InferenceRequest): The request to preprocess.
Returns:
Tuple[np.ndarray, List[Tuple[int, int]]]: The preprocessed data.
"""
self.check_for_model(model_id)
return self._models[model_id].preprocess(**request.dict())
def get_class_names(self, model_id):
"""Retrieves the class names for a given model.
Args:
model_id (str): The identifier of the model.
Returns:
List[str]: The class names of the model.
"""
self.check_for_model(model_id)
return self._models[model_id].class_names
def get_task_type(self, model_id: str, api_key: str = None) -> str:
"""Retrieves the task type for a given model.
Args:
model_id (str): The identifier of the model.
Returns:
str: The task type of the model.
"""
self.check_for_model(model_id)
return self._models[model_id].task_type
def remove(self, model_id: str) -> None:
"""Removes a model from the manager.
Args:
model_id (str): The identifier of the model.
"""
try:
self.check_for_model(model_id)
self._models[model_id].clear_cache()
del self._models[model_id]
except InferenceModelNotFound:
logger.warning(
f"Attempted to remove model with id {model_id}, but it is not loaded. Skipping..."
)
def clear(self) -> None:
"""Removes all models from the manager."""
for model_id in list(self.keys()):
self.remove(model_id)
def __contains__(self, model_id: str) -> bool:
"""Checks if the model is contained in the manager.
Args:
model_id (str): The identifier of the model.
Returns:
bool: Whether the model is in the manager.
"""
return model_id in self._models
def __getitem__(self, key: str) -> Model:
"""Retrieve a model from the manager by key.
Args:
key (str): The identifier of the model.
Returns:
Model: The model corresponding to the key.
"""
self.check_for_model(model_id=key)
return self._models[key]
def __len__(self) -> int:
"""Retrieve the number of models in the manager.
Returns:
int: The number of models in the manager.
"""
return len(self._models)
def keys(self):
"""Retrieve the keys (model identifiers) from the manager.
Returns:
List[str]: The keys of the models in the manager.
"""
return self._models.keys()
def models(self) -> Dict[str, Model]:
"""Retrieve the models dictionary from the manager.
Returns:
Dict[str, Model]: The keys of the models in the manager.
"""
return self._models
def describe_models(self) -> List[ModelDescription]:
return [
ModelDescription(
model_id=model_id,
task_type=model.task_type,
batch_size=getattr(model, "batch_size", None),
input_width=getattr(model, "img_size_w", None),
input_height=getattr(model, "img_size_h", None),
)
for model_id, model in self._models.items()
]
|