project_charles / pipeline.py
sohojoe's picture
basic POC
162d5c8
raw
history blame
No virus
4.53 kB
import asyncio
import traceback
class Job:
def __init__(self, data):
self._id = None
self.data = data
class Node:
# def __init__(self, worker_id: int, input_queue, output_queue, buffer=None, job_sync=None):
def __init__(self, worker_id: int, input_queue, output_queue=None, job_sync=None, sequential_node=False ):
self.worker_id = worker_id
self.input_queue = input_queue
self.output_queue = output_queue
self.buffer = {}
self.job_sync = job_sync
self.sequential_node = sequential_node
self.next_i = 0
self._jobs_dequeued = 0
self._jobs_processed = 0
# throw an error if job_sync is not None and sequential_node is False
if self.job_sync is not None and self.sequential_node == False:
raise ValueError('job_sync is not None and sequential_node is False')
async def run(self):
try:
while True:
job: Job = await self.input_queue.get()
self._jobs_dequeued += 1
if self.sequential_node == False:
async for job in self.process_job(job):
if self.output_queue is not None:
await self.output_queue.put(job)
if self.job_sync is not None:
self.job_sync.append(job)
self._jobs_processed += 1
else:
# ensure that jobs are processed in order
self.buffer[job.id] = job
while self.next_i in self.buffer:
job = self.buffer.pop(self.next_i)
async for job in self.process_job(job):
if self.output_queue is not None:
await self.output_queue.put(job)
if self.job_sync is not None:
self.job_sync.append(job)
self._jobs_processed += 1
self.next_i += 1
except Exception as e:
print(f"An error occurred in node: {self.__class__.__name__} worker: {self.worker_id}: {e}")
traceback.print_exc()
raise # Re-raises the last exception.
async def process_job(self, job: Job):
raise NotImplementedError()
class Pipeline:
def __init__(self):
self.input_queues = []
self.root_queue = None
# self.output_queues = []
# self.job_sysncs = []
self.nodes= []
self.node_workers = {}
self.tasks = []
self._job_id = 0
async def add_node(self, node: Node, num_workers=1, input_queue=None, output_queue=None, job_sync=None, sequential_node=False ):
# input_queue must not be None
if input_queue is None:
raise ValueError('input_queue is None')
# job_sync nodes must be sequential_nodes
if job_sync is not None and sequential_node == False:
raise ValueError('job_sync is not None and sequential_node is False')
# sequential_nodes should one have 1 worker
if sequential_node == True and num_workers != 1:
raise ValueError('sequentaial nodes can only have one node (sequential_node is True and num_workers is not 1)')
# output queue must not equal input_queue
if output_queue == input_queue:
raise ValueError('output_queue must not be the same as input_queue')
node_name = node.__name__
if node_name not in self.nodes:
self.nodes.append(node_name)
# if input_queue is None then this is the root node
if len(self.input_queues) == 0:
self.root_queue = input_queue
self.input_queues.append(input_queue)
for i in range(num_workers):
worker_id = i
node_worker = node(worker_id, input_queue, output_queue, job_sync, sequential_node)
if node_name not in self.node_workers:
self.node_workers[node_name] = []
self.node_workers[node_name].append(node_worker)
task = asyncio.create_task(node_worker.run())
self.tasks.append(task)
async def enqueue_job(self, job: Job):
job.id = self._job_id
self._job_id += 1
await self.root_queue.put(job)
async def close(self):
for task in self.tasks:
task.cancel()
await asyncio.gather(*self.tasks, return_exceptions=True)