import multiprocessing import numpy as np import psutil from typing import * def parallel(function: Callable, n_jobs: int, x: List, *args) -> List: """Higher order function to run other functions on multiple processes Simple parallelization utility, slices the input list x in chunks and executes the function on each chunk in different processes. Not suited for functions that have already multithreading/processing implemented. Args: function: callable to run on different processes n_jobs: how many cores to use x: list (M,) to use as input for function *args: optional arguments for function Returns: Object (M,) containing the output of function. Content and type depend on function. If function returns list, then parallel will also return a list. If function returns a numpy array, then parallel will return an array. """ # check that parallelization is required. n_jobs might be passed as 1 by # i.e. Dataset methods if they notice that the loaded HTS is too large # to be used on different cores. if n_jobs > 1: # split list in chunks chunks = split_list(x, n_jobs) # create list of tuples containing the chunks and *args args = stitch_args(chunks, args) # create multiprocessing pool and run function on chunks pool = multiprocessing.Pool(n_jobs) output = pool.starmap(function, args) pool.close() # unroll output (list of function outputs) into a single object # of size M if isinstance(output[0], list): unrolled = [x for k in output for x in k] elif isinstance(output[0], np.ndarray): unrolled = np.concatenate(output, axis=0) else: # run function normally unrolled = function(x, *args) return unrolled def stitch_args(chunks: List[List], args: Tuple) -> List[Tuple]: """ Stitches together the chunks to be run in parallel and optional function arguments into tuples """ output = [[x] for x in chunks] for i in range(len(output)): for j in range(len(args)): output[i].append(args[j]) return [tuple(x) for x in output] def split_list(x: List, n_jobs: int) -> List[List]: """ Converts a list into a list of lists of size n_jobs. """ idxs = np.array_split(range(len(x)), n_jobs) output = [0] * n_jobs for i in range(n_jobs): output[i] = [x[k] for k in idxs[i]] return output