File size: 5,502 Bytes
df6c67d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import asyncio
from asyncio import BoundedSemaphore
from time import perf_counter, time
from typing import Any, Dict, List, Optional

import orjson
from redis.asyncio import Redis

from inference.core.entities.requests.inference import (
    InferenceRequest,
    request_from_type,
)
from inference.core.entities.responses.inference import response_from_type
from inference.core.env import NUM_PARALLEL_TASKS
from inference.core.managers.base import ModelManager
from inference.core.registries.base import ModelRegistry
from inference.core.registries.roboflow import get_model_type
from inference.enterprise.parallel.tasks import preprocess
from inference.enterprise.parallel.utils import FAILURE_STATE, SUCCESS_STATE


class ResultsChecker:
    """
    Class responsible for queuing asyncronous inference runs,
    keeping track of running requests, and awaiting their results.
    """

    def __init__(self, redis: Redis):
        self.tasks: Dict[str, asyncio.Event] = {}
        self.dones = dict()
        self.errors = dict()
        self.running = True
        self.redis = redis
        self.semaphore: BoundedSemaphore = BoundedSemaphore(NUM_PARALLEL_TASKS)

    async def add_task(self, task_id: str, request: InferenceRequest):
        """
        Wait until there's available cylce to queue a task.
        When there are cycles, add the task's id to a list to keep track of its results,
        launch the preprocess celeryt task, set the task's status to in progress in redis.
        """
        await self.semaphore.acquire()
        self.tasks[task_id] = asyncio.Event()
        preprocess.s(request.dict()).delay()

    def get_result(self, task_id: str) -> Any:
        """
        Check the done tasks and errored tasks for this task id.
        """
        if task_id in self.dones:
            return self.dones.pop(task_id)
        elif task_id in self.errors:
            message = self.errors.pop(task_id)
            raise Exception(message)
        else:
            raise RuntimeError(
                "Task result not found in either success or error dict. Unreachable"
            )

    async def loop(self):
        """
        Main loop. Check all in progress tasks for their status, and if their status is final,
        (either failure or success) then add their results to the appropriate results dictionary.
        """
        async with self.redis.pubsub() as pubsub:
            await pubsub.subscribe("results")
            async for message in pubsub.listen():
                if message["type"] != "message":
                    continue
                message = orjson.loads(message["data"])
                task_id = message.pop("task_id")
                if task_id not in self.tasks:
                    continue
                self.semaphore.release()
                status = message.pop("status")
                if status == FAILURE_STATE:
                    self.errors[task_id] = message["payload"]
                elif status == SUCCESS_STATE:
                    self.dones[task_id] = message["payload"]
                else:
                    raise RuntimeError(
                        "Task result not found in possible states. Unreachable"
                    )
                self.tasks[task_id].set()
                await asyncio.sleep(0)

    async def wait_for_response(self, key: str):
        event = self.tasks[key]
        await event.wait()
        del self.tasks[key]
        return self.get_result(key)


class DispatchModelManager(ModelManager):
    def __init__(
        self,
        model_registry: ModelRegistry,
        checker: ResultsChecker,
        models: Optional[dict] = None,
    ):
        super().__init__(model_registry, models)
        self.checker = checker

    async def model_infer(self, model_id: str, request: InferenceRequest, **kwargs):
        if request.visualize_predictions:
            raise NotImplementedError("Visualisation of prediction is not supported")
        request.start = time()
        t = perf_counter()
        task_type = self.get_task_type(model_id, request.api_key)

        list_mode = False
        if isinstance(request.image, list):
            list_mode = True
            request_dict = request.dict()
            images = request_dict.pop("image")
            del request_dict["id"]
            requests = [
                request_from_type(task_type, dict(**request_dict, image=image))
                for image in images
            ]
        else:
            requests = [request]

        start_task_awaitables = []
        results_awaitables = []
        for r in requests:
            start_task_awaitables.append(self.checker.add_task(r.id, r))
            results_awaitables.append(self.checker.wait_for_response(r.id))

        await asyncio.gather(*start_task_awaitables)
        response_jsons = await asyncio.gather(*results_awaitables)
        responses = []
        for response_json in response_jsons:
            response = response_from_type(task_type, response_json)
            response.time = perf_counter() - t
            responses.append(response)

        if list_mode:
            return responses
        return responses[0]

    def add_model(
        self, model_id: str, api_key: str, model_id_alias: str = None
    ) -> None:
        pass

    def __contains__(self, model_id: str) -> bool:
        return True

    def get_task_type(self, model_id: str, api_key: str = None) -> str:
        return get_model_type(model_id, api_key)[0]