File size: 1,915 Bytes
ad16788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from chainer.training.extension import Extension


class TensorboardLogger(Extension):
    """A tensorboard logger extension"""

    default_name = "espnet_tensorboard_logger"

    def __init__(
        self, logger, att_reporter=None, ctc_reporter=None, entries=None, epoch=0
    ):
        """Init the extension

        :param SummaryWriter logger: The logger to use
        :param PlotAttentionReporter att_reporter: The (optional) PlotAttentionReporter
        :param entries: The entries to watch
        :param int epoch: The starting epoch
        """
        self._entries = entries
        self._att_reporter = att_reporter
        self._ctc_reporter = ctc_reporter
        self._logger = logger
        self._epoch = epoch

    def __call__(self, trainer):
        """Updates the events file with the new values

        :param trainer: The trainer
        """
        observation = trainer.observation
        for k, v in observation.items():
            if (self._entries is not None) and (k not in self._entries):
                continue
            if k is not None and v is not None:
                if "cupy" in str(type(v)):
                    v = v.get()
                if "cupy" in str(type(k)):
                    k = k.get()
                self._logger.add_scalar(k, v, trainer.updater.iteration)
        if (
            self._att_reporter is not None
            and trainer.updater.get_iterator("main").epoch > self._epoch
        ):
            self._epoch = trainer.updater.get_iterator("main").epoch
            self._att_reporter.log_attentions(self._logger, trainer.updater.iteration)
        if (
            self._ctc_reporter is not None
            and trainer.updater.get_iterator("main").epoch > self._epoch
        ):
            self._epoch = trainer.updater.get_iterator("main").epoch
            self._ctc_reporter.log_ctc_probs(self._logger, trainer.updater.iteration)