|
import dataclasses |
|
import logging |
|
import os |
|
from typing import Any, Dict, List, Optional |
|
|
|
import numpy as np |
|
from sqlitedict import SqliteDict |
|
|
|
__all__ = ["Loggers"] |
|
|
|
from llm_studio.src.utils.plot_utils import PLOT_ENCODINGS |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def get_cfg(cfg: Any) -> Dict: |
|
"""Returns simplified config elements |
|
|
|
Args: |
|
cfg: configuration |
|
|
|
Returns: |
|
Dict of config elements |
|
""" |
|
|
|
items: Dict = {} |
|
type_annotations = cfg.get_annotations() |
|
|
|
cfg_dict = cfg.__dict__ |
|
|
|
cfg_dict = {key: cfg_dict[key] for key in cfg._get_order(warn_if_unset=False)} |
|
|
|
for k, v in cfg_dict.items(): |
|
if k.startswith("_") or cfg._get_visibility(k) < 0: |
|
continue |
|
|
|
if any([x in k for x in ["api"]]): |
|
continue |
|
|
|
if dataclasses.is_dataclass(v): |
|
elements_group = get_cfg(cfg=v) |
|
t = elements_group |
|
items = {**items, **t} |
|
else: |
|
type_annotation = type_annotations[k] |
|
if type_annotation == float: |
|
items[k] = float(v) |
|
else: |
|
items[k] = v |
|
|
|
return items |
|
|
|
|
|
class NeptuneLogger: |
|
def __init__(self, cfg: Any): |
|
import neptune as neptune |
|
from neptune.utils import stringify_unsupported |
|
|
|
if cfg.logging._neptune_debug: |
|
mode = "debug" |
|
else: |
|
mode = "async" |
|
|
|
self.logger = neptune.init_run( |
|
project=cfg.logging.neptune_project, |
|
api_token=os.getenv("NEPTUNE_API_TOKEN", ""), |
|
name=cfg.experiment_name, |
|
mode=mode, |
|
capture_stdout=False, |
|
capture_stderr=False, |
|
source_files=[], |
|
) |
|
|
|
self.logger["cfg"] = stringify_unsupported(get_cfg(cfg)) |
|
|
|
def log(self, subset: str, name: str, value: Any, step: Optional[int] = None): |
|
name = f"{subset}/{name}" |
|
self.logger[name].append(value, step=step) |
|
|
|
|
|
class LocalLogger: |
|
def __init__(self, cfg: Any): |
|
logging.getLogger("sqlitedict").setLevel(logging.ERROR) |
|
|
|
self.logs = f"{cfg.output_directory}/charts.db" |
|
|
|
params = get_cfg(cfg) |
|
|
|
with SqliteDict(self.logs) as logs: |
|
logs["cfg"] = params |
|
logs.commit() |
|
|
|
def log(self, subset: str, name: str, value: Any, step: Optional[int] = None): |
|
if subset in PLOT_ENCODINGS: |
|
with SqliteDict(self.logs) as logs: |
|
if subset not in logs: |
|
subset_dict = dict() |
|
else: |
|
subset_dict = logs[subset] |
|
subset_dict[name] = value |
|
logs[subset] = subset_dict |
|
logs.commit() |
|
return |
|
|
|
|
|
if np.isnan(value): |
|
value = None |
|
else: |
|
value = float(value) |
|
with SqliteDict(self.logs) as logs: |
|
if subset not in logs: |
|
subset_dict = dict() |
|
else: |
|
subset_dict = logs[subset] |
|
if name not in subset_dict: |
|
subset_dict[name] = {"steps": [], "values": []} |
|
|
|
subset_dict[name]["steps"].append(step) |
|
subset_dict[name]["values"].append(value) |
|
|
|
logs[subset] = subset_dict |
|
logs.commit() |
|
|
|
|
|
class DummyLogger: |
|
def __init__(self, cfg: Optional[Any] = None): |
|
return |
|
|
|
def log(self, subset: str, name: str, value: Any, step: Optional[int] = None): |
|
return |
|
|
|
|
|
class MainLogger: |
|
"""Main logger""" |
|
|
|
def __init__(self, cfg: Any): |
|
self.loggers = { |
|
"local": LocalLogger(cfg), |
|
"external": Loggers.get(cfg.logging.logger), |
|
} |
|
|
|
try: |
|
self.loggers["external"] = self.loggers["external"](cfg) |
|
except Exception as e: |
|
logger.warning( |
|
f"Error when initializing logger. " |
|
f"Disabling custom logging functionality. " |
|
f"Please ensure logger configuration is correct and " |
|
f"you have a stable Internet connection: {e}" |
|
) |
|
self.loggers["external"] = DummyLogger(cfg) |
|
|
|
def reset_external(self): |
|
self.loggers["external"] = DummyLogger() |
|
|
|
def log(self, subset: str, name: str, value: str | float, step: float = None): |
|
for k, logger in self.loggers.items(): |
|
if "validation_predictions" in name and k == "external": |
|
continue |
|
if subset == "internal" and not isinstance(logger, LocalLogger): |
|
continue |
|
logger.log(subset=subset, name=name, value=value, step=step) |
|
|
|
|
|
class Loggers: |
|
"""Loggers factory.""" |
|
|
|
_loggers = {"None": DummyLogger, "Neptune": NeptuneLogger} |
|
|
|
@classmethod |
|
def names(cls) -> List[str]: |
|
return sorted(cls._loggers.keys()) |
|
|
|
@classmethod |
|
def get(cls, name: str) -> Any: |
|
"""Access to Loggers. |
|
|
|
Args: |
|
name: loggers name |
|
Returns: |
|
A class to build the Loggers |
|
""" |
|
|
|
return cls._loggers.get(name, DummyLogger) |
|
|