|
|
import sys |
|
|
|
|
|
from overdue import timeout_set_to |
|
|
import threading |
|
|
import contextvars |
|
|
from typing import Union |
|
|
from contextlib import contextmanager |
|
|
from .logging import logger, get_log_file |
|
|
|
|
|
class Callback: |
|
|
|
|
|
""" |
|
|
a base class for callbacks |
|
|
""" |
|
|
|
|
|
def on_error(self, exception, *args, **kwargs): |
|
|
pass |
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
try: |
|
|
result = self.run(*args, **kwargs) |
|
|
except Exception as e: |
|
|
self.on_error(e, *args, kwargs) |
|
|
raise e |
|
|
return result |
|
|
|
|
|
def run(self, *args, **kwargs): |
|
|
raise NotImplementedError(f"run is not implemented for {type(self).__name__}!") |
|
|
|
|
|
|
|
|
class CallbackManager: |
|
|
|
|
|
def __init__(self): |
|
|
self.local_data = threading.local() |
|
|
|
|
|
|
|
|
def _ensure_callbacks(self): |
|
|
if not hasattr(self.local_data, "callbacks"): |
|
|
self.local_data.callbacks = {} |
|
|
|
|
|
def set_callback(self, callback_type: str, callback: Callback): |
|
|
self._ensure_callbacks() |
|
|
self.local_data.callbacks[callback_type] = callback |
|
|
|
|
|
def get_callback(self, callback_type: str): |
|
|
self._ensure_callbacks() |
|
|
return self.local_data.callbacks.get(callback_type, None) |
|
|
|
|
|
def has_callback(self, callback_type: str): |
|
|
self._ensure_callbacks() |
|
|
return callback_type in self.local_data.callbacks |
|
|
|
|
|
def clear_callback(self, callback_type: str): |
|
|
self._ensure_callbacks() |
|
|
if callback_type in self.local_data.callbacks: |
|
|
del self.local_data.callbacks[callback_type] |
|
|
|
|
|
def clear_all(self): |
|
|
self._ensure_callbacks() |
|
|
self.local_data.callbacks.clear() |
|
|
|
|
|
callback_manager = CallbackManager() |
|
|
|
|
|
|
|
|
class DeferredExceptionHandler(Callback): |
|
|
|
|
|
def __init__(self): |
|
|
self.exceptions = [] |
|
|
|
|
|
def add(self, exception): |
|
|
self.exceptions.append(exception) |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def exception_buffer(): |
|
|
if not callback_manager.has_callback("exception_buffer"): |
|
|
exception_handler = DeferredExceptionHandler() |
|
|
callback_manager.set_callback("exception_buffer", exception_handler) |
|
|
else: |
|
|
exception_handler = callback_manager.get_callback("exception_buffer") |
|
|
try: |
|
|
yield exception_handler |
|
|
finally: |
|
|
callback_manager.clear_callback("exception_buffer") |
|
|
|
|
|
|
|
|
suppress_cost_logs = contextvars.ContextVar("suppress_cost_logs", default=False) |
|
|
|
|
|
@contextmanager |
|
|
def suppress_cost_logging(): |
|
|
"""Thread-safe context manager: only suppresses cost-related logs without affecting other info-level logs""" |
|
|
token = suppress_cost_logs.set(True) |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
suppress_cost_logs.reset(token) |
|
|
|
|
|
|
|
|
silence_nesting = contextvars.ContextVar("silence_nesting", default=0) |
|
|
|
|
|
@contextmanager |
|
|
def suppress_logger_info(): |
|
|
token = None |
|
|
try: |
|
|
current_level = silence_nesting.get() |
|
|
token = silence_nesting.set(current_level + 1) |
|
|
|
|
|
if current_level == 0: |
|
|
logger.remove() |
|
|
logger.add(sys.stdout, level="WARNING") |
|
|
log_file = get_log_file() |
|
|
if log_file is not None: |
|
|
logger.add( |
|
|
log_file, |
|
|
encoding="utf-8", |
|
|
level="WARNING", |
|
|
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}" |
|
|
) |
|
|
yield |
|
|
finally: |
|
|
new_level = silence_nesting.get() - 1 |
|
|
silence_nesting.set(new_level) |
|
|
|
|
|
if new_level == 0: |
|
|
logger.remove() |
|
|
logger.add(sys.stdout, level="INFO") |
|
|
log_file = get_log_file() |
|
|
if log_file is not None: |
|
|
logger.add( |
|
|
log_file, |
|
|
encoding="utf-8", |
|
|
level="INFO", |
|
|
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}" |
|
|
) |
|
|
if token: |
|
|
silence_nesting.reset(token) |
|
|
|
|
|
|
|
|
class TimeoutException(Exception): |
|
|
pass |
|
|
|
|
|
class TimeoutContext: |
|
|
""" |
|
|
A reliable cross-platform timeout context manager using stopit |
|
|
|
|
|
Usage: |
|
|
with TimeoutContext(seconds=5): |
|
|
# code that may timeout |
|
|
do_something() |
|
|
""" |
|
|
def __init__(self, seconds: Union[int, float]): |
|
|
self.seconds = float(seconds) |
|
|
|
|
|
self._cm = None |
|
|
self._result = None |
|
|
|
|
|
def __enter__(self): |
|
|
|
|
|
|
|
|
self._cm = timeout_set_to(self.seconds) |
|
|
self._result = self._cm.__enter__() |
|
|
return self |
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
|
|
|
|
|
|
|
|
|
self._cm.__exit__(exc_type, exc_val, exc_tb) |
|
|
if self._result.triggered: |
|
|
raise TimeoutException("Operation timed out") |
|
|
return False |
|
|
|
|
|
@contextmanager |
|
|
def timeout(seconds: float): |
|
|
with TimeoutContext(seconds): |
|
|
yield |
|
|
|