#!/usr/bin/env python """Training on a single process.""" import torch from onmt.inputters.inputter import IterOnDevice from onmt.model_builder import build_model from onmt.utils.optimizers import Optimizer from onmt.utils.misc import set_random_seed from onmt.trainer import build_trainer from onmt.models import build_model_saver from onmt.utils.logging import init_logger, logger from onmt.utils.parse import ArgumentParser from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter def configure_process(opt, device_id): if device_id >= 0: torch.cuda.set_device(device_id) set_random_seed(opt.seed, device_id >= 0) def _get_model_opts(opt, checkpoint=None): """Get `model_opt` to build model, may load from `checkpoint` if any.""" if checkpoint is not None: model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) if (opt.tensorboard_log_dir == model_opt.tensorboard_log_dir and hasattr(model_opt, 'tensorboard_log_dir_dated')): # ensure tensorboard output is written in the directory # of previous checkpoints opt.tensorboard_log_dir_dated = model_opt.tensorboard_log_dir_dated # Override checkpoint's update_embeddings as it defaults to false model_opt.update_vocab = opt.update_vocab else: model_opt = opt return model_opt def _build_valid_iter(opt, fields, transforms_cls): """Build iterator used for validation.""" valid_iter = build_dynamic_dataset_iter( fields, transforms_cls, opt, is_train=False) return valid_iter def _build_train_iter(opt, fields, transforms_cls, stride=1, offset=0): """Build training iterator.""" train_iter = build_dynamic_dataset_iter( fields, transforms_cls, opt, is_train=True, stride=stride, offset=offset) return train_iter def main(opt, fields, transforms_cls, checkpoint, device_id, batch_queue=None, semaphore=None): """Start training on `device_id`.""" # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) init_logger(opt.log_file) model_opt = _get_model_opts(opt, checkpoint=checkpoint) # Build model. model = build_model(model_opt, opt, fields, checkpoint) model.count_parameters(log=logger.info) # Build optimizer. optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) # Build model saver model_saver = build_model_saver(model_opt, opt, model, fields, optim) trainer = build_trainer( opt, device_id, model, fields, optim, model_saver=model_saver) if batch_queue is None: _train_iter = _build_train_iter(opt, fields, transforms_cls) train_iter = IterOnDevice(_train_iter, device_id) else: assert semaphore is not None, \ "Using batch_queue requires semaphore as well" def _train_iter(): while True: batch = batch_queue.get() semaphore.release() # Move batch to specified device IterOnDevice.batch_to_device(batch, device_id) yield batch train_iter = _train_iter() valid_iter = _build_valid_iter(opt, fields, transforms_cls) if valid_iter is not None: valid_iter = IterOnDevice(valid_iter, device_id) if len(opt.gpu_ranks): logger.info('Starting training on GPU: %s' % opt.gpu_ranks) else: logger.info('Starting training on CPU, could be very slow') train_steps = opt.train_steps if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 trainer.train( train_iter, train_steps, save_checkpoint_steps=opt.save_checkpoint_steps, valid_iter=valid_iter, valid_steps=opt.valid_steps) if trainer.report_manager.tensorboard_writer is not None: trainer.report_manager.tensorboard_writer.close()