Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,558 Bytes
84bfd88 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import multiprocessing
import numpy as np
import psutil
from typing import *
def parallel(function: Callable, n_jobs: int, x: List, *args) -> List:
"""Higher order function to run other functions on multiple processes
Simple parallelization utility, slices the input list x in chunks and
executes the function on each chunk in different processes. Not suited
for functions that have already multithreading/processing implemented.
Args:
function: callable to run on different processes
n_jobs: how many cores to use
x: list (M,) to use as input for function
*args: optional arguments for function
Returns:
Object (M,) containing the output of function. Content and type depend
on function. If function returns list, then parallel will also return
a list. If function returns a numpy array, then parallel will return an
array.
"""
# check that parallelization is required. n_jobs might be passed as 1 by
# i.e. Dataset methods if they notice that the loaded HTS is too large
# to be used on different cores.
if n_jobs > 1:
# split list in chunks
chunks = split_list(x, n_jobs)
# create list of tuples containing the chunks and *args
args = stitch_args(chunks, args)
# create multiprocessing pool and run function on chunks
pool = multiprocessing.Pool(n_jobs)
output = pool.starmap(function, args)
pool.close()
# unroll output (list of function outputs) into a single object
# of size M
if isinstance(output[0], list):
unrolled = [x for k in output for x in k]
elif isinstance(output[0], np.ndarray):
unrolled = np.concatenate(output, axis=0)
else:
# run function normally
unrolled = function(x, *args)
return unrolled
def stitch_args(chunks: List[List], args: Tuple) -> List[Tuple]:
"""
Stitches together the chunks to be run in parallel and optional function
arguments into tuples
"""
output = [[x] for x in chunks]
for i in range(len(output)):
for j in range(len(args)):
output[i].append(args[j])
return [tuple(x) for x in output]
def split_list(x: List, n_jobs: int) -> List[List]:
"""
Converts a list into a list of lists of size n_jobs.
"""
idxs = np.array_split(range(len(x)), n_jobs)
output = [0] * n_jobs
for i in range(n_jobs):
output[i] = [x[k] for k in idxs[i]]
return output
|