| from collections import deque |
| from datetime import datetime |
| import io |
| import logging |
| import sys |
| import threading |
|
|
| logs = None |
| stdout_interceptor = None |
| stderr_interceptor = None |
|
|
|
|
| class LogInterceptor(io.TextIOWrapper): |
| def __init__(self, stream, *args, **kwargs): |
| buffer = stream.buffer |
| encoding = stream.encoding |
| super().__init__(buffer, *args, **kwargs, encoding=encoding, line_buffering=stream.line_buffering) |
| self._lock = threading.Lock() |
| self._flush_callbacks = [] |
| self._logs_since_flush = [] |
|
|
| def write(self, data): |
| entry = {"t": datetime.now().isoformat(), "m": data} |
| with self._lock: |
| self._logs_since_flush.append(entry) |
|
|
| |
| |
| if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"): |
| logs.pop() |
| logs.append(entry) |
| super().write(data) |
|
|
| def flush(self): |
| super().flush() |
| for cb in self._flush_callbacks: |
| cb(self._logs_since_flush) |
| self._logs_since_flush = [] |
|
|
| def on_flush(self, callback): |
| self._flush_callbacks.append(callback) |
|
|
|
|
| def get_logs(): |
| return logs |
|
|
|
|
| def on_flush(callback): |
| if stdout_interceptor is not None: |
| stdout_interceptor.on_flush(callback) |
| if stderr_interceptor is not None: |
| stderr_interceptor.on_flush(callback) |
|
|
| def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False): |
| global logs |
| if logs: |
| return |
|
|
| |
| logs = deque(maxlen=capacity) |
|
|
| global stdout_interceptor |
| global stderr_interceptor |
| stdout_interceptor = sys.stdout = LogInterceptor(sys.stdout) |
| stderr_interceptor = sys.stderr = LogInterceptor(sys.stderr) |
|
|
| |
| logger = logging.getLogger() |
| logger.setLevel(log_level) |
|
|
| stream_handler = logging.StreamHandler() |
| stream_handler.setFormatter(logging.Formatter("%(message)s")) |
|
|
| if use_stdout: |
| |
| stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR) |
|
|
| |
| stdout_handler = logging.StreamHandler(sys.stdout) |
| stdout_handler.setFormatter(logging.Formatter("%(message)s")) |
| stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR) |
| logger.addHandler(stdout_handler) |
|
|
| logger.addHandler(stream_handler) |
|
|
|
|
| STARTUP_WARNINGS = [] |
|
|
|
|
| def log_startup_warning(msg): |
| logging.warning(msg) |
| STARTUP_WARNINGS.append(msg) |
|
|
|
|
| def print_startup_warnings(): |
| for s in STARTUP_WARNINGS: |
| logging.warning(s) |
| STARTUP_WARNINGS.clear() |
|
|