import asyncio from asyncio import BoundedSemaphore from time import perf_counter, time from typing import Any, Dict, List, Optional import orjson from redis.asyncio import Redis from inference.core.entities.requests.inference import ( InferenceRequest, request_from_type, ) from inference.core.entities.responses.inference import response_from_type from inference.core.env import NUM_PARALLEL_TASKS from inference.core.managers.base import ModelManager from inference.core.registries.base import ModelRegistry from inference.core.registries.roboflow import get_model_type from inference.enterprise.parallel.tasks import preprocess from inference.enterprise.parallel.utils import FAILURE_STATE, SUCCESS_STATE class ResultsChecker: """ Class responsible for queuing asyncronous inference runs, keeping track of running requests, and awaiting their results. """ def __init__(self, redis: Redis): self.tasks: Dict[str, asyncio.Event] = {} self.dones = dict() self.errors = dict() self.running = True self.redis = redis self.semaphore: BoundedSemaphore = BoundedSemaphore(NUM_PARALLEL_TASKS) async def add_task(self, task_id: str, request: InferenceRequest): """ Wait until there's available cylce to queue a task. When there are cycles, add the task's id to a list to keep track of its results, launch the preprocess celeryt task, set the task's status to in progress in redis. """ await self.semaphore.acquire() self.tasks[task_id] = asyncio.Event() preprocess.s(request.dict()).delay() def get_result(self, task_id: str) -> Any: """ Check the done tasks and errored tasks for this task id. """ if task_id in self.dones: return self.dones.pop(task_id) elif task_id in self.errors: message = self.errors.pop(task_id) raise Exception(message) else: raise RuntimeError( "Task result not found in either success or error dict. Unreachable" ) async def loop(self): """ Main loop. Check all in progress tasks for their status, and if their status is final, (either failure or success) then add their results to the appropriate results dictionary. """ async with self.redis.pubsub() as pubsub: await pubsub.subscribe("results") async for message in pubsub.listen(): if message["type"] != "message": continue message = orjson.loads(message["data"]) task_id = message.pop("task_id") if task_id not in self.tasks: continue self.semaphore.release() status = message.pop("status") if status == FAILURE_STATE: self.errors[task_id] = message["payload"] elif status == SUCCESS_STATE: self.dones[task_id] = message["payload"] else: raise RuntimeError( "Task result not found in possible states. Unreachable" ) self.tasks[task_id].set() await asyncio.sleep(0) async def wait_for_response(self, key: str): event = self.tasks[key] await event.wait() del self.tasks[key] return self.get_result(key) class DispatchModelManager(ModelManager): def __init__( self, model_registry: ModelRegistry, checker: ResultsChecker, models: Optional[dict] = None, ): super().__init__(model_registry, models) self.checker = checker async def model_infer(self, model_id: str, request: InferenceRequest, **kwargs): if request.visualize_predictions: raise NotImplementedError("Visualisation of prediction is not supported") request.start = time() t = perf_counter() task_type = self.get_task_type(model_id, request.api_key) list_mode = False if isinstance(request.image, list): list_mode = True request_dict = request.dict() images = request_dict.pop("image") del request_dict["id"] requests = [ request_from_type(task_type, dict(**request_dict, image=image)) for image in images ] else: requests = [request] start_task_awaitables = [] results_awaitables = [] for r in requests: start_task_awaitables.append(self.checker.add_task(r.id, r)) results_awaitables.append(self.checker.wait_for_response(r.id)) await asyncio.gather(*start_task_awaitables) response_jsons = await asyncio.gather(*results_awaitables) responses = [] for response_json in response_jsons: response = response_from_type(task_type, response_json) response.time = perf_counter() - t responses.append(response) if list_mode: return responses return responses[0] def add_model( self, model_id: str, api_key: str, model_id_alias: str = None ) -> None: pass def __contains__(self, model_id: str) -> bool: return True def get_task_type(self, model_id: str, api_key: str = None) -> str: return get_model_type(model_id, api_key)[0]