ReactSeq / onmt /utils /report_manager.py
Oopstom's picture
Upload 313 files
c668e80 verified
""" Report manager utility """
import time
from datetime import datetime
import onmt
from onmt.utils.logging import logger
def build_report_manager(opt, gpu_rank):
if opt.tensorboard and gpu_rank <= 0:
from torch.utils.tensorboard import SummaryWriter
if not hasattr(opt, "tensorboard_log_dir_dated"):
opt.tensorboard_log_dir_dated = (
opt.tensorboard_log_dir + datetime.now().strftime("/%b-%d_%H-%M-%S")
)
writer = SummaryWriter(opt.tensorboard_log_dir_dated, comment="Unmt")
else:
writer = None
report_mgr = ReportMgr(opt.report_every, start_time=-1, tensorboard_writer=writer)
return report_mgr
class ReportMgrBase(object):
"""
Report Manager Base class
Inherited classes should override:
* `_report_training`
* `_report_step`
"""
def __init__(self, report_every, start_time=-1.0):
"""
Args:
report_every(int): Report status every this many sentences
start_time(float): manually set report start time. Negative values
means that you will need to set it later or use `start()`
"""
self.report_every = report_every
self.start_time = start_time
def start(self):
self.start_time = time.time()
def log(self, *args, **kwargs):
logger.info(*args, **kwargs)
def report_training(
self, step, num_steps, learning_rate, patience, report_stats, multigpu=False
):
"""
This is the user-defined batch-level traing progress
report function.
Args:
step(int): current step count.
num_steps(int): total number of batches.
learning_rate(float): current learning rate.
report_stats(Statistics): old Statistics instance.
Returns:
report_stats(Statistics): updated Statistics instance.
"""
if self.start_time < 0:
raise ValueError(
"""ReportMgr needs to be started
(set 'start_time' or use 'start()'"""
)
if step % self.report_every == 0:
if multigpu:
report_stats = onmt.utils.Statistics.all_gather_stats(report_stats)
self._report_training(
step, num_steps, learning_rate, patience, report_stats
)
return onmt.utils.Statistics()
else:
return report_stats
def _report_training(self, *args, **kwargs):
"""To be overridden"""
raise NotImplementedError()
def report_step(self, lr, patience, step, train_stats=None, valid_stats=None):
"""
Report stats of a step
Args:
lr(float): current learning rate
patience(int): current patience
step(int): current step
train_stats(Statistics): training stats
valid_stats(Statistics): validation stats
"""
self._report_step(
lr, patience, step, valid_stats=valid_stats, train_stats=train_stats
)
def _report_step(self, *args, **kwargs):
raise NotImplementedError()
class ReportMgr(ReportMgrBase):
def __init__(self, report_every, start_time=-1.0, tensorboard_writer=None):
"""
A report manager that writes statistics on standard output as well as
(optionally) TensorBoard
Args:
report_every(int): Report status every this many sentences
tensorboard_writer(:obj:`tensorboard.SummaryWriter`):
The TensorBoard Summary writer to use or None
"""
super(ReportMgr, self).__init__(report_every, start_time)
self.tensorboard_writer = tensorboard_writer
def maybe_log_tensorboard(self, stats, prefix, learning_rate, patience, step):
if self.tensorboard_writer is not None:
stats.log_tensorboard(
prefix, self.tensorboard_writer, learning_rate, patience, step
)
def _report_training(self, step, num_steps, learning_rate, patience, report_stats):
"""
See base class method `ReportMgrBase.report_training`.
"""
report_stats.output(step, num_steps, learning_rate, self.start_time)
self.maybe_log_tensorboard(
report_stats, "progress", learning_rate, patience, step
)
report_stats = onmt.utils.Statistics()
return report_stats
def _report_step(self, lr, patience, step, valid_stats=None, train_stats=None):
"""
See base class method `ReportMgrBase.report_step`.
"""
if train_stats is not None:
self.log("Train perplexity: %g" % train_stats.ppl())
self.log("Train accuracy: %g" % train_stats.accuracy())
self.log("Sentences processed: %g" % train_stats.n_sents)
self.log(
"Average bsz: %4.0f/%4.0f/%2.0f"
% (
train_stats.n_src_words / train_stats.n_batchs,
train_stats.n_words / train_stats.n_batchs,
train_stats.n_sents / train_stats.n_batchs,
)
)
self.maybe_log_tensorboard(train_stats, "train", lr, patience, step)
if valid_stats is not None:
self.log("Validation perplexity: %g" % valid_stats.ppl())
self.log("Validation accuracy: %g" % valid_stats.accuracy())
self.maybe_log_tensorboard(valid_stats, "valid", lr, patience, step)