import sys import threading from typing import List, Union import tqdm class ProgressListener: def on_progress(self, current: Union[int, float], total: Union[int, float]): self.total = total def on_finished(self): pass class ProgressListenerHandle: def __init__(self, listener: ProgressListener): self.listener = listener def __enter__(self): register_thread_local_progress_listener(self.listener) def __exit__(self, exc_type, exc_val, exc_tb): unregister_thread_local_progress_listener(self.listener) if exc_type is None: self.listener.on_finished() class SubTaskProgressListener(ProgressListener): """ A sub task listener that reports the progress of a sub task to a base task listener Parameters ---------- base_task_listener : ProgressListener The base progress listener to accumulate overall progress in. base_task_total : float The maximum total progress that will be reported to the base progress listener. sub_task_start : float The starting progress of a sub task, in respect to the base progress listener. sub_task_total : float The total amount of progress a sub task will report to the base progress listener. """ def __init__( self, base_task_listener: ProgressListener, base_task_total: float, sub_task_start: float, sub_task_total: float, ): self.base_task_listener = base_task_listener self.base_task_total = base_task_total self.sub_task_start = sub_task_start self.sub_task_total = sub_task_total def on_progress(self, current: Union[int, float], total: Union[int, float]): sub_task_progress_frac = current / total sub_task_progress = self.sub_task_start + self.sub_task_total * sub_task_progress_frac self.base_task_listener.on_progress(sub_task_progress, self.base_task_total) def on_finished(self): self.base_task_listener.on_progress(self.sub_task_start + self.sub_task_total, self.base_task_total) class _CustomProgressBar(tqdm.tqdm): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._current = self.n # Set the initial value def update(self, n): super().update(n) # Because the progress bar might be disabled, we need to manually update the progress self._current += n # Inform listeners listeners = _get_thread_local_listeners() for listener in listeners: listener.on_progress(self._current, self.total) _thread_local = threading.local() def _get_thread_local_listeners(): if not hasattr(_thread_local, 'listeners'): _thread_local.listeners = [] return _thread_local.listeners _hooked = False def init_progress_hook(): global _hooked if _hooked: return # Inject into tqdm.tqdm of Whisper, so we can see progress import whisper.transcribe transcribe_module = sys.modules['whisper.transcribe'] transcribe_module.tqdm.tqdm = _CustomProgressBar _hooked = True def register_thread_local_progress_listener(progress_listener: ProgressListener): # This is a workaround for the fact that the progress bar is not exposed in the API init_progress_hook() listeners = _get_thread_local_listeners() listeners.append(progress_listener) def unregister_thread_local_progress_listener(progress_listener: ProgressListener): listeners = _get_thread_local_listeners() if progress_listener in listeners: listeners.remove(progress_listener) def create_progress_listener_handle(progress_listener: ProgressListener): return ProgressListenerHandle(progress_listener) if __name__ == '__main__': with create_progress_listener_handle(ProgressListener()) as listener: # Call model.transcribe here pass print("Done")