rawalkhirodkar's picture
Add initial commit
28c256d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import sys
from collections.abc import Iterable
from multiprocessing import Pool
from shutil import get_terminal_size
from typing import Callable, Sequence
from .timer import Timer
class ProgressBar:
"""A progress bar which can print the progress.
Args:
task_num (int): Number of total steps. Defaults to 0.
bar_width (int): Width of the progress bar. Defaults to 50.
start (bool): Whether to start the progress bar in the constructor.
Defaults to True.
file (callable): Progress bar output mode. Defaults to "sys.stdout".
Examples:
>>> import mmengine
>>> import time
>>> bar = mmengine.ProgressBar(10)
>>> for i in range(10):
>>> bar.update()
>>> time.sleep(1)
"""
def __init__(self,
task_num: int = 0,
bar_width: int = 50,
start: bool = True,
file=sys.stdout):
self.task_num = task_num
self.bar_width = bar_width
self.completed = 0
self.file = file
if start:
self.start()
@property
def terminal_width(self):
width, _ = get_terminal_size()
return width
def start(self):
if self.task_num > 0:
self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
'elapsed: 0s, ETA:')
else:
self.file.write('completed: 0, elapsed: 0s')
self.file.flush()
self.timer = Timer()
def update(self, num_tasks: int = 1):
"""update progressbar.
Args:
num_tasks (int): Update step size.
"""
assert num_tasks > 0
self.completed += num_tasks
elapsed = self.timer.since_start()
if elapsed > 0:
fps = self.completed / elapsed
else:
fps = float('inf')
if self.task_num > 0:
percentage = self.completed / float(self.task_num)
eta = int(elapsed * (1 - percentage) / percentage + 0.5)
msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
f'ETA: {eta:5}s'
bar_width = min(self.bar_width,
int(self.terminal_width - len(msg)) + 2,
int(self.terminal_width * 0.6))
bar_width = max(2, bar_width)
mark_width = int(bar_width * percentage)
bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
self.file.write(msg.format(bar_chars))
else:
self.file.write(
f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
f' {fps:.1f} tasks/s')
self.file.flush()
def track_progress(func: Callable,
tasks: Sequence,
bar_width: int = 50,
file=sys.stdout,
**kwargs):
"""Track the progress of tasks execution with a progress bar.
Tasks are done with a simple for-loop.
Args:
func (callable): The function to be applied to each task.
tasks (Sequence): If tasks is a tuple, it must contain two elements,
the first being the tasks to be completed and the other being the
number of tasks. If it is not a tuple, it represents the tasks to
be completed.
bar_width (int): Width of progress bar.
Returns:
list: The task results.
"""
if isinstance(tasks, tuple):
assert len(tasks) == 2
assert isinstance(tasks[0], Iterable)
assert isinstance(tasks[1], int)
task_num = tasks[1]
tasks = tasks[0] # type: ignore
elif isinstance(tasks, Sequence):
task_num = len(tasks)
else:
raise TypeError(
'"tasks" must be a tuple object or a sequence object, but got '
f'{type(tasks)}')
prog_bar = ProgressBar(task_num, bar_width, file=file)
results = []
for task in tasks:
results.append(func(task, **kwargs))
prog_bar.update()
prog_bar.file.write('\n')
return results
def init_pool(process_num, initializer=None, initargs=None):
if initializer is None:
return Pool(process_num)
elif initargs is None:
return Pool(process_num, initializer)
else:
if not isinstance(initargs, tuple):
raise TypeError('"initargs" must be a tuple')
return Pool(process_num, initializer, initargs)
def track_parallel_progress(func: Callable,
tasks: Sequence,
nproc: int,
initializer: Callable = None,
initargs: tuple = None,
bar_width: int = 50,
chunksize: int = 1,
skip_first: bool = False,
keep_order: bool = True,
file=sys.stdout):
"""Track the progress of parallel task execution with a progress bar.
The built-in :mod:`multiprocessing` module is used for process pools and
tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
Args:
func (callable): The function to be applied to each task.
tasks (Sequence): If tasks is a tuple, it must contain two elements,
the first being the tasks to be completed and the other being the
number of tasks. If it is not a tuple, it represents the tasks to
be completed.
nproc (int): Process (worker) number.
initializer (None or callable): Refer to :class:`multiprocessing.Pool`
for details.
initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for
details.
chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
bar_width (int): Width of progress bar.
skip_first (bool): Whether to skip the first sample for each worker
when estimating fps, since the initialization step may takes
longer.
keep_order (bool): If True, :func:`Pool.imap` is used, otherwise
:func:`Pool.imap_unordered` is used.
Returns:
list: The task results.
"""
if isinstance(tasks, tuple):
assert len(tasks) == 2
assert isinstance(tasks[0], Iterable)
assert isinstance(tasks[1], int)
task_num = tasks[1]
tasks = tasks[0] # type: ignore
elif isinstance(tasks, Sequence):
task_num = len(tasks)
else:
raise TypeError(
'"tasks" must be a tuple object or a sequence object, but got '
f'{type(tasks)}')
pool = init_pool(nproc, initializer, initargs)
start = not skip_first
task_num -= nproc * chunksize * int(skip_first)
prog_bar = ProgressBar(task_num, bar_width, start, file=file)
results = []
if keep_order:
gen = pool.imap(func, tasks, chunksize)
else:
gen = pool.imap_unordered(func, tasks, chunksize)
for result in gen:
results.append(result)
if skip_first:
if len(results) < nproc * chunksize:
continue
elif len(results) == nproc * chunksize:
prog_bar.start()
continue
prog_bar.update()
prog_bar.file.write('\n')
pool.close()
pool.join()
return results
def track_iter_progress(tasks: Sequence, bar_width: int = 50, file=sys.stdout):
"""Track the progress of tasks iteration or enumeration with a progress
bar.
Tasks are yielded with a simple for-loop.
Args:
tasks (Sequence): If tasks is a tuple, it must contain two elements,
the first being the tasks to be completed and the other being the
number of tasks. If it is not a tuple, it represents the tasks to
be completed.
bar_width (int): Width of progress bar.
Yields:
list: The task results.
"""
if isinstance(tasks, tuple):
assert len(tasks) == 2
assert isinstance(tasks[0], Iterable)
assert isinstance(tasks[1], int)
task_num = tasks[1]
tasks = tasks[0] # type: ignore
elif isinstance(tasks, Sequence):
task_num = len(tasks)
else:
raise TypeError(
'"tasks" must be a tuple object or a sequence object, but got '
f'{type(tasks)}')
prog_bar = ProgressBar(task_num, bar_width, file=file)
for task in tasks:
yield task
prog_bar.update()
prog_bar.file.write('\n')