Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import sys | |
| from collections.abc import Iterable | |
| from multiprocessing import Pool | |
| from shutil import get_terminal_size | |
| from typing import Callable, Sequence | |
| from .timer import Timer | |
| class ProgressBar: | |
| """A progress bar which can print the progress. | |
| Args: | |
| task_num (int): Number of total steps. Defaults to 0. | |
| bar_width (int): Width of the progress bar. Defaults to 50. | |
| start (bool): Whether to start the progress bar in the constructor. | |
| Defaults to True. | |
| file (callable): Progress bar output mode. Defaults to "sys.stdout". | |
| Examples: | |
| >>> import mmengine | |
| >>> import time | |
| >>> bar = mmengine.ProgressBar(10) | |
| >>> for i in range(10): | |
| >>> bar.update() | |
| >>> time.sleep(1) | |
| """ | |
| def __init__(self, | |
| task_num: int = 0, | |
| bar_width: int = 50, | |
| start: bool = True, | |
| file=sys.stdout): | |
| self.task_num = task_num | |
| self.bar_width = bar_width | |
| self.completed = 0 | |
| self.file = file | |
| if start: | |
| self.start() | |
| def terminal_width(self): | |
| width, _ = get_terminal_size() | |
| return width | |
| def start(self): | |
| if self.task_num > 0: | |
| self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, ' | |
| 'elapsed: 0s, ETA:') | |
| else: | |
| self.file.write('completed: 0, elapsed: 0s') | |
| self.file.flush() | |
| self.timer = Timer() | |
| def update(self, num_tasks: int = 1): | |
| """update progressbar. | |
| Args: | |
| num_tasks (int): Update step size. | |
| """ | |
| assert num_tasks > 0 | |
| self.completed += num_tasks | |
| elapsed = self.timer.since_start() | |
| if elapsed > 0: | |
| fps = self.completed / elapsed | |
| else: | |
| fps = float('inf') | |
| if self.task_num > 0: | |
| percentage = self.completed / float(self.task_num) | |
| eta = int(elapsed * (1 - percentage) / percentage + 0.5) | |
| msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \ | |
| f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \ | |
| f'ETA: {eta:5}s' | |
| bar_width = min(self.bar_width, | |
| int(self.terminal_width - len(msg)) + 2, | |
| int(self.terminal_width * 0.6)) | |
| bar_width = max(2, bar_width) | |
| mark_width = int(bar_width * percentage) | |
| bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width) | |
| self.file.write(msg.format(bar_chars)) | |
| else: | |
| self.file.write( | |
| f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,' | |
| f' {fps:.1f} tasks/s') | |
| self.file.flush() | |
| def track_progress(func: Callable, | |
| tasks: Sequence, | |
| bar_width: int = 50, | |
| file=sys.stdout, | |
| **kwargs): | |
| """Track the progress of tasks execution with a progress bar. | |
| Tasks are done with a simple for-loop. | |
| Args: | |
| func (callable): The function to be applied to each task. | |
| tasks (Sequence): If tasks is a tuple, it must contain two elements, | |
| the first being the tasks to be completed and the other being the | |
| number of tasks. If it is not a tuple, it represents the tasks to | |
| be completed. | |
| bar_width (int): Width of progress bar. | |
| Returns: | |
| list: The task results. | |
| """ | |
| if isinstance(tasks, tuple): | |
| assert len(tasks) == 2 | |
| assert isinstance(tasks[0], Iterable) | |
| assert isinstance(tasks[1], int) | |
| task_num = tasks[1] | |
| tasks = tasks[0] # type: ignore | |
| elif isinstance(tasks, Sequence): | |
| task_num = len(tasks) | |
| else: | |
| raise TypeError( | |
| '"tasks" must be a tuple object or a sequence object, but got ' | |
| f'{type(tasks)}') | |
| prog_bar = ProgressBar(task_num, bar_width, file=file) | |
| results = [] | |
| for task in tasks: | |
| results.append(func(task, **kwargs)) | |
| prog_bar.update() | |
| prog_bar.file.write('\n') | |
| return results | |
| def init_pool(process_num, initializer=None, initargs=None): | |
| if initializer is None: | |
| return Pool(process_num) | |
| elif initargs is None: | |
| return Pool(process_num, initializer) | |
| else: | |
| if not isinstance(initargs, tuple): | |
| raise TypeError('"initargs" must be a tuple') | |
| return Pool(process_num, initializer, initargs) | |
| def track_parallel_progress(func: Callable, | |
| tasks: Sequence, | |
| nproc: int, | |
| initializer: Callable = None, | |
| initargs: tuple = None, | |
| bar_width: int = 50, | |
| chunksize: int = 1, | |
| skip_first: bool = False, | |
| keep_order: bool = True, | |
| file=sys.stdout): | |
| """Track the progress of parallel task execution with a progress bar. | |
| The built-in :mod:`multiprocessing` module is used for process pools and | |
| tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`. | |
| Args: | |
| func (callable): The function to be applied to each task. | |
| tasks (Sequence): If tasks is a tuple, it must contain two elements, | |
| the first being the tasks to be completed and the other being the | |
| number of tasks. If it is not a tuple, it represents the tasks to | |
| be completed. | |
| nproc (int): Process (worker) number. | |
| initializer (None or callable): Refer to :class:`multiprocessing.Pool` | |
| for details. | |
| initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for | |
| details. | |
| chunksize (int): Refer to :class:`multiprocessing.Pool` for details. | |
| bar_width (int): Width of progress bar. | |
| skip_first (bool): Whether to skip the first sample for each worker | |
| when estimating fps, since the initialization step may takes | |
| longer. | |
| keep_order (bool): If True, :func:`Pool.imap` is used, otherwise | |
| :func:`Pool.imap_unordered` is used. | |
| Returns: | |
| list: The task results. | |
| """ | |
| if isinstance(tasks, tuple): | |
| assert len(tasks) == 2 | |
| assert isinstance(tasks[0], Iterable) | |
| assert isinstance(tasks[1], int) | |
| task_num = tasks[1] | |
| tasks = tasks[0] # type: ignore | |
| elif isinstance(tasks, Sequence): | |
| task_num = len(tasks) | |
| else: | |
| raise TypeError( | |
| '"tasks" must be a tuple object or a sequence object, but got ' | |
| f'{type(tasks)}') | |
| pool = init_pool(nproc, initializer, initargs) | |
| start = not skip_first | |
| task_num -= nproc * chunksize * int(skip_first) | |
| prog_bar = ProgressBar(task_num, bar_width, start, file=file) | |
| results = [] | |
| if keep_order: | |
| gen = pool.imap(func, tasks, chunksize) | |
| else: | |
| gen = pool.imap_unordered(func, tasks, chunksize) | |
| for result in gen: | |
| results.append(result) | |
| if skip_first: | |
| if len(results) < nproc * chunksize: | |
| continue | |
| elif len(results) == nproc * chunksize: | |
| prog_bar.start() | |
| continue | |
| prog_bar.update() | |
| prog_bar.file.write('\n') | |
| pool.close() | |
| pool.join() | |
| return results | |
| def track_iter_progress(tasks: Sequence, bar_width: int = 50, file=sys.stdout): | |
| """Track the progress of tasks iteration or enumeration with a progress | |
| bar. | |
| Tasks are yielded with a simple for-loop. | |
| Args: | |
| tasks (Sequence): If tasks is a tuple, it must contain two elements, | |
| the first being the tasks to be completed and the other being the | |
| number of tasks. If it is not a tuple, it represents the tasks to | |
| be completed. | |
| bar_width (int): Width of progress bar. | |
| Yields: | |
| list: The task results. | |
| """ | |
| if isinstance(tasks, tuple): | |
| assert len(tasks) == 2 | |
| assert isinstance(tasks[0], Iterable) | |
| assert isinstance(tasks[1], int) | |
| task_num = tasks[1] | |
| tasks = tasks[0] # type: ignore | |
| elif isinstance(tasks, Sequence): | |
| task_num = len(tasks) | |
| else: | |
| raise TypeError( | |
| '"tasks" must be a tuple object or a sequence object, but got ' | |
| f'{type(tasks)}') | |
| prog_bar = ProgressBar(task_num, bar_width, file=file) | |
| for task in tasks: | |
| yield task | |
| prog_bar.update() | |
| prog_bar.file.write('\n') | |