File size: 2,558 Bytes
84bfd88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import multiprocessing
import numpy as np
import psutil
from typing import *


def parallel(function: Callable, n_jobs: int, x: List, *args) -> List:
    """Higher order function to run other functions on multiple processes

    Simple parallelization utility, slices the input list x in chunks and
    executes the function on each chunk in different processes. Not suited
    for functions that have already multithreading/processing implemented.

    Args:
        function:   callable to run on different processes
        n_jobs:     how many cores to use
        x:          list (M,) to use as input for function
        *args:      optional arguments for function

    Returns:
        Object (M,) containing the output of function. Content and type depend
        on function. If function returns list, then parallel will also return
        a list. If function returns a numpy array, then parallel will return an
        array.
    """

    # check that parallelization is required. n_jobs might be passed as 1 by
    # i.e. Dataset methods if they notice that the loaded HTS is too large
    # to be used on different cores.
    if n_jobs > 1:
        # split list in chunks
        chunks = split_list(x, n_jobs)

        # create list of tuples containing the chunks and *args
        args = stitch_args(chunks, args)

        # create multiprocessing pool and run function on chunks
        pool = multiprocessing.Pool(n_jobs)
        output = pool.starmap(function, args)
        pool.close()

        # unroll output (list of function outputs) into a single object
        # of size M
        if isinstance(output[0], list):
            unrolled = [x for k in output for x in k]
        elif isinstance(output[0], np.ndarray):
            unrolled = np.concatenate(output, axis=0)

    else:
        # run function normally
        unrolled = function(x, *args)

    return unrolled


def stitch_args(chunks: List[List], args: Tuple) -> List[Tuple]:
    """
    Stitches together the chunks to be run in parallel and optional function
    arguments into tuples
    """
    output = [[x] for x in chunks]
    for i in range(len(output)):
        for j in range(len(args)):
            output[i].append(args[j])

    return [tuple(x) for x in output]


def split_list(x: List, n_jobs: int) -> List[List]:
    """
    Converts a list into a list of lists of size n_jobs.
    """
    idxs = np.array_split(range(len(x)), n_jobs)
    output = [0] * n_jobs
    for i in range(n_jobs):
        output[i] = [x[k] for k in idxs[i]]

    return output