Spaces:
Runtime error
Runtime error
add: pipelines nodes can now spawn one to many jobs via yield
Browse files- pipeline.py +12 -12
- tests/test_pipeline.py +14 -11
pipeline.py
CHANGED
@@ -27,24 +27,24 @@ class Node:
|
|
27 |
job: Job = await self.input_queue.get()
|
28 |
self._jobs_dequeued += 1
|
29 |
if self.sequential_node == False:
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
else:
|
37 |
# ensure that jobs are processed in order
|
38 |
self.buffer[job.id] = job
|
39 |
while self.next_i in self.buffer:
|
40 |
job = self.buffer.pop(self.next_i)
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
42 |
self.next_i += 1
|
43 |
-
if self.output_queue is not None:
|
44 |
-
await self.output_queue.put(job)
|
45 |
-
if self.job_sync is not None:
|
46 |
-
self.job_sync.append(job)
|
47 |
-
self._jobs_processed += 1
|
48 |
|
49 |
async def process_job(self, job: Job):
|
50 |
raise NotImplementedError()
|
|
|
27 |
job: Job = await self.input_queue.get()
|
28 |
self._jobs_dequeued += 1
|
29 |
if self.sequential_node == False:
|
30 |
+
async for job in self.process_job(job):
|
31 |
+
if self.output_queue is not None:
|
32 |
+
await self.output_queue.put(job)
|
33 |
+
if self.job_sync is not None:
|
34 |
+
self.job_sync.append(job)
|
35 |
+
self._jobs_processed += 1
|
36 |
else:
|
37 |
# ensure that jobs are processed in order
|
38 |
self.buffer[job.id] = job
|
39 |
while self.next_i in self.buffer:
|
40 |
job = self.buffer.pop(self.next_i)
|
41 |
+
async for job in self.process_job(job):
|
42 |
+
if self.output_queue is not None:
|
43 |
+
await self.output_queue.put(job)
|
44 |
+
if self.job_sync is not None:
|
45 |
+
self.job_sync.append(job)
|
46 |
+
self._jobs_processed += 1
|
47 |
self.next_i += 1
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
async def process_job(self, job: Job):
|
50 |
raise NotImplementedError()
|
tests/test_pipeline.py
CHANGED
@@ -12,6 +12,7 @@ from pipeline import Pipeline, Node, Job
|
|
12 |
class Node1(Node):
|
13 |
async def process_job(self, job: Job):
|
14 |
job.data += f' (processed by node 1, worker {self.worker_id})'
|
|
|
15 |
|
16 |
|
17 |
class Node2(Node):
|
@@ -19,12 +20,14 @@ class Node2(Node):
|
|
19 |
sleep_duration = 0.08 + 0.04 * random.random()
|
20 |
await asyncio.sleep(sleep_duration)
|
21 |
job.data += f' (processed by node 2, worker {self.worker_id})'
|
|
|
22 |
|
23 |
|
24 |
class Node3(Node):
|
25 |
async def process_job(self, job: Job):
|
26 |
job.data += f' (processed by node 3, worker {self.worker_id})'
|
27 |
print(f'{job.id} - {job.data}')
|
|
|
28 |
|
29 |
|
30 |
class TestPipeline(unittest.TestCase):
|
@@ -63,17 +66,17 @@ class TestPipeline(unittest.TestCase):
|
|
63 |
asyncio.run(self._test_pipeline_edge_cases())
|
64 |
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
|
78 |
|
79 |
if __name__ == '__main__':
|
|
|
12 |
class Node1(Node):
|
13 |
async def process_job(self, job: Job):
|
14 |
job.data += f' (processed by node 1, worker {self.worker_id})'
|
15 |
+
yield job
|
16 |
|
17 |
|
18 |
class Node2(Node):
|
|
|
20 |
sleep_duration = 0.08 + 0.04 * random.random()
|
21 |
await asyncio.sleep(sleep_duration)
|
22 |
job.data += f' (processed by node 2, worker {self.worker_id})'
|
23 |
+
yield job
|
24 |
|
25 |
|
26 |
class Node3(Node):
|
27 |
async def process_job(self, job: Job):
|
28 |
job.data += f' (processed by node 3, worker {self.worker_id})'
|
29 |
print(f'{job.id} - {job.data}')
|
30 |
+
yield job
|
31 |
|
32 |
|
33 |
class TestPipeline(unittest.TestCase):
|
|
|
66 |
asyncio.run(self._test_pipeline_edge_cases())
|
67 |
|
68 |
|
69 |
+
def test_pipeline_keeps_order(self):
|
70 |
+
self.pipeline = Pipeline()
|
71 |
+
self.job_sync = []
|
72 |
+
num_jobs = 100
|
73 |
+
start_time = time.time()
|
74 |
+
asyncio.run(self._test_pipeline(num_jobs))
|
75 |
+
end_time = time.time()
|
76 |
+
print(f"Pipeline processed in {end_time - start_time} seconds.")
|
77 |
+
self.assertEqual(len(self.job_sync), num_jobs)
|
78 |
+
for i, job in enumerate(self.job_sync):
|
79 |
+
self.assertEqual(i, job.id)
|
80 |
|
81 |
|
82 |
if __name__ == '__main__':
|