project_charles / pipeline_test.py
sohojoe's picture
added pipeline test for a simple pipeline
80eea9e
raw
history blame
2.48 kB
import asyncio
import random
import time
class Job:
def __init__(self, id, data):
self.id = id
self.data = data
async def node1(worker_id: int, input_queue, output_queue):
while True:
job:Job = await input_queue.get()
job.data += f' (processed by node 1, worker {worker_id})'
await output_queue.put(job)
async def node2(worker_id: int, input_queue, output_queue):
while True:
job:Job = await input_queue.get()
sleep_duration = 0.8 + 0.4 * random.random() # Generate a random sleep duration between 0.8 and 1.2 seconds
await asyncio.sleep(sleep_duration)
job.data += f' (processed by node 2, worker {worker_id})'
await output_queue.put(job)
async def node3(worker_id: int, input_queue, job_sync):
buffer = {}
next_i = 0
while True:
job:Job = await input_queue.get()
buffer[job.id] = job # Store the data in the buffer
# While the next expected item is in the buffer, output it and increment the index
while next_i in buffer:
curr_job = buffer.pop(next_i)
curr_job.data += f' (processed by node 3, worker {worker_id})'
print(f'{curr_job.id} - {curr_job.data}')
next_i += 1
job_sync.append(curr_job)
async def main():
input_queue = asyncio.Queue()
buffer_queue = asyncio.Queue()
output_queue = asyncio.Queue()
num_jobs = 100
joe_source = [Job(i, "") for i in range(num_jobs)]
job_sync = []
task1 = asyncio.create_task(node1(None, input_queue, buffer_queue))
task3 = asyncio.create_task(node3(None, output_queue, job_sync))
num_workers = 5
tasks2 = []
for i in range(num_workers):
task2 = asyncio.create_task(node2(i + 1, buffer_queue, output_queue))
tasks2.append(task2)
for job in joe_source:
await input_queue.put(job)
try:
# await asyncio.gather(task1, *tasks2, task3)
while len(job_sync) < num_jobs:
await asyncio.sleep(0.1)
except asyncio.CancelledError:
print("Pipeline cancelled")
task1.cancel()
for task in tasks2:
task.cancel()
task3.cancel()
await asyncio.gather(task1, *tasks2, task3, return_exceptions=True)
start_time = time.time()
try:
asyncio.run(main())
except KeyboardInterrupt:
print("Pipeline interrupted by user")
end_time = time.time()
print(f"Pipeline processed in {end_time - start_time} seconds.")