| | |
| | from multiprocessing import Pool |
| | from typing import Callable, Iterable, Sized |
| |
|
| | from rich.progress import (BarColumn, MofNCompleteColumn, Progress, Task, |
| | TaskProgressColumn, TextColumn, TimeRemainingColumn) |
| | from rich.text import Text |
| |
|
| |
|
| | class _Worker: |
| | """Function wrapper for ``track_progress_rich``""" |
| |
|
| | def __init__(self, func) -> None: |
| | self.func = func |
| |
|
| | def __call__(self, inputs): |
| | inputs, idx = inputs |
| | if not isinstance(inputs, (tuple, list)): |
| | inputs = (inputs, ) |
| |
|
| | return self.func(*inputs), idx |
| |
|
| |
|
| | class _SkipFirstTimeRemainingColumn(TimeRemainingColumn): |
| | """Skip calculating remaining time for the first few times. |
| | |
| | Args: |
| | skip_times (int): The number of times to skip. Defaults to 0. |
| | """ |
| |
|
| | def __init__(self, *args, skip_times=0, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.skip_times = skip_times |
| |
|
| | def render(self, task: Task) -> Text: |
| | """Show time remaining.""" |
| | if task.completed <= self.skip_times: |
| | return Text('-:--:--', style='progress.remaining') |
| | return super().render(task) |
| |
|
| |
|
| | def _tasks_with_index(tasks): |
| | """Add index to tasks.""" |
| | for idx, task in enumerate(tasks): |
| | yield task, idx |
| |
|
| |
|
| | def track_progress_rich(func: Callable, |
| | tasks: Iterable = tuple(), |
| | task_num: int = None, |
| | nproc: int = 1, |
| | chunksize: int = 1, |
| | description: str = 'Processing', |
| | color: str = 'blue') -> list: |
| | """Track the progress of parallel task execution with a progress bar. The |
| | built-in :mod:`multiprocessing` module is used for process pools and tasks |
| | are done with :func:`Pool.map` or :func:`Pool.imap_unordered`. |
| | |
| | Args: |
| | func (callable): The function to be applied to each task. |
| | tasks (Iterable or Sized): A tuple of tasks. There are several cases |
| | for different format tasks: |
| | - When ``func`` accepts no arguments: tasks should be an empty |
| | tuple, and ``task_num`` must be specified. |
| | - When ``func`` accepts only one argument: tasks should be a tuple |
| | containing the argument. |
| | - When ``func`` accepts multiple arguments: tasks should be a |
| | tuple, with each element representing a set of arguments. |
| | If an element is a ``dict``, it will be parsed as a set of |
| | keyword-only arguments. |
| | Defaults to an empty tuple. |
| | task_num (int, optional): If ``tasks`` is an iterator which does not |
| | have length, the number of tasks can be provided by ``task_num``. |
| | Defaults to None. |
| | nproc (int): Process (worker) number, if nuproc is 1, |
| | use single process. Defaults to 1. |
| | chunksize (int): Refer to :class:`multiprocessing.Pool` for details. |
| | Defaults to 1. |
| | description (str): The description of progress bar. |
| | Defaults to "Process". |
| | color (str): The color of progress bar. Defaults to "blue". |
| | |
| | Examples: |
| | >>> import time |
| | |
| | >>> def func(x): |
| | ... time.sleep(1) |
| | ... return x**2 |
| | >>> track_progress_rich(func, range(10), nproc=2) |
| | |
| | Returns: |
| | list: The task results. |
| | """ |
| | if not callable(func): |
| | raise TypeError('func must be a callable object') |
| | if not isinstance(tasks, Iterable): |
| | raise TypeError( |
| | f'tasks must be an iterable object, but got {type(tasks)}') |
| | if isinstance(tasks, Sized): |
| | if len(tasks) == 0: |
| | if task_num is None: |
| | raise ValueError('If tasks is an empty iterable, ' |
| | 'task_num must be set') |
| | else: |
| | tasks = tuple(tuple() for _ in range(task_num)) |
| | else: |
| | if task_num is not None and task_num != len(tasks): |
| | raise ValueError('task_num does not match the length of tasks') |
| | task_num = len(tasks) |
| |
|
| | if nproc <= 0: |
| | raise ValueError('nproc must be a positive number') |
| |
|
| | skip_times = nproc * chunksize if nproc > 1 else 0 |
| | prog_bar = Progress( |
| | TextColumn('{task.description}'), |
| | BarColumn(), |
| | _SkipFirstTimeRemainingColumn(skip_times=skip_times), |
| | MofNCompleteColumn(), |
| | TaskProgressColumn(show_speed=True), |
| | ) |
| |
|
| | worker = _Worker(func) |
| | task_id = prog_bar.add_task( |
| | total=task_num, color=color, description=description) |
| | tasks = _tasks_with_index(tasks) |
| |
|
| | |
| | with prog_bar: |
| | if nproc == 1: |
| | results = [] |
| | for task in tasks: |
| | results.append(worker(task)[0]) |
| | prog_bar.update(task_id, advance=1, refresh=True) |
| | else: |
| | with Pool(nproc) as pool: |
| | results = [] |
| | unordered_results = [] |
| | gen = pool.imap_unordered(worker, tasks, chunksize) |
| | try: |
| | for result in gen: |
| | result, idx = result |
| | unordered_results.append((result, idx)) |
| | results.append(None) |
| | prog_bar.update(task_id, advance=1, refresh=True) |
| | except Exception as e: |
| | prog_bar.stop() |
| | raise e |
| | for result, idx in unordered_results: |
| | results[idx] = result |
| | return results |
| |
|