File size: 2,482 Bytes
f53b39e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# utilitary functions for multiprocessing
# --------------------------------------------------------
from tqdm import tqdm
from multiprocessing.dummy import Pool as ThreadPool
from multiprocessing import cpu_count
def parallel_threads(function, args, workers=0, star_args=False, kw_args=False, front_num=1, Pool=ThreadPool, **tqdm_kw):
""" tqdm but with parallel execution.
Will essentially return
res = [ function(arg) # default
function(*arg) # if star_args is True
function(**arg) # if kw_args is True
for arg in args]
Note:
the <front_num> first elements of args will not be parallelized.
This can be useful for debugging.
"""
while workers <= 0:
workers += cpu_count()
if workers == 1:
front_num = float('inf')
# convert into an iterable
try:
n_args_parallel = len(args) - front_num
except TypeError:
n_args_parallel = None
args = iter(args)
# sequential execution first
front = []
while len(front) < front_num:
try:
a = next(args)
except StopIteration:
return front # end of the iterable
front.append(function(*a) if star_args else function(**a) if kw_args else function(a))
# then parallel execution
out = []
with Pool(workers) as pool:
# Pass the elements of args into function
if star_args:
futures = pool.imap(starcall, [(function, a) for a in args])
elif kw_args:
futures = pool.imap(starstarcall, [(function, a) for a in args])
else:
futures = pool.imap(function, args)
# Print out the progress as tasks complete
for f in tqdm(futures, total=n_args_parallel, **tqdm_kw):
out.append(f)
return front + out
def parallel_processes(*args, **kwargs):
""" Same as parallel_threads, with processes
"""
import multiprocessing as mp
kwargs['Pool'] = mp.Pool
return parallel_threads(*args, **kwargs)
def starcall(args):
""" convenient wrapper for Process.Pool """
function, args = args
return function(*args)
def starstarcall(args):
""" convenient wrapper for Process.Pool """
function, args = args
return function(**args)
|