|
|
import contextlib |
|
|
import copy |
|
|
import pathlib |
|
|
from enum import Enum |
|
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
|
|
from .logging import get_logger |
|
|
from .utils import Timer, TimerDevice |
|
|
|
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
class BaseTracker: |
|
|
r"""Base class for loggers. Does nothing by default, so it is useful when you want to disable logging.""" |
|
|
|
|
|
def __init__(self): |
|
|
self._timed_metrics = {} |
|
|
|
|
|
@contextlib.contextmanager |
|
|
def timed(self, name: str, device: TimerDevice = TimerDevice.CPU, device_sync: bool = False): |
|
|
r"""Context manager to track time for a specific operation.""" |
|
|
timer = Timer(name, device, device_sync) |
|
|
timer.start() |
|
|
yield timer |
|
|
timer.end() |
|
|
elapsed_time = timer.elapsed_time |
|
|
if name in self._timed_metrics: |
|
|
|
|
|
self._timed_metrics[name] += elapsed_time |
|
|
else: |
|
|
self._timed_metrics[name] = elapsed_time |
|
|
|
|
|
def log(self, metrics: Dict[str, Any], step: int) -> None: |
|
|
pass |
|
|
|
|
|
def finish(self) -> None: |
|
|
pass |
|
|
|
|
|
|
|
|
class DummyTracker(BaseTracker): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
def log(self, *args, **kwargs): |
|
|
pass |
|
|
|
|
|
def finish(self) -> None: |
|
|
pass |
|
|
|
|
|
|
|
|
class WandbTracker(BaseTracker): |
|
|
r"""Logger implementation for Weights & Biases.""" |
|
|
|
|
|
def __init__(self, experiment_name: str, log_dir: str, config: Optional[Dict[str, Any]] = None) -> None: |
|
|
super().__init__() |
|
|
|
|
|
import wandb |
|
|
|
|
|
self.wandb = wandb |
|
|
|
|
|
|
|
|
pathlib.Path(log_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
self.run = wandb.init(project=experiment_name, dir=log_dir, config=config) |
|
|
logger.info("WandB logging enabled") |
|
|
|
|
|
def log(self, metrics: Dict[str, Any], step: int) -> None: |
|
|
metrics = {**self._timed_metrics, **metrics} |
|
|
self.run.log(metrics, step=step) |
|
|
self._timed_metrics = {} |
|
|
|
|
|
def finish(self) -> None: |
|
|
self.run.finish() |
|
|
|
|
|
|
|
|
class SequentialTracker(BaseTracker): |
|
|
r"""Sequential tracker that logs to multiple trackers in sequence.""" |
|
|
|
|
|
def __init__(self, trackers: List[BaseTracker]) -> None: |
|
|
super().__init__() |
|
|
self.trackers = trackers |
|
|
|
|
|
@contextlib.contextmanager |
|
|
def timed(self, name: str, device: TimerDevice = TimerDevice.CPU, device_sync: bool = False): |
|
|
r"""Context manager to track time for a specific operation.""" |
|
|
timer = Timer(name, device, device_sync) |
|
|
timer.start() |
|
|
yield timer |
|
|
timer.end() |
|
|
elapsed_time = timer.elapsed_time |
|
|
if name in self._timed_metrics: |
|
|
|
|
|
self._timed_metrics[name] += elapsed_time |
|
|
else: |
|
|
self._timed_metrics[name] = elapsed_time |
|
|
for tracker in self.trackers: |
|
|
tracker._timed_metrics = copy.deepcopy(self._timed_metrics) |
|
|
|
|
|
def log(self, metrics: Dict[str, Any], step: int) -> None: |
|
|
for tracker in self.trackers: |
|
|
tracker.log(metrics, step) |
|
|
self._timed_metrics = {} |
|
|
|
|
|
def finish(self) -> None: |
|
|
for tracker in self.trackers: |
|
|
tracker.finish() |
|
|
|
|
|
|
|
|
class Trackers(str, Enum): |
|
|
r"""Enum for supported trackers.""" |
|
|
|
|
|
NONE = "none" |
|
|
WANDB = "wandb" |
|
|
|
|
|
|
|
|
_SUPPORTED_TRACKERS = [tracker.value for tracker in Trackers.__members__.values()] |
|
|
|
|
|
|
|
|
def initialize_trackers( |
|
|
trackers: List[str], experiment_name: str, config: Dict[str, Any], log_dir: str |
|
|
) -> Union[BaseTracker, SequentialTracker]: |
|
|
r"""Initialize loggers based on the provided configuration.""" |
|
|
|
|
|
logger.info(f"Initializing trackers: {trackers}. Logging to {log_dir=}") |
|
|
|
|
|
if len(trackers) == 0: |
|
|
return BaseTracker() |
|
|
|
|
|
if any(tracker_name not in _SUPPORTED_TRACKERS for tracker_name in set(trackers)): |
|
|
raise ValueError(f"Unsupported tracker(s) provided. Supported trackers: {_SUPPORTED_TRACKERS}") |
|
|
|
|
|
tracker_instances = [] |
|
|
for tracker_name in set(trackers): |
|
|
if tracker_name == Trackers.NONE: |
|
|
tracker = BaseTracker() |
|
|
elif tracker_name == Trackers.WANDB: |
|
|
tracker = WandbTracker(experiment_name, log_dir, config) |
|
|
tracker_instances.append(tracker) |
|
|
|
|
|
tracker = SequentialTracker(tracker_instances) |
|
|
return tracker |
|
|
|
|
|
|
|
|
TrackerType = Union[BaseTracker, SequentialTracker, WandbTracker] |
|
|
|