|
import sys |
|
import threading |
|
from typing import List, Union |
|
import tqdm |
|
|
|
from src.hooks.progressListener import ProgressListener |
|
|
|
class ProgressListenerHandle: |
|
def __init__(self, listener: ProgressListener): |
|
self.listener = listener |
|
|
|
def __enter__(self): |
|
register_thread_local_progress_listener(self.listener) |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
unregister_thread_local_progress_listener(self.listener) |
|
|
|
if exc_type is None: |
|
self.listener.on_finished() |
|
|
|
class _CustomProgressBar(tqdm.tqdm): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self._current = self.n |
|
|
|
def update(self, n): |
|
super().update(n) |
|
|
|
self._current += n |
|
|
|
|
|
listeners = _get_thread_local_listeners() |
|
|
|
for listener in listeners: |
|
listener.on_progress(self._current, self.total) |
|
|
|
_thread_local = threading.local() |
|
|
|
def _get_thread_local_listeners(): |
|
if not hasattr(_thread_local, 'listeners'): |
|
_thread_local.listeners = [] |
|
return _thread_local.listeners |
|
|
|
_hooked = False |
|
|
|
def init_progress_hook(): |
|
global _hooked |
|
|
|
if _hooked: |
|
return |
|
|
|
|
|
import whisper.transcribe |
|
transcribe_module = sys.modules['whisper.transcribe'] |
|
transcribe_module.tqdm.tqdm = _CustomProgressBar |
|
_hooked = True |
|
|
|
def register_thread_local_progress_listener(progress_listener: ProgressListener): |
|
|
|
init_progress_hook() |
|
|
|
listeners = _get_thread_local_listeners() |
|
listeners.append(progress_listener) |
|
|
|
def unregister_thread_local_progress_listener(progress_listener: ProgressListener): |
|
listeners = _get_thread_local_listeners() |
|
|
|
if progress_listener in listeners: |
|
listeners.remove(progress_listener) |
|
|
|
def create_progress_listener_handle(progress_listener: ProgressListener): |
|
return ProgressListenerHandle(progress_listener) |
|
|
|
|
|
if __name__ == '__main__': |
|
class PrintingProgressListener: |
|
def on_progress(self, current: Union[int, float], total: Union[int, float]): |
|
print(f"Progress: {current}/{total}") |
|
|
|
def on_finished(self): |
|
print("Finished") |
|
|
|
import whisper |
|
model = whisper.load_model("medium") |
|
|
|
with create_progress_listener_handle(PrintingProgressListener()) as listener: |
|
|
|
result = model.transcribe("J:\\Dev\\OpenAI\\whisper\\tests\\Noriko\\out.mka", language="Japanese", fp16=False, verbose=None) |
|
print(result) |
|
|
|
print("Done") |