import re import pprint import logging from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities.rank_zero import rank_zero_only class ConsoleLogger(LightningLoggerBase): def __init__(self, log_keys=[]): super().__init__() self.log_keys = [re.compile(k) for k in log_keys] self.dict_printer = pprint.PrettyPrinter(indent=2, compact=False).pformat def match_log_keys(self, s): return True if not self.log_keys else any(r.search(s) for r in self.log_keys) @property def name(self): return 'console' @property def version(self): return '0' @property @rank_zero_experiment def experiment(self): return logging.getLogger('pytorch_lightning') @rank_zero_only def log_hyperparams(self, params): pass @rank_zero_only def log_metrics(self, metrics, step): metrics_ = {k: v for k, v in metrics.items() if self.match_log_keys(k)} if not metrics_: return self.experiment.info(f"\nEpoch{metrics['epoch']} Step{step}\n{self.dict_printer(metrics_)}")