Joshua Lochner commited on
Commit
07690ba
1 Parent(s): 63f1925

Improve parallel execution of functions

Browse files
Files changed (1) hide show
  1. src/utils.py +153 -65
src/utils.py CHANGED
@@ -1,92 +1,180 @@
1
  import re
2
- import asyncio
3
  import os
 
 
 
 
 
 
 
 
 
 
 
4
 
 
 
5
 
6
- class Job:
 
7
  def __init__(self, function, *args, **kwargs) -> None:
8
  self.function = function
9
  self.args = args
10
  self.kwargs = kwargs
11
 
12
- self.result = None
 
 
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- class InterruptibleThreadPool:
 
 
16
  def __init__(self,
 
17
  num_workers=None,
18
- loop=None,
19
- shutdown_message='\nAttempting graceful shutdown, press Ctrl+C again to exit...',
20
- on_job_complete=None, # Useful for monitoring progress
21
- raise_after_interrupt=False,
22
- ) -> None:
 
 
 
 
23
  self.num_workers = os.cpu_count() if num_workers is None else num_workers
24
- self.loop = asyncio.get_event_loop() if loop is None else loop
25
- self.shutdown_message = shutdown_message
26
 
27
- self.sem = asyncio.Semaphore(num_workers)
 
 
28
 
29
- self.jobs = []
 
 
 
 
30
 
31
- self.on_job_complete = on_job_complete
32
- self.raise_after_interrupt = raise_after_interrupt
33
 
34
- async def _sync_to_async(self, job):
35
- async with self.sem: # Limit number of parallel tasks
36
- job.result = await self.loop.run_in_executor(None, job.function, *job.args, **job.kwargs)
 
37
 
38
- if callable(self.on_job_complete):
39
- self.on_job_complete(job)
40
 
41
- return job
42
 
43
- def add_job(self, job):
44
- self.jobs.append(job)
 
45
 
46
- def run(self):
47
- try:
48
- tasks = [
49
- # creating task starts coroutine
50
- asyncio.ensure_future(self._sync_to_async(job))
51
- for job in self.jobs
52
- ]
53
 
54
- # https://stackoverflow.com/a/42097478
55
- self.loop.run_until_complete(
56
- asyncio.gather(*tasks, return_exceptions=True)
57
- )
58
 
59
- except KeyboardInterrupt:
60
- # Optionally show a message if the shutdown may take a while
61
- print(self.shutdown_message, flush=True)
62
-
63
- # Do not show `asyncio.CancelledError` exceptions during shutdown
64
- # (a lot of these may be generated, skip this if you prefer to see them)
65
- def shutdown_exception_handler(loop, context):
66
- if "exception" not in context \
67
- or not isinstance(context["exception"], asyncio.CancelledError):
68
- loop.default_exception_handler(context)
69
- self.loop.set_exception_handler(shutdown_exception_handler)
70
-
71
- # Handle shutdown gracefully by waiting for all tasks to be cancelled
72
- cancelled_tasks = asyncio.gather(
73
- *asyncio.all_tasks(loop=self.loop), loop=self.loop, return_exceptions=True)
74
- cancelled_tasks.add_done_callback(lambda t: self.loop.stop())
75
- cancelled_tasks.cancel()
76
-
77
- # Keep the event loop running until it is either destroyed or all
78
- # tasks have really terminated
79
- while not cancelled_tasks.done() and not self.loop.is_closed():
80
- self.loop.run_forever()
81
-
82
- if self.raise_after_interrupt:
83
- raise
84
- finally:
85
- self.loop.run_until_complete(self.loop.shutdown_asyncgens())
86
- self.loop.close()
87
 
88
- return self.jobs
 
 
 
 
 
 
 
 
89
 
 
 
 
90
 
