File size: 8,817 Bytes
e71a2ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import multiprocessing as mp
import multiprocessing.pool
import threading
from collections import defaultdict
from itertools import chain
from queue import SimpleQueue
from selectors import EVENT_READ, DefaultSelector
from statistics import mean
from time import time
from typing import Dict, NamedTuple, Optional

import torch
from hivemind.moe.server.module_backend import ModuleBackend
from hivemind.utils import get_logger
from prefetch_generator import BackgroundGenerator

logger = get_logger(__name__)


class Runtime(threading.Thread):
    """
    A group of processes that processes incoming requests for multiple module backends on a shared device.
    Runtime is usually created and managed by Server, humans need not apply.

    For debugging, you can start runtime manually with .start() or .run()

    >>> module_backends = {'block_uid': ModuleBackend(**kwargs)}
    >>> runtime = Runtime(module_backends)
    >>> runtime.start()  # start runtime in background thread. To start in current thread, use runtime.run()
    >>> runtime.ready.wait()  # await for runtime to load all blocks on device and create request pools
    >>> future = runtime.module_backends['block_uid'].forward_pool.submit_task(*module_inputs)
    >>> print("Returned:", future.result())
    >>> runtime.shutdown()

    :param module_backends: a dict [block uid -> ModuleBackend]
    :param prefetch_batches: form up to this many batches in advance
    :param sender_threads: dispatches outputs from finished batches using this many asynchronous threads
    :param device: if specified, moves all blocks and data to this device via .to(device=device).
      If you want to manually specify devices for each block (in their forward pass), leave device=None (default)

    :param stats_report_interval: interval to collect and log statistics about runtime performance
    """

    SHUTDOWN_TRIGGER = "RUNTIME SHUTDOWN TRIGGERED"

    def __init__(
        self,
        module_backends: Dict[str, ModuleBackend],
        prefetch_batches: int = 1,
        sender_threads: int = 1,
        device: torch.device = None,
        stats_report_interval: Optional[int] = None,
    ):
        super().__init__()
        self.module_backends = module_backends
        self.pools = tuple(chain(*(backend.get_pools() for backend in module_backends.values())))
        self.device, self.prefetch_batches, self.sender_threads = device, prefetch_batches, sender_threads
        self.shutdown_recv, self.shutdown_send = mp.Pipe(duplex=False)
        self.shutdown_trigger = mp.Event()
        self.ready = mp.Event()  # event is set iff server is currently running and ready to accept batches

        self.stats_report_interval = stats_report_interval
        if self.stats_report_interval is not None:
            self.stats_reporter = StatsReporter(self.stats_report_interval)

    def run(self):
        for pool in self.pools:
            if not pool.is_alive():
                pool.start()
        if self.device is not None:
            for backend in self.module_backends.values():
                backend.module.to(self.device)

        with mp.pool.ThreadPool(self.sender_threads) as output_sender_pool:
            try:
                self.ready.set()
                if self.stats_report_interval is not None:
                    self.stats_reporter.start()
                logger.info("Started")

                batch_iterator = self.iterate_minibatches_from_pools()
                if self.prefetch_batches > 0:
                    batch_iterator = BackgroundGenerator(batch_iterator, self.prefetch_batches)

                for pool, batch_index, batch in batch_iterator:
                    logger.debug(f"Processing batch {batch_index} from pool {pool.name}")

                    start = time()
                    try:
                        outputs = pool.process_func(*batch)
                        output_sender_pool.apply_async(pool.send_outputs_from_runtime, args=[batch_index, outputs])

                        batch_processing_time = time() - start

                        batch_size = outputs[0].size(0)
                        logger.debug(f"Pool {pool.name}: batch {batch_index} processed, size {batch_size}")

                        if self.stats_report_interval is not None:
                            self.stats_reporter.report_stats(pool.name, batch_size, batch_processing_time)

                    except KeyboardInterrupt:
                        raise
                    except BaseException as exception:
                        logger.exception(f"Caught {exception}, attempting to recover")
                        output_sender_pool.apply_async(pool.send_exception_from_runtime, args=[batch_index, exception])

            finally:
                if not self.shutdown_trigger.is_set():
                    self.shutdown()

    def shutdown(self):
        """Gracefully terminate a running runtime."""
        logger.info("Shutting down")
        self.ready.clear()

        if self.stats_report_interval is not None:
            self.stats_reporter.stop.set()
            self.stats_reporter.join()

        logger.debug("Terminating pools")
        for pool in self.pools:
            if pool.is_alive():
                pool.shutdown()
        logger.debug("Pools terminated")

        # trigger background thread to shutdown
        self.shutdown_send.send(self.SHUTDOWN_TRIGGER)
        self.shutdown_trigger.set()

    def iterate_minibatches_from_pools(self, timeout=None):
        """
        Chooses pool according to priority, then copies exposed batch and frees the buffer
        """
        with DefaultSelector() as selector:
            for pool in self.pools:
                selector.register(pool.batch_receiver, EVENT_READ, pool)
            selector.register(self.shutdown_recv, EVENT_READ, self.SHUTDOWN_TRIGGER)

            while True:
                # wait until at least one batch_receiver becomes available
                logger.debug("Waiting for inputs from task pools")
                ready_fds = selector.select()
                ready_objects = {key.data for (key, events) in ready_fds}
                if self.SHUTDOWN_TRIGGER in ready_objects:
                    break  # someone asked us to shutdown, break from the loop

                logger.debug("Choosing the pool with first priority")

                pool = min(ready_objects, key=lambda pool: pool.priority)

                logger.debug(f"Loading batch from {pool.name}")
                batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
                logger.debug(f"Loaded batch from {pool.name}")
                yield pool, batch_index, batch_tensors


BatchStats = NamedTuple("BatchStats", (("batch_size", int), ("processing_time", float)))


class StatsReporter(threading.Thread):
    def __init__(self, report_interval: int):
        super().__init__()
        self.report_interval = report_interval
        self.stop = threading.Event()
        self.stats_queue = SimpleQueue()

    def run(self):
        while not self.stop.wait(self.report_interval):
            pool_batch_stats = defaultdict(list)
            while not self.stats_queue.empty():
                pool_uid, batch_stats = self.stats_queue.get()
                pool_batch_stats[pool_uid].append(batch_stats)

            total_processed_batches = sum(len(pool_stats) for pool_stats in pool_batch_stats.values())
            logger.info(f"Processed {total_processed_batches} batches in last {self.report_interval} seconds:")
            for pool_uid, pool_stats in pool_batch_stats.items():
                total_batches = len(pool_stats)
                total_examples = sum(batch_stats.batch_size for batch_stats in pool_stats)
                avg_batch_size = mean(batch_stats.batch_size for batch_stats in pool_stats)
                total_time = sum(batch_stats.processing_time for batch_stats in pool_stats)
                batches_to_time = total_batches / total_time
                batch_performance = f"{batches_to_time:.2f} " + ("batches/s" if batches_to_time > 1 else "s/batch")

                examples_to_time = total_examples / total_time
                example_performance = f"{examples_to_time:.2f} " + (
                    "examples/s" if examples_to_time > 1 else "s/example"
                )

                logger.info(
                    f"{pool_uid}: "
                    f"{total_batches} batches ({batch_performance}), "
                    f"{total_examples} examples ({example_performance}), "
                    f"avg batch size {avg_batch_size:.2f}"
                )

    def report_stats(self, pool_uid, batch_size, processing_time):
        batch_stats = BatchStats(batch_size, processing_time)
        self.stats_queue.put_nowait((pool_uid, batch_stats))