File size: 7,017 Bytes
569f484
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
from multiprocessing import Pool
import os
from typing import Callable, Iterable, Sized

from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task,
                           TaskProgressColumn, TextColumn, TimeRemainingColumn)
from rich.text import Text
import os.path as osp
import portalocker
from ..smp import load, dump


class _Worker:
    """Function wrapper for ``track_progress_rich``"""

    def __init__(self, func) -> None:
        self.func = func

    def __call__(self, inputs):
        inputs, idx = inputs
        if not isinstance(inputs, (tuple, list, dict)):
            inputs = (inputs, )

        if isinstance(inputs, dict):
            return self.func(**inputs), idx
        else:
            return self.func(*inputs), idx


class _SkipFirstTimeRemainingColumn(TimeRemainingColumn):
    """Skip calculating remaining time for the first few times.

    Args:
        skip_times (int): The number of times to skip. Defaults to 0.
    """

    def __init__(self, *args, skip_times=0, **kwargs):
        super().__init__(*args, **kwargs)
        self.skip_times = skip_times

    def render(self, task: Task) -> Text:
        """Show time remaining."""
        if task.completed <= self.skip_times:
            return Text('-:--:--', style='progress.remaining')
        return super().render(task)


def _tasks_with_index(tasks):
    """Add index to tasks."""
    for idx, task in enumerate(tasks):
        yield task, idx


def track_progress_rich(func: Callable,
                        tasks: Iterable = tuple(),
                        task_num: int = None,
                        nproc: int = 1,
                        chunksize: int = 1,
                        description: str = 'Processing',
                        save=None, keys=None,
                        color: str = 'blue') -> list:
    """Track the progress of parallel task execution with a progress bar. The
    built-in :mod:`multiprocessing` module is used for process pools and tasks
    are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.

    Args:
        func (callable): The function to be applied to each task.
        tasks (Iterable or Sized): A tuple of tasks. There are several cases
            for different format tasks:
            - When ``func`` accepts no arguments: tasks should be an empty
              tuple, and ``task_num`` must be specified.
            - When ``func`` accepts only one argument: tasks should be a tuple
              containing the argument.
            - When ``func`` accepts multiple arguments: tasks should be a
              tuple, with each element representing a set of arguments.
              If an element is a ``dict``, it will be parsed as a set of
              keyword-only arguments.
            Defaults to an empty tuple.
        task_num (int, optional): If ``tasks`` is an iterator which does not
            have length, the number of tasks can be provided by ``task_num``.
            Defaults to None.
        nproc (int): Process (worker) number, if nuproc is 1,
            use single process. Defaults to 1.
        chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
            Defaults to 1.
        description (str): The description of progress bar.
            Defaults to "Process".
        color (str): The color of progress bar. Defaults to "blue".

    Examples:
        >>> import time

        >>> def func(x):
        ...    time.sleep(1)
        ...    return x**2
        >>> track_progress_rich(func, range(10), nproc=2)

    Returns:
        list: The task results.
    """
    if save is not None:
        assert osp.exists(osp.dirname(save)) or osp.dirname(save) == ''
        if not osp.exists(save):
            dump({}, save)
    if keys is not None:
        assert len(keys) == len(tasks)

    if not callable(func):
        raise TypeError('func must be a callable object')
    if not isinstance(tasks, Iterable):
        raise TypeError(
            f'tasks must be an iterable object, but got {type(tasks)}')
    if isinstance(tasks, Sized):
        if len(tasks) == 0:
            if task_num is None:
                raise ValueError('If tasks is an empty iterable, '
                                 'task_num must be set')
            else:
                tasks = tuple(tuple() for _ in range(task_num))
        else:
            if task_num is not None and task_num != len(tasks):
                raise ValueError('task_num does not match the length of tasks')
            task_num = len(tasks)

    if nproc <= 0:
        raise ValueError('nproc must be a positive number')

    skip_times = nproc * chunksize if nproc > 1 else 0
    prog_bar = Progress(
        TextColumn('{task.description}'),
        BarColumn(),
        _SkipFirstTimeRemainingColumn(skip_times=skip_times),
        MofNCompleteColumn(),
        TaskProgressColumn(show_speed=True),
    )

    worker = _Worker(func)
    task_id = prog_bar.add_task(
        total=task_num, color=color, description=description)
    tasks = _tasks_with_index(tasks)

    # Use single process when nproc is 1, else use multiprocess.
    with prog_bar:
        if nproc == 1:
            results = []
            for task in tasks:
                result, idx = worker(task)
                results.append(result)
                if save is not None:
                    with portalocker.Lock(save, timeout=5) as fh:
                        ans = load(save)
                        ans[keys[idx]] = result

                        if os.environ.get('VERBOSE', True):
                            print(keys[idx], result, flush=True)

                        dump(ans, save)
                        fh.flush()
                        os.fsync(fh.fileno())

                prog_bar.update(task_id, advance=1, refresh=True)
        else:
            with Pool(nproc) as pool:
                results = []
                unordered_results = []
                gen = pool.imap_unordered(worker, tasks, chunksize)
                try:
                    for result in gen:
                        result, idx = result
                        unordered_results.append((result, idx))

                        if save is not None:
                            with portalocker.Lock(save, timeout=5) as fh:
                                ans = load(save)
                                ans[keys[idx]] = result

                                if os.environ.get('VERBOSE', False):
                                    print(keys[idx], result, flush=True)

                                dump(ans, save)
                                fh.flush()
                                os.fsync(fh.fileno())

                        results.append(None)
                        prog_bar.update(task_id, advance=1, refresh=True)
                except Exception as e:
                    prog_bar.stop()
                    raise e
            for result, idx in unordered_results:
                results[idx] = result
    return results