File size: 1,915 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
from chainer.training.extension import Extension
class TensorboardLogger(Extension):
"""A tensorboard logger extension"""
default_name = "espnet_tensorboard_logger"
def __init__(
self, logger, att_reporter=None, ctc_reporter=None, entries=None, epoch=0
):
"""Init the extension
:param SummaryWriter logger: The logger to use
:param PlotAttentionReporter att_reporter: The (optional) PlotAttentionReporter
:param entries: The entries to watch
:param int epoch: The starting epoch
"""
self._entries = entries
self._att_reporter = att_reporter
self._ctc_reporter = ctc_reporter
self._logger = logger
self._epoch = epoch
def __call__(self, trainer):
"""Updates the events file with the new values
:param trainer: The trainer
"""
observation = trainer.observation
for k, v in observation.items():
if (self._entries is not None) and (k not in self._entries):
continue
if k is not None and v is not None:
if "cupy" in str(type(v)):
v = v.get()
if "cupy" in str(type(k)):
k = k.get()
self._logger.add_scalar(k, v, trainer.updater.iteration)
if (
self._att_reporter is not None
and trainer.updater.get_iterator("main").epoch > self._epoch
):
self._epoch = trainer.updater.get_iterator("main").epoch
self._att_reporter.log_attentions(self._logger, trainer.updater.iteration)
if (
self._ctc_reporter is not None
and trainer.updater.get_iterator("main").epoch > self._epoch
):
self._epoch = trainer.updater.get_iterator("main").epoch
self._ctc_reporter.log_ctc_probs(self._logger, trainer.updater.iteration)
|