File size: 18,219 Bytes
ec0c8fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
from typing import *
from abc import abstractmethod
from queue import Empty, Full
from threading import Thread
from queue import Queue
from multiprocessing import Process
from threading import Thread, Event
import multiprocessing
import threading
import inspect
import time
import uuid
from copy import deepcopy
import itertools
import functools

__all__ = [
    'Node', 
    'Link',
    'ConcurrentNode',
    'Worker', 
    'WorkerFunction',
    'Provider',
    'ProviderFunction',
    'Sequential',
    'Batch',
    'Unbatch',
    'Parallel',
    'Graph', 
    'Buffer',
]

TERMINATE_CHECK_INTERVAL = 0.5


class _ItemWrapper:
    def __init__(self, data: Any, id: Union[int, List[int]] = None):
        self.data = data
        self.id = id


class Terminate(Exception):
    pass


def _get_queue_item(queue: Queue, terminate_flag: Event, timeout: float = None) -> _ItemWrapper:
    while True:
        try:
            item: _ItemWrapper = queue.get(block=True, timeout=TERMINATE_CHECK_INTERVAL if timeout is None else min(timeout, TERMINATE_CHECK_INTERVAL))
            if terminate_flag.is_set():
                raise Terminate()
            return item
        except Empty:
            if terminate_flag.is_set():
                raise Terminate()
            
        if timeout is not None:
            timeout -= TERMINATE_CHECK_INTERVAL
            if timeout <= 0:
                raise Empty()


def _put_queue_item(queue: Queue, item: _ItemWrapper, terminate_flag: Event):
    while True:
        try:
            queue.put(item, block=True, timeout=TERMINATE_CHECK_INTERVAL)
            if terminate_flag.is_set():
                raise Terminate()
            return
        except Full:
            if terminate_flag.is_set():
                raise Terminate()

class Node:
    def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1) -> None:
        self.input: Queue = Queue(maxsize=in_buffer_size)
        self.output: Queue = Queue(maxsize=out_buffer_size)
        self.in_buffer_size = in_buffer_size
        self.out_buffer_size = out_buffer_size

    @abstractmethod
    def start(self):
        pass

    @abstractmethod
    def terminate(self):
        pass

    def stop(self):
        self.terminate()
        self.join()

    @abstractmethod
    def join(self):
        pass
    
    def put(self, data: Any, key: str = None, block: bool = True) -> None:
        item = _ItemWrapper(data)
        self.input.put(item, block=block)
    
    def get(self, key: str = None, block: bool = True) -> Any:
        item: _ItemWrapper = self.output.get(block=block)
        return item.data

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.terminate()
        self.join()


class ConcurrentNode(Node):
    job: Union[Thread, Process]

    def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None:
        super().__init__(in_buffer_size, out_buffer_size)
        self.running_as = running_as

    @abstractmethod
    def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
        pass

    def start(self):
        if self.running_as == 'thread':
            terminate_flag = threading.Event()
            job = Thread(target=self._loop_fn, args=(self.input, self.output, terminate_flag))
        elif self.running_as == 'process':
            terminate_flag = multiprocessing.Event()
            job = Process(target=self._loop_fn, args=(self.input, self.output, terminate_flag))
        job.start()
        self.job = job
        self.terminate_flag = terminate_flag

    def terminate(self):
        self.terminate_flag.set()

    def join(self):
        self.job.join()


class Worker(ConcurrentNode):
    def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 0, out_buffer_size: int = 0) -> None:
        super().__init__(running_as, in_buffer_size, out_buffer_size)

    def init(self) -> None:
        """
        This method is called the the thread is started, to initialize any resources that is only held in the thread.
        """
        pass

    @abstractmethod
    def work(self, *args, **kwargs) -> Union[Any, Dict[str, Any]]:
        """
        This method defines the job that the node should do for each input item. 
        A item obtained from the input queue is passed as arguments to this method, and the result is placed in the output queue.
        The method is executed concurrently with other nodes.
        """
        pass

    def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
        self.init()
        try:
            while True:
                item = _get_queue_item(input, terminate_flag)
                result = self.work(item.data)
                _put_queue_item(output, _ItemWrapper(result, item.id), terminate_flag)
                
        except Terminate:
            return


class Provider(ConcurrentNode):
    """
    A node that provides data to successive nodes. It takes no input and provides data to the output queue.
    """
    def __init__(self, running_as: Literal['thread', 'process'], out_buffer_size: int = 1) -> None:
        super().__init__(running_as, 0, out_buffer_size)

    def init(self) -> None:
        """
        This method is called the the thread or process is started, to initialize any resources that is only held in the thread or process.
        """
        pass

    @abstractmethod
    def provide(self) -> Generator[Any, None, None]:
        pass

    def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
        self.init()
        try:
            for data in self.provide():
                _put_queue_item(output, _ItemWrapper(data), terminate_flag)
        except Terminate:
            return


