Spaces:
Runtime error
Runtime error
import re | |
import pprint | |
import logging | |
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment | |
from pytorch_lightning.utilities.rank_zero import rank_zero_only | |
class ConsoleLogger(LightningLoggerBase): | |
def __init__(self, log_keys=[]): | |
super().__init__() | |
self.log_keys = [re.compile(k) for k in log_keys] | |
self.dict_printer = pprint.PrettyPrinter(indent=2, compact=False).pformat | |
def match_log_keys(self, s): | |
return True if not self.log_keys else any(r.search(s) for r in self.log_keys) | |
def name(self): | |
return 'console' | |
def version(self): | |
return '0' | |
def experiment(self): | |
return logging.getLogger('pytorch_lightning') | |
def log_hyperparams(self, params): | |
pass | |
def log_metrics(self, metrics, step): | |
metrics_ = {k: v for k, v in metrics.items() if self.match_log_keys(k)} | |
if not metrics_: | |
return | |
self.experiment.info(f"\nEpoch{metrics['epoch']} Step{step}\n{self.dict_printer(metrics_)}") | |