import asyncio import traceback class Job: def __init__(self, data): self._id = None self.data = data class Node: # def __init__(self, worker_id: int, input_queue, output_queue, buffer=None, job_sync=None): def __init__(self, worker_id: int, input_queue, output_queue=None, job_sync=None, sequential_node=False ): self.worker_id = worker_id self.input_queue = input_queue self.output_queue = output_queue self.buffer = {} self.job_sync = job_sync self.sequential_node = sequential_node self.next_i = 0 self._jobs_dequeued = 0 self._jobs_processed = 0 # throw an error if job_sync is not None and sequential_node is False if self.job_sync is not None and self.sequential_node == False: raise ValueError('job_sync is not None and sequential_node is False') async def run(self): try: while True: job: Job = await self.input_queue.get() self._jobs_dequeued += 1 if self.sequential_node == False: async for job in self.process_job(job): if self.output_queue is not None: await self.output_queue.put(job) if self.job_sync is not None: self.job_sync.append(job) self._jobs_processed += 1 else: # ensure that jobs are processed in order self.buffer[job.id] = job while self.next_i in self.buffer: job = self.buffer.pop(self.next_i) async for job in self.process_job(job): if self.output_queue is not None: await self.output_queue.put(job) if self.job_sync is not None: self.job_sync.append(job) self._jobs_processed += 1 self.next_i += 1 except Exception as e: print(f"An error occurred in node: {self.__class__.__name__} worker: {self.worker_id}: {e}") traceback.print_exc() raise # Re-raises the last exception. async def process_job(self, job: Job): raise NotImplementedError() class Pipeline: def __init__(self): self.input_queues = [] self.root_queue = None # self.output_queues = [] # self.job_sysncs = [] self.nodes= [] self.node_workers = {} self.tasks = [] self._job_id = 0 async def add_node(self, node: Node, num_workers=1, input_queue=None, output_queue=None, job_sync=None, sequential_node=False ): # input_queue must not be None if input_queue is None: raise ValueError('input_queue is None') # job_sync nodes must be sequential_nodes if job_sync is not None and sequential_node == False: raise ValueError('job_sync is not None and sequential_node is False') # sequential_nodes should one have 1 worker if sequential_node == True and num_workers != 1: raise ValueError('sequentaial nodes can only have one node (sequential_node is True and num_workers is not 1)') # output queue must not equal input_queue if output_queue == input_queue: raise ValueError('output_queue must not be the same as input_queue') node_name = node.__name__ if node_name not in self.nodes: self.nodes.append(node_name) # if input_queue is None then this is the root node if len(self.input_queues) == 0: self.root_queue = input_queue self.input_queues.append(input_queue) for i in range(num_workers): worker_id = i node_worker = node(worker_id, input_queue, output_queue, job_sync, sequential_node) if node_name not in self.node_workers: self.node_workers[node_name] = [] self.node_workers[node_name].append(node_worker) task = asyncio.create_task(node_worker.run()) self.tasks.append(task) async def enqueue_job(self, job: Job): job.id = self._job_id self._job_id += 1 await self.root_queue.put(job) async def close(self): for task in self.tasks: task.cancel() await asyncio.gather(*self.tasks, return_exceptions=True)