|
|
from __future__ import annotations |
|
|
|
|
|
import time |
|
|
from dataclasses import dataclass, field |
|
|
from functools import wraps |
|
|
from inspect import iscoroutinefunction |
|
|
from typing import Any, Callable, ClassVar, TypeVar |
|
|
|
|
|
from hibiapi.utils.log import logger |
|
|
|
|
|
Callable_T = TypeVar("Callable_T", bound=Callable) |
|
|
|
|
|
|
|
|
class TimerError(Exception): |
|
|
"""A custom exception used to report errors in use of Timer class""" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Timer: |
|
|
"""Time your code using a class, context manager, or decorator""" |
|
|
|
|
|
timers: ClassVar[dict[str, float]] = dict() |
|
|
name: str | None = None |
|
|
text: str = "Elapsed time: {:0.3f} seconds" |
|
|
logger_func: Callable[[str], None] | None = print |
|
|
_start_time: float | None = field(default=None, init=False, repr=False) |
|
|
|
|
|
def __post_init__(self) -> None: |
|
|
"""Initialization: add timer to dict of timers""" |
|
|
if self.name: |
|
|
self.timers.setdefault(self.name, 0) |
|
|
|
|
|
def start(self) -> None: |
|
|
"""Start a new timer""" |
|
|
if self._start_time is not None: |
|
|
raise TimerError("Timer is running. Use .stop() to stop it") |
|
|
|
|
|
self._start_time = time.perf_counter() |
|
|
|
|
|
def stop(self) -> float: |
|
|
"""Stop the timer, and report the elapsed time""" |
|
|
if self._start_time is None: |
|
|
raise TimerError("Timer is not running. Use .start() to start it") |
|
|
|
|
|
|
|
|
elapsed_time = time.perf_counter() - self._start_time |
|
|
self._start_time = None |
|
|
|
|
|
|
|
|
if self.logger_func: |
|
|
self.logger_func(self.text.format(elapsed_time * 1000)) |
|
|
if self.name: |
|
|
self.timers[self.name] += elapsed_time |
|
|
|
|
|
return elapsed_time |
|
|
|
|
|
def __enter__(self) -> Timer: |
|
|
"""Start a new timer as a context manager""" |
|
|
self.start() |
|
|
return self |
|
|
|
|
|
def __exit__(self, *exc_info: Any) -> None: |
|
|
"""Stop the context manager timer""" |
|
|
self.stop() |
|
|
|
|
|
def _recreate_cm(self) -> Timer: |
|
|
return self.__class__(self.name, self.text, self.logger_func) |
|
|
|
|
|
def __call__(self, function: Callable_T) -> Callable_T: |
|
|
@wraps(function) |
|
|
async def async_wrapper(*args: Any, **kwargs: Any): |
|
|
self.text = ( |
|
|
f"<g>Async</g> function <y>{function.__qualname__}</y> " |
|
|
"cost <e>{:.3f}ms</e>" |
|
|
) |
|
|
|
|
|
with self._recreate_cm(): |
|
|
return await function(*args, **kwargs) |
|
|
|
|
|
@wraps(function) |
|
|
def sync_wrapper(*args: Any, **kwargs: Any): |
|
|
self.text = ( |
|
|
f"<g>sync</g> function <y>{function.__qualname__}</y> " |
|
|
"cost <e>{:.3f}ms</e>" |
|
|
) |
|
|
|
|
|
with self._recreate_cm(): |
|
|
return function(*args, **kwargs) |
|
|
|
|
|
return ( |
|
|
async_wrapper if iscoroutinefunction(function) else sync_wrapper |
|
|
) |
|
|
|
|
|
|
|
|
TimeIt = Timer(logger_func=logger.trace) |
|
|
|