Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional | |
import mmengine.dist as dist | |
import rich.progress as progress | |
from rich.live import Live | |
disable_progress_bar = False | |
global_progress = progress.Progress( | |
'{task.description}', | |
progress.BarColumn(), | |
progress.TaskProgressColumn(show_speed=True), | |
progress.TimeRemainingColumn(), | |
) | |
global_live = Live(global_progress, refresh_per_second=10) | |
def track(sequence, description: str = '', total: Optional[float] = None): | |
if disable_progress_bar: | |
yield from sequence | |
else: | |
global_live.start() | |
task_id = global_progress.add_task(description, total=total) | |
task = global_progress._tasks[task_id] | |
try: | |
yield from global_progress.track(sequence, task_id=task_id) | |
finally: | |
if task.total is None: | |
global_progress.update(task_id, total=task.completed) | |
if all(task.finished for task in global_progress.tasks): | |
global_live.stop() | |
for task_id in global_progress.task_ids: | |
global_progress.remove_task(task_id) | |
def track_on_main_process(sequence, description='', total=None): | |
if not dist.is_main_process() or disable_progress_bar: | |
yield from sequence | |
else: | |
yield from track(sequence, total=total, description=description) | |