Spaces:
Runtime error
Runtime error
import asyncio | |
import random | |
import time | |
import unittest | |
import sys | |
import os | |
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
from legacy_to_delete.pipeline import Pipeline, Node, Job | |
class Node1(Node): | |
async def process_job(self, job: Job): | |
job.data += f' (processed by node 1, worker {self.worker_id})' | |
yield job | |
class Node2(Node): | |
async def process_job(self, job: Job): | |
sleep_duration = 0.08 + 0.04 * random.random() | |
await asyncio.sleep(sleep_duration) | |
job.data += f' (processed by node 2, worker {self.worker_id})' | |
yield job | |
class Node3(Node): | |
async def process_job(self, job: Job): | |
job.data += f' (processed by node 3, worker {self.worker_id})' | |
print(f'{job.id} - {job.data}') | |
yield job | |
class TestPipeline(unittest.TestCase): | |
def setUp(self): | |
pass | |
async def _test_pipeline_edge_cases(self): | |
# must have a input queue | |
with self.assertRaises(ValueError): | |
await self.pipeline.add_node(Node1, 1, None, None) | |
# too output queue must not equal from input queue | |
node1_queue = asyncio.Queue() | |
with self.assertRaises(ValueError): | |
await self.pipeline.add_node(Node1, 1, node1_queue, node1_queue) | |
async def _test_pipeline(self, num_jobs): | |
node1_queue = asyncio.Queue() | |
node2_queue = asyncio.Queue() | |
node3_queue = asyncio.Queue() | |
await self.pipeline.add_node(Node1, 1, node1_queue, node2_queue) | |
await self.pipeline.add_node(Node2, 5, node2_queue, node3_queue) | |
await self.pipeline.add_node(Node3, 1, node3_queue, job_sync=self.job_sync, sequential_node=True) | |
for i in range(num_jobs): | |
job = Job("") | |
await self.pipeline.enqueue_job(job) | |
while True: | |
if len(self.job_sync) == num_jobs: | |
break | |
await asyncio.sleep(0.1) | |
await self.pipeline.close() | |
def test_pipeline_edge_cases(self): | |
self.pipeline = Pipeline() | |
self.job_sync = [] | |
asyncio.run(self._test_pipeline_edge_cases()) | |
def test_pipeline_keeps_order(self): | |
self.pipeline = Pipeline() | |
self.job_sync = [] | |
num_jobs = 100 | |
start_time = time.time() | |
asyncio.run(self._test_pipeline(num_jobs)) | |
end_time = time.time() | |
print(f"Pipeline processed in {end_time - start_time} seconds.") | |
self.assertEqual(len(self.job_sync), num_jobs) | |
for i, job in enumerate(self.job_sync): | |
self.assertEqual(i, job.id) | |
if __name__ == '__main__': | |
unittest.main() | |
# test = TestPipeline() | |
# test.setUp() | |
# test.test_pipeline() |