File size: 2,836 Bytes
e71a2ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import multiprocessing as mp
import time

import pytest
import torch

from src.server.runtime import Runtime
from src.server.task_pool import PrioritizedTaskPool


@pytest.mark.forked
def test_priority_pools():
    outputs_queue = mp.SimpleQueue()
    results_valid = mp.Event()

    def dummy_pool_func(x):
        time.sleep(0.1)
        y = x**2
        outputs_queue.put((x, y))
        return (y,)

    class DummyBackend:
        def __init__(self, pools):
            self.pools = pools

        def get_pools(self):
            return self.pools

    pools = (
        PrioritizedTaskPool(dummy_pool_func, name="A", max_batch_size=1),
        PrioritizedTaskPool(dummy_pool_func, name="B", max_batch_size=1),
    )

    runtime = Runtime({str(i): DummyBackend([pool]) for i, pool in enumerate(pools)}, prefetch_batches=0)
    runtime.start()

    def process_tasks():
        futures = []
        futures.append(pools[0].submit_task(torch.tensor([0]), priority=1))
        futures.append(pools[0].submit_task(torch.tensor([1]), priority=1))
        time.sleep(0.01)
        futures.append(pools[1].submit_task(torch.tensor([2]), priority=1))
        futures.append(pools[0].submit_task(torch.tensor([3]), priority=2))
        futures.append(pools[0].submit_task(torch.tensor([4]), priority=10))
        futures.append(pools[0].submit_task(torch.tensor([5]), priority=0))
        futures.append(pools[0].submit_task(torch.tensor([6]), priority=1))
        futures.append(pools[1].submit_task(torch.tensor([7]), priority=11))
        futures.append(pools[1].submit_task(torch.tensor([8]), priority=1))
        for i, f in enumerate(futures):
            assert f.result()[0].item() == i**2
        results_valid.set()

    proc = mp.Process(target=process_tasks)
    proc.start()
    proc.join()
    assert results_valid.is_set()

    ordered_outputs = []
    while not outputs_queue.empty():
        ordered_outputs.append(outputs_queue.get()[0].item())

    assert ordered_outputs == [0, 5, 1, 2, 6, 8, 3, 4, 7]
    #                          0 - first batch is loaded immediately, before everything else
    #                             5 - highest priority task overall
    #                                1 - first of several tasks with equal lowest priority (1)
    #                                   2 - second earliest task with priority 1, fetched from pool B
    #                                      6 - third earliest task with priority 1, fetched from pool A again
    #                                         8 - last priority-1 task, pool B
    #                                            3 - task with priority 2 from pool A
    #                                               4 - task with priority 10 from pool A
    #                                                  7 - task with priority 11 from pool B