|
|
|
|
|
|
|
|
|
|
|
|
|
"""LM training in pytorch.""" |
|
|
|
import copy |
|
import json |
|
import logging |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn.parallel import data_parallel |
|
|
|
from chainer import Chain |
|
from chainer.dataset import convert |
|
from chainer import reporter |
|
from chainer import training |
|
from chainer.training import extensions |
|
|
|
from espnet.lm.lm_utils import count_tokens |
|
from espnet.lm.lm_utils import load_dataset |
|
from espnet.lm.lm_utils import MakeSymlinkToBestModel |
|
from espnet.lm.lm_utils import ParallelSentenceIterator |
|
from espnet.lm.lm_utils import read_tokens |
|
from espnet.nets.lm_interface import dynamic_import_lm |
|
from espnet.nets.lm_interface import LMInterface |
|
from espnet.optimizer.factory import dynamic_import_optimizer |
|
from espnet.scheduler.pytorch import PyTorchScheduler |
|
from espnet.scheduler.scheduler import dynamic_import_scheduler |
|
|
|
from espnet.asr.asr_utils import snapshot_object |
|
from espnet.asr.asr_utils import torch_load |
|
from espnet.asr.asr_utils import torch_resume |
|
from espnet.asr.asr_utils import torch_snapshot |
|
|
|
from espnet.utils.training.tensorboard_logger import TensorboardLogger |
|
from tensorboardX import SummaryWriter |
|
|
|
from espnet.utils.deterministic_utils import set_deterministic_pytorch |
|
from espnet.utils.training.evaluator import BaseEvaluator |
|
from espnet.utils.training.iterators import ShufflingEnabler |
|
from espnet.utils.training.train_utils import check_early_stop |
|
from espnet.utils.training.train_utils import set_early_stop |
|
|
|
|
|
def compute_perplexity(result): |
|
"""Compute and add the perplexity to the LogReport. |
|
|
|
:param dict result: The current observations |
|
""" |
|
|
|
result["perplexity"] = np.exp(result["main/nll"] / result["main/count"]) |
|
if "validation/main/nll" in result: |
|
result["val_perplexity"] = np.exp( |
|
result["validation/main/nll"] / result["validation/main/count"] |
|
) |
|
|
|
|
|
class Reporter(Chain): |
|
"""Dummy module to use chainer's trainer.""" |
|
|
|
def report(self, loss): |
|
"""Report nothing.""" |
|
pass |
|
|
|
|
|
def concat_examples(batch, device=None, padding=None): |
|
"""Concat examples in minibatch. |
|
|
|
:param np.ndarray batch: The batch to concatenate |
|
:param int device: The device to send to |
|
:param Tuple[int,int] padding: The padding to use |
|
:return: (inputs, targets) |
|
:rtype (torch.Tensor, torch.Tensor) |
|
""" |
|
x, t = convert.concat_examples(batch, padding=padding) |
|
x = torch.from_numpy(x) |
|
t = torch.from_numpy(t) |
|
if device is not None and device >= 0: |
|
x = x.cuda(device) |
|
t = t.cuda(device) |
|
return x, t |
|
|
|
|
|
class BPTTUpdater(training.StandardUpdater): |
|
"""An updater for a pytorch LM.""" |
|
|
|
def __init__( |
|
self, |
|
train_iter, |
|
model, |
|
optimizer, |
|
schedulers, |
|
device, |
|
gradclip=None, |
|
use_apex=False, |
|
accum_grad=1, |
|
): |
|
"""Initialize class. |
|
|
|
Args: |
|
train_iter (chainer.dataset.Iterator): The train iterator |
|
model (LMInterface) : The model to update |
|
optimizer (torch.optim.Optimizer): The optimizer for training |
|
schedulers (espnet.scheduler.scheduler.SchedulerInterface): |
|
The schedulers of `optimizer` |
|
device (int): The device id |
|
gradclip (float): The gradient clipping value to use |
|
use_apex (bool): The flag to use Apex in backprop. |
|
accum_grad (int): The number of gradient accumulation. |
|
|
|
""" |
|
super(BPTTUpdater, self).__init__(train_iter, optimizer) |
|
self.model = model |
|
self.device = device |
|
self.gradclip = gradclip |
|
self.use_apex = use_apex |
|
self.scheduler = PyTorchScheduler(schedulers, optimizer) |
|
self.accum_grad = accum_grad |
|
|
|
|
|
def update_core(self): |
|
"""Update the model.""" |
|
|
|
|
|
train_iter = self.get_iterator("main") |
|
optimizer = self.get_optimizer("main") |
|
|
|
self.model.zero_grad() |
|
accum = {"loss": 0.0, "nll": 0.0, "count": 0} |
|
for _ in range(self.accum_grad): |
|
batch = train_iter.__next__() |
|
|
|
|
|
|
|
x, t = concat_examples(batch, device=self.device[0], padding=(0, -100)) |
|
if self.device[0] == -1: |
|
loss, nll, count = self.model(x, t) |
|
else: |
|
|
|
loss, nll, count = data_parallel(self.model, (x, t), self.device) |
|
|
|
|
|
loss = loss.mean() / self.accum_grad |
|
if self.use_apex: |
|
from apex import amp |
|
|
|
with amp.scale_loss(loss, optimizer) as scaled_loss: |
|
scaled_loss.backward() |
|
else: |
|
loss.backward() |
|
|
|
accum["loss"] += float(loss) |
|
accum["nll"] += float(nll.sum()) |
|
accum["count"] += int(count.sum()) |
|
|
|
for k, v in accum.items(): |
|
reporter.report({k: v}, optimizer.target) |
|
if self.gradclip is not None: |
|
nn.utils.clip_grad_norm_(self.model.parameters(), self.gradclip) |
|
optimizer.step() |
|
self.scheduler.step(n_iter=self.iteration) |
|
|
|
|
|
class LMEvaluator(BaseEvaluator): |
|
"""A custom evaluator for a pytorch LM.""" |
|
|
|
def __init__(self, val_iter, eval_model, reporter, device): |
|
"""Initialize class. |
|
|
|
:param chainer.dataset.Iterator val_iter : The validation iterator |
|
:param LMInterface eval_model : The model to evaluate |
|
:param chainer.Reporter reporter : The observations reporter |
|
:param int device : The device id to use |
|
|
|
""" |
|
super(LMEvaluator, self).__init__(val_iter, reporter, device=-1) |
|
self.model = eval_model |
|
self.device = device |
|
|
|
def evaluate(self): |
|
"""Evaluate the model.""" |
|
val_iter = self.get_iterator("main") |
|
loss = 0 |
|
nll = 0 |
|
count = 0 |
|
self.model.eval() |
|
with torch.no_grad(): |
|
for batch in copy.copy(val_iter): |
|
x, t = concat_examples(batch, device=self.device[0], padding=(0, -100)) |
|
if self.device[0] == -1: |
|
l, n, c = self.model(x, t) |
|
else: |
|
|
|
l, n, c = data_parallel(self.model, (x, t), self.device) |
|
loss += float(l.sum()) |
|
nll += float(n.sum()) |
|
count += int(c.sum()) |
|
self.model.train() |
|
|
|
observation = {} |
|
with reporter.report_scope(observation): |
|
reporter.report({"loss": loss}, self.model.reporter) |
|
reporter.report({"nll": nll}, self.model.reporter) |
|
reporter.report({"count": count}, self.model.reporter) |
|
return observation |
|
|
|
|
|
def train(args): |
|
"""Train with the given args. |
|
|
|
:param Namespace args: The program arguments |
|
:param type model_class: LMInterface class for training |
|
""" |
|
model_class = dynamic_import_lm(args.model_module, args.backend) |
|
assert issubclass(model_class, LMInterface), "model should implement LMInterface" |
|
|
|
logging.info("torch version = " + torch.__version__) |
|
|
|
set_deterministic_pytorch(args) |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
logging.warning("cuda is not available") |
|
|
|
|
|
unk = args.char_list_dict["<unk>"] |
|
eos = args.char_list_dict["<eos>"] |
|
|
|
val, n_val_tokens, n_val_oovs = load_dataset( |
|
args.valid_label, args.char_list_dict, args.dump_hdf5_path |
|
) |
|
train, n_train_tokens, n_train_oovs = load_dataset( |
|
args.train_label, args.char_list_dict, args.dump_hdf5_path |
|
) |
|
logging.info("#vocab = " + str(args.n_vocab)) |
|
logging.info("#sentences in the training data = " + str(len(train))) |
|
logging.info("#tokens in the training data = " + str(n_train_tokens)) |
|
logging.info( |
|
"oov rate in the training data = %.2f %%" |
|
% (n_train_oovs / n_train_tokens * 100) |
|
) |
|
logging.info("#sentences in the validation data = " + str(len(val))) |
|
logging.info("#tokens in the validation data = " + str(n_val_tokens)) |
|
logging.info( |
|
"oov rate in the validation data = %.2f %%" % (n_val_oovs / n_val_tokens * 100) |
|
) |
|
|
|
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 |
|
|
|
batch_size = args.batchsize * max(args.ngpu, 1) |
|
if batch_size * args.accum_grad > args.batchsize: |
|
logging.info( |
|
f"batch size is automatically increased " |
|
f"({args.batchsize} -> {batch_size * args.accum_grad})" |
|
) |
|
train_iter = ParallelSentenceIterator( |
|
train, |
|
batch_size, |
|
max_length=args.maxlen, |
|
sos=eos, |
|
eos=eos, |
|
shuffle=not use_sortagrad, |
|
) |
|
val_iter = ParallelSentenceIterator( |
|
val, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False |
|
) |
|
epoch_iters = int(len(train_iter.batch_indices) / args.accum_grad) |
|
logging.info("#iterations per epoch = %d" % epoch_iters) |
|
logging.info("#total iterations = " + str(args.epoch * epoch_iters)) |
|
|
|
if args.train_dtype in ("float16", "float32", "float64"): |
|
dtype = getattr(torch, args.train_dtype) |
|
else: |
|
dtype = torch.float32 |
|
model = model_class(args.n_vocab, args).to(dtype=dtype) |
|
if args.ngpu > 0: |
|
model.to("cuda") |
|
gpu_id = list(range(args.ngpu)) |
|
else: |
|
gpu_id = [-1] |
|
|
|
|
|
model_conf = args.outdir + "/model.json" |
|
with open(model_conf, "wb") as f: |
|
logging.info("writing a model config file to " + model_conf) |
|
f.write( |
|
json.dumps(vars(args), indent=4, ensure_ascii=False, sort_keys=True).encode( |
|
"utf_8" |
|
) |
|
) |
|
|
|
logging.warning( |
|
"num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format( |
|
sum(p.numel() for p in model.parameters()), |
|
sum(p.numel() for p in model.parameters() if p.requires_grad), |
|
sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
* 100.0 |
|
/ sum(p.numel() for p in model.parameters()), |
|
) |
|
) |
|
|
|
|
|
opt_class = dynamic_import_optimizer(args.opt, args.backend) |
|
optimizer = opt_class.from_args(model.parameters(), args) |
|
if args.schedulers is None: |
|
schedulers = [] |
|
else: |
|
schedulers = [dynamic_import_scheduler(v)(k, args) for k, v in args.schedulers] |
|
|
|
|
|
if args.train_dtype in ("O0", "O1", "O2", "O3"): |
|
try: |
|
from apex import amp |
|
except ImportError as e: |
|
logging.error( |
|
f"You need to install apex for --train-dtype {args.train_dtype}. " |
|
"See https://github.com/NVIDIA/apex#linux" |
|
) |
|
raise e |
|
model, optimizer = amp.initialize(model, optimizer, opt_level=args.train_dtype) |
|
use_apex = True |
|
else: |
|
use_apex = False |
|
|
|
|
|
reporter = Reporter() |
|
setattr(model, "reporter", reporter) |
|
setattr(optimizer, "target", reporter) |
|
setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) |
|
|
|
updater = BPTTUpdater( |
|
train_iter, |
|
model, |
|
optimizer, |
|
schedulers, |
|
gpu_id, |
|
gradclip=args.gradclip, |
|
use_apex=use_apex, |
|
accum_grad=args.accum_grad, |
|
) |
|
trainer = training.Trainer(updater, (args.epoch, "epoch"), out=args.outdir) |
|
trainer.extend(LMEvaluator(val_iter, model, reporter, device=gpu_id)) |
|
trainer.extend( |
|
extensions.LogReport( |
|
postprocess=compute_perplexity, |
|
trigger=(args.report_interval_iters, "iteration"), |
|
) |
|
) |
|
trainer.extend( |
|
extensions.PrintReport( |
|
[ |
|
"epoch", |
|
"iteration", |
|
"main/loss", |
|
"perplexity", |
|
"val_perplexity", |
|
"elapsed_time", |
|
] |
|
), |
|
trigger=(args.report_interval_iters, "iteration"), |
|
) |
|
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters)) |
|
|
|
trainer.extend(torch_snapshot(filename="snapshot.ep.{.updater.epoch}")) |
|
trainer.extend(snapshot_object(model, "rnnlm.model.{.updater.epoch}")) |
|
|
|
trainer.extend(MakeSymlinkToBestModel("validation/main/loss", "rnnlm.model")) |
|
|
|
if use_sortagrad: |
|
trainer.extend( |
|
ShufflingEnabler([train_iter]), |
|
trigger=(args.sortagrad if args.sortagrad != -1 else args.epoch, "epoch"), |
|
) |
|
if args.resume: |
|
logging.info("resumed from %s" % args.resume) |
|
torch_resume(args.resume, trainer) |
|
|
|
set_early_stop(trainer, args, is_lm=True) |
|
if args.tensorboard_dir is not None and args.tensorboard_dir != "": |
|
writer = SummaryWriter(args.tensorboard_dir) |
|
trainer.extend( |
|
TensorboardLogger(writer), trigger=(args.report_interval_iters, "iteration") |
|
) |
|
|
|
trainer.run() |
|
check_early_stop(trainer, args.epoch) |
|
|
|
|
|
if args.test_label: |
|
logging.info("test the best model") |
|
torch_load(args.outdir + "/rnnlm.model.best", model) |
|
test = read_tokens(args.test_label, args.char_list_dict) |
|
n_test_tokens, n_test_oovs = count_tokens(test, unk) |
|
logging.info("#sentences in the test data = " + str(len(test))) |
|
logging.info("#tokens in the test data = " + str(n_test_tokens)) |
|
logging.info( |
|
"oov rate in the test data = %.2f %%" % (n_test_oovs / n_test_tokens * 100) |
|
) |
|
test_iter = ParallelSentenceIterator( |
|
test, batch_size, max_length=args.maxlen, sos=eos, eos=eos, repeat=False |
|
) |
|
evaluator = LMEvaluator(test_iter, model, reporter, device=gpu_id) |
|
result = evaluator() |
|
compute_perplexity(result) |
|
logging.info(f"test perplexity: {result['perplexity']}") |
|
|