91
- def re_findall(pattern, string):
92
- return [m.groupdict() for m in re.finditer(pattern, string)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
+
3
  import os
4
+ import signal
5
+ import logging
6
+ import sys
7
+ from time import sleep, time
8
+ from random import random, randint
9
+ from multiprocessing import JoinableQueue, Event, Process
10
+ from queue import Empty
11
+ from typing import Optional
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
 
16
+ def re_findall(pattern, string):
17
+ return [m.groupdict() for m in re.finditer(pattern, string)]
18
 
19
+
20
+ class Task:
21
  def __init__(self, function, *args, **kwargs) -> None:
22
  self.function = function
23
  self.args = args
24
  self.kwargs = kwargs
25
 
26
+ def run(self):
27
+ return self.function(*self.args, **self.kwargs)
28
+
29
+
30
 
31
+ class CallbackGenerator:
32
+ def __init__(self, generator, callback):
33
+ self.generator = generator
34
+ self.callback = callback
35
+
36
+ def __iter__(self):
37
+ if self.callback is not None and callable(self.callback):
38
+ for t in self.generator:
39
+ self.callback(t)
40
+ yield t
41
+ else:
42
+ yield from self.generator
43
+
44
+
45
+
46
+ def start_worker(q: JoinableQueue, stop_event: Event): # TODO make class?
47
+ logger.info('Starting worker...')
48
+ while True:
49
+ if stop_event.is_set():
50
+ logger.info('Worker exiting because of stop_event')
51
+ break
52
+ # We set a timeout so we loop past 'stop_event' even if the queue is empty
53
+ try:
54
+ task = q.get(timeout=.01)
55
+ except Empty:
56
+ # Run next iteration of loop
57
+ continue
58
+
59
+ # Exit if end of queue
60
+ if task is None:
61
+ logger.info('Worker exiting because of None on queue')
62
+ q.task_done()
63
+ break
64
+
65
+ try:
66
+ task.run() # Do the task
67
+ except: # Will also catch KeyboardInterrupt
68
+ logger.exception(f'Failed to process task {task}', )
69
+ # Can implement some kind of retry handling here
70
+ finally:
71
+ q.task_done()
72
 
73
+ class InterruptibleTaskPool:
74
+
75
+ # https://the-fonz.gitlab.io/posts/python-multiprocessing/
76
  def __init__(self,
77
+ tasks=None,
78
  num_workers=None,
79
+
80
+ callback=None, # Fired on start
81
+ max_queue_size=1,
82
+ grace_period=2,
83
+ kill_period=30,
84
+ ):
85
+
86
+ self.tasks = CallbackGenerator(
87
+ [] if tasks is None else tasks, callback)
88
  self.num_workers = os.cpu_count() if num_workers is None else num_workers
 
 
89
 
90
+ self.max_queue_size = max_queue_size
91
+ self.grace_period = grace_period
92
+ self.kill_period = kill_period
93
 
94
+ # The JoinableQueue has an internal counter that increments when an item is put on the queue and
95
+ # decrements when q.task_done() is called. This allows us to wait until it's empty using .join()
96
+ self.queue = JoinableQueue(maxsize=self.max_queue_size)
97
+ # This is a process-safe version of the 'panic' variable shown above
98
+ self.stop_event = Event()
99
 
 
 
100
 
101
+ # n_workers: Start this many processes
102
+ # max_queue_size: If queue exceeds this size, block when putting items on the queue
103
+ # grace_period: Send SIGINT to processes if they don't exit within this time after SIGINT/SIGTERM
104
+ # kill_period: Send SIGKILL to processes if they don't exit after this many seconds
105
 
106
+ # self.on_task_complete = on_task_complete
107
+ # self.raise_after_interrupt = raise_after_interrupt
108
 
 
109
 
110
+ def __enter__(self):
111
+ self.start()
112
+ return self
113
 
114
+ def __exit__(self, exc_type, exc_value, exc_traceback):
115
+ pass
 
 
 
 
 
116
 
 
 
 
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ def start(self) -> None:
120
+ def handler(signalname):
121
+ """
122
+ Python 3.9 has `signal.strsignal(signalnum)` so this closure would not be needed.
123
+ Also, 3.8 includes `signal.valid_signals()` that can be used to create a mapping for the same purpose.
124
+ """
125
+ def f(signal_received, frame):
126
+ raise KeyboardInterrupt(f'{signalname} received')
127
+ return f
128
 
129
+ # This will be inherited by the child process if it is forked (not spawned)
130
+ signal.signal(signal.SIGINT, handler('SIGINT'))
131
+ signal.signal(signal.SIGTERM, handler('SIGTERM'))
132
 
133
+ procs = []
134
+
135
+ for i in range(self.num_workers):
136
+ # Make it a daemon process so it is definitely terminated when this process exits,
137
+ # might be overkill but is a nice feature. See
138
+ # https://docs.python.org/3.8/library/multiprocessing.html#multiprocessing.Process.daemon
139
+ p = Process(name=f'Worker-{i:02d}', daemon=True,
140
+ target=start_worker, args=(self.queue, self.stop_event))
141
+ procs.append(p)
142
+ p.start()
143
+
144
+ try:
145
+ # Put tasks on queue
146
+ for task in self.tasks:
147
+ logger.info(f'Put task {task} on queue')
148
+ self.queue.put(task)
149
+
150
+ # Put exit tasks on queue
151
+ for i in range(self.num_workers):
152
+ self.queue.put(None)
153
+
154
+ # Wait until all tasks are processed
155
+ self.queue.join()
156
+
157
+ except KeyboardInterrupt:
158
+ logger.warning('Caught KeyboardInterrupt! Setting stop event...')
159
+ # raise # TODO add option
160
+ finally:
161
+ self.stop_event.set()
162
+ t = time()
163
+ # Send SIGINT if process doesn't exit quickly enough, and kill it as last resort
164
+ # .is_alive() also implicitly joins the process (good practice in linux)
165
+ while alive_procs := [p for p in procs if p.is_alive()]:
166
+ if time() > t + self.grace_period:
167
+ for p in alive_procs:
168
+ os.kill(p.pid, signal.SIGINT)
169
+ logger.warning(f'Sending SIGINT to {p}')
170
+ elif time() > t + self.kill_period:
171
+ for p in alive_procs:
172
+ logger.warning(f'Sending SIGKILL to {p}')
173
+ # Queues and other inter-process communication primitives can break when
174
+ # process is killed, but we don't care here
175
+ p.kill()
176
+ sleep(.01)
177
+
178
+ sleep(.1)
179
+ for p in procs:
180
+ logger.info(f'Process status: {p}')