import asyncio import random import time import unittest import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from pipeline import Pipeline, Node, Job class Node1(Node): async def process_job(self, job: Job): job.data += f' (processed by node 1, worker {self.worker_id})' yield job class Node2(Node): async def process_job(self, job: Job): sleep_duration = 0.08 + 0.04 * random.random() await asyncio.sleep(sleep_duration) job.data += f' (processed by node 2, worker {self.worker_id})' yield job class Node3(Node): async def process_job(self, job: Job): job.data += f' (processed by node 3, worker {self.worker_id})' print(f'{job.id} - {job.data}') yield job class TestPipeline(unittest.TestCase): def setUp(self): pass async def _test_pipeline_edge_cases(self): # must have a input queue with self.assertRaises(ValueError): await self.pipeline.add_node(Node1, 1, None, None) # too output queue must not equal from input queue node1_queue = asyncio.Queue() with self.assertRaises(ValueError): await self.pipeline.add_node(Node1, 1, node1_queue, node1_queue) async def _test_pipeline(self, num_jobs): node1_queue = asyncio.Queue() node2_queue = asyncio.Queue() node3_queue = asyncio.Queue() await self.pipeline.add_node(Node1, 1, node1_queue, node2_queue) await self.pipeline.add_node(Node2, 5, node2_queue, node3_queue) await self.pipeline.add_node(Node3, 1, node3_queue, job_sync=self.job_sync, sequential_node=True) for i in range(num_jobs): job = Job("") await self.pipeline.enqueue_job(job) while True: if len(self.job_sync) == num_jobs: break await asyncio.sleep(0.1) await self.pipeline.close() def test_pipeline_edge_cases(self): self.pipeline = Pipeline() self.job_sync = [] asyncio.run(self._test_pipeline_edge_cases()) def test_pipeline_keeps_order(self): self.pipeline = Pipeline() self.job_sync = [] num_jobs = 100 start_time = time.time() asyncio.run(self._test_pipeline(num_jobs)) end_time = time.time() print(f"Pipeline processed in {end_time - start_time} seconds.") self.assertEqual(len(self.job_sync), num_jobs) for i, job in enumerate(self.job_sync): self.assertEqual(i, job.id) if __name__ == '__main__': unittest.main() # test = TestPipeline() # test.setUp() # test.test_pipeline()