class WorkerFunction(Worker):
    def __init__(self, fn: Callable, running_as: 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None:
        super().__init__(running_as, in_buffer_size, out_buffer_size)
        self.fn = fn

    def work(self, *args, **kwargs):
        return self.fn(*args, **kwargs)


class ProviderFunction(Provider):
    def __init__(self, fn: Callable, running_as: 'thread', out_buffer_size: int = 1) -> None:
        super().__init__(running_as, out_buffer_size)
        self.fn = fn

    def provide(self):
        for item in self.fn():
            yield item


class Link:
    def __init__(self, src: Queue, dst: Queue):
        self.src = src
        self.dst = dst
        
    def _thread_fn(self):
        try:
            while True:
                item = _get_queue_item(self.src, self.terminate_flag)
                _put_queue_item(self.dst, item, self.terminate_flag)
        except Terminate:
            return
    
    def start(self):
        self.terminate_flag = threading.Event()
        self.thread = Thread(target=self._thread_fn)
        self.thread.start()

    def terminate(self):
        self.terminate_flag.set()

    def join(self):
        self.thread.join()


class Graph(Node):
    """
    Graph pipeline of nodes and links
    """
    nodes: List[Node]
    links: List[Link]

    def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1):
        super().__init__(in_buffer_size, out_buffer_size)
        self.nodes = []
        self.links = []

    def add(self, node: Node):
        self.nodes.append(node)

    def link(self, src: Union[Node, Tuple[Node, str]], dst: Union[Node, Tuple[Node, str]]):
        """
        Links the output of the source node to the input of the destination node.
        If the source or destination node is None, the pipeline's input or output is used.
        """
        src_queue = self.input if src is None else src.output
        dst_queue = self.output if dst is None else dst.input
        self.links.append(Link(src_queue, dst_queue))

    def chain(self, nodes: Iterable[Node]):
        """
        Link the output of each node to the input of the next node.
        """
        nodes = list(nodes)
        for i in range(len(nodes) - 1):
            self.link(nodes[i], nodes[i + 1])

    def start(self):
        for node in self.nodes:
            node.start()
        for link in self.links:
            link.start()

    def terminate(self):
        for node in self.nodes:
            node.terminate()
        for link in self.links:
            link.terminate()

    def join(self):
        for node in self.nodes:
            node.join()
        for link in self.links:
            link.join()

    def __iter__(self):
        providers = [node for node in self.nodes if isinstance(node, Provider)]
        if len(providers) == 0:
            raise ValueError("No provider node found in the pipeline. If you want to iterate over the pipeline, the pipeline must be driven by a provider node.")
        with self:
            # while all(provider.job.is_alive() for provider in providers):
            while True:
                yield self.get()

    def __call__(self, data: Any) -> Any:
        """
        Submit data to the pipeline's input queue, and return the output data asynchronously.
        NOTE: The pipeline must be streamed (i.e., every output item is uniquely associated with an input item) for this to work.
        """
        # TODO


class Sequential(Graph):
    """
    Pipeline of nodes in sequential order, where each node takes the output of the previous node as input.
    The order of input and output items is preserved (FIFO)
    """
    def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1):
        """
        Initialize the pipeline with a list of nodes to execute sequentially.
        ### Parameters:
        - nodes: List of nodes or functions to execute sequentially. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes.
        - function_running_as: Whether to wrap the function as a thread or process worker. Default is 'thread'.
        - in_buffer_size: Maximum size of the input queue of the pipeline. Default is 0 (unlimited).
        - out_buffer_size: Maximum size of the output queue of the pipeline. Default is 0 (unlimited).
        """
        super().__init__(in_buffer_size, out_buffer_size)
        for node in nodes:
            if isinstance(node, Node):
                pass
            elif isinstance(node, Callable):
                if inspect.isgeneratorfunction(node):
                    node = ProviderFunction(node, function_running_as)
                else:
                    node = WorkerFunction(node, function_running_as)
            else:
                raise ValueError(f"Invalid node type: {type(node)}")
            self.add(node)
        self.chain([None, *self.nodes, None])


class Parallel(Node):
    """
    A FIFO node that runs multiple nodes in parallel to process the input items. Each input item is handed to one of the nodes whoever is available.
    NOTE: It is FIFO if and only if all the nested nodes are FIFO.
    """
    nodes: List[Node]

    def __init__(self, nodes: Iterable[Node], in_buffer_size: int = 1, out_buffer_size: int = 1, function_running_as: Literal['thread', 'process'] = 'thread'):
        super().__init__(in_buffer_size, out_buffer_size)
        self.nodes = []
        for node in nodes:
            if isinstance(node, Node):
                pass
            elif isinstance(node, Callable):
                if inspect.isgeneratorfunction(node):
                    node = ProviderFunction(node, function_running_as)
                else:
                    node = WorkerFunction(node, function_running_as)
            else:
                raise ValueError(f"Invalid node type: {type(node)}")
            self.nodes.append(node)
        self.output_order = Queue()
        self.lock = threading.Lock()

    def _in_thread_fn(self, node: Node):
        try:
            while True:
                with self.lock:
                    # A better idea: first make sure its node is vacant, then get it a new item. 
                    # Currently we will not be able to know which node is busy util there is at least one item already waiting in the queue of the node.
                    # This could lead to suboptimal scheduling.
                    item = _get_queue_item(self.input, self.terminate_flag)
                    self.output_order.put(node.output)
                _put_queue_item(node.input, item, self.terminate_flag)
        except Terminate:
            return
    
    def _out_thread_fn(self):
        try:
            while True:
                queue = _get_queue_item(self.output_order, self.terminate_flag)
                item = _get_queue_item(queue, self.terminate_flag)
                _put_queue_item(self.output, item, self.terminate_flag)
        except Terminate:
            return

    def start(self):
        self.terminate_flag = threading.Event()
        self.in_threads = []
        for node in self.nodes:
            thread = Thread(target=self._in_thread_fn, args=(node,))
            thread.start()
            self.in_threads.append(thread)
        thread = Thread(target=self._out_thread_fn)
        thread.start()
        self.out_thread = thread
        for node in self.nodes:
            node.start()

    def terminate(self):
        self.terminate_flag.set()
        for node in self.nodes:
            node.terminate()

    def join(self):
        for thread in self.in_threads:
            thread.join()
        self.out_thread.join()


class UnorderedParallel(Graph):
    """
    Pipeline of nodes in parallel, where each input item is handed to one of the nodes whoever is available.
    NOTE: The order of the output items is NOT guaranteed to be the same as the input items, depending on how fast the nodes handle their input.
    """
    def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1):
        """
        Initialize the pipeline with a list of nodes to execute in parallel. If a function is given, it is wrapped in a worker node.
        ### Parameters:
        - nodes: List of nodes or functions to execute in parallel. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes.
        - function_running_as: Whether to wrap the function as a thread or process worker. Default is 'thread'.
        - in_buffer_size: Maximum size of the input queue of the pipeline. Default is 0 (unlimited).
        - out_buffer_size: Maximum size of the output queue of the pipeline. Default is 0 (unlimited).
        """
        super().__init__(in_buffer_size, out_buffer_size)
        for node in nodes:
            if isinstance(node, Node):
                pass
            elif isinstance(node, Callable):
                if inspect.isgeneratorfunction(node):
                    node = ProviderFunction(node, function_running_as)
                else:
                    node = WorkerFunction(node, function_running_as)
            else:
                raise ValueError(f"Invalid node type: {type(node)}")
            self.add(node)
        for i in range(len(nodes)):
            self.chain([None, self.nodes[i], None])


class Batch(ConcurrentNode):
    """
    Groups every `batch_size` items into a batch (a list of items) and passes the batch to successive nodes.
    The `patience` parameter specifies the maximum time to wait for a batch to be filled before sending it to the next node,
    i.e., when the earliest item in the batch is out of `patience` seconds, the batch is sent regardless of its size.
    """
    def __init__(self, batch_size: int, patience: float = None, in_buffer_size: int = 1, out_buffer_size: int = 1):
        assert batch_size > 0, "Batch size must be greater than 0."
        super().__init__('thread', in_buffer_size, out_buffer_size)
        self.batch_size = batch_size
        self.patience = patience

    def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
        try:
            while True:
                batch_id, batch_data = [], []
                # Try to fill the batch
                for i in range(self.batch_size):
                    if i == 0 or self.patience is None:
                        timeout = None
                    else:
                        timeout = self.patience - (time.time() - earliest_time)
                        if timeout < 0:
                            break
                    try:
                        item = _get_queue_item(input, terminate_flag, timeout)
                    except Empty:
                        break

                    if i == 0:
                        earliest_time = time.time()
                    batch_data.append(item.data)
                    batch_id.append(item.id)

                batch = _ItemWrapper(batch_data, batch_id)
                _put_queue_item(output, batch, terminate_flag)
        except Terminate:
            return


class Unbatch(ConcurrentNode):
    """
    Ungroups every batch (a list of items) into individual items and passes them to successive nodes.
    """
    def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1):
        super().__init__('thread', in_buffer_size, out_buffer_size)

    def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
        try:
            while True:
                batch = _get_queue_item(input, terminate_flag)
                for id, data in zip(batch.id or itertools.repeat(None), batch.data):
                    item = _ItemWrapper(data, id)
                    _put_queue_item(output, item, terminate_flag)
        except Terminate:
            return


class Buffer(Node):
    "A FIFO node that buffers items in a queue. Usefull achieve better temporal balance when its successor node has a variable processing time."
    def __init__(self, size: int):
        super().__init__(size, size)
        self.size = size
        self.input = self.output = Queue(maxsize=size)