Spaces:
Sleeping
Sleeping
File size: 2,717 Bytes
ed232fa 572f2e6 ed232fa 8c9e2db ed232fa 8c9e2db ed232fa 8c9e2db ed232fa 8c9e2db ed232fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import asyncio
import random
import time
import unittest
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from legacy_to_delete.pipeline import Pipeline, Node, Job
class Node1(Node):
async def process_job(self, job: Job):
job.data += f' (processed by node 1, worker {self.worker_id})'
yield job
class Node2(Node):
async def process_job(self, job: Job):
sleep_duration = 0.08 + 0.04 * random.random()
await asyncio.sleep(sleep_duration)
job.data += f' (processed by node 2, worker {self.worker_id})'
yield job
class Node3(Node):
async def process_job(self, job: Job):
job.data += f' (processed by node 3, worker {self.worker_id})'
print(f'{job.id} - {job.data}')
yield job
class TestPipeline(unittest.TestCase):
def setUp(self):
pass
async def _test_pipeline_edge_cases(self):
# must have a input queue
with self.assertRaises(ValueError):
await self.pipeline.add_node(Node1, 1, None, None)
# too output queue must not equal from input queue
node1_queue = asyncio.Queue()
with self.assertRaises(ValueError):
await self.pipeline.add_node(Node1, 1, node1_queue, node1_queue)
async def _test_pipeline(self, num_jobs):
node1_queue = asyncio.Queue()
node2_queue = asyncio.Queue()
node3_queue = asyncio.Queue()
await self.pipeline.add_node(Node1, 1, node1_queue, node2_queue)
await self.pipeline.add_node(Node2, 5, node2_queue, node3_queue)
await self.pipeline.add_node(Node3, 1, node3_queue, job_sync=self.job_sync, sequential_node=True)
for i in range(num_jobs):
job = Job("")
await self.pipeline.enqueue_job(job)
while True:
if len(self.job_sync) == num_jobs:
break
await asyncio.sleep(0.1)
await self.pipeline.close()
def test_pipeline_edge_cases(self):
self.pipeline = Pipeline()
self.job_sync = []
asyncio.run(self._test_pipeline_edge_cases())
def test_pipeline_keeps_order(self):
self.pipeline = Pipeline()
self.job_sync = []
num_jobs = 100
start_time = time.time()
asyncio.run(self._test_pipeline(num_jobs))
end_time = time.time()
print(f"Pipeline processed in {end_time - start_time} seconds.")
self.assertEqual(len(self.job_sync), num_jobs)
for i, job in enumerate(self.job_sync):
self.assertEqual(i, job.id)
if __name__ == '__main__':
unittest.main()
# test = TestPipeline()
# test.setUp()
# test.test_pipeline() |