# This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/utils/base_trainer.py from __future__ import division import logging from utils import CheckpointSaver from tensorboardX import SummaryWriter import torch from tqdm import tqdm tqdm.monitor_interval = 0 logger = logging.getLogger(__name__) class BaseTrainer(object): """Base class for Trainer objects. Takes care of checkpointing/logging/resuming training. """ def __init__(self, options): self.options = options if options.multiprocessing_distributed: self.device = torch.device('cuda', options.gpu) else: self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') # override this function to define your model, optimizers etc. self.saver = CheckpointSaver(save_dir=options.checkpoint_dir, overwrite=options.overwrite) if options.rank == 0: self.summary_writer = SummaryWriter(self.options.summary_dir) self.init_fn() self.checkpoint = None if options.resume and self.saver.exists_checkpoint(): self.checkpoint = self.saver.load_checkpoint( self.models_dict, self.optimizers_dict) if self.checkpoint is None: self.epoch_count = 0 self.step_count = 0 else: self.epoch_count = self.checkpoint['epoch'] self.step_count = self.checkpoint['total_step_count'] if self.checkpoint is not None: self.checkpoint_batch_idx = self.checkpoint['batch_idx'] else: self.checkpoint_batch_idx = 0 self.best_performance = float('inf') def load_pretrained(self, checkpoint_file=None): """Load a pretrained checkpoint. This is different from resuming training using --resume. """ if checkpoint_file is not None: checkpoint = torch.load(checkpoint_file) for model in self.models_dict: if model in checkpoint: self.models_dict[model].load_state_dict(checkpoint[model], strict=True) print(f'Checkpoint {model} loaded') def move_dict_to_device(self, dict, device, tensor2float=False): for k, v in dict.items(): if isinstance(v, torch.Tensor): if tensor2float: dict[k] = v.float().to(device) else: dict[k] = v.to(device) # The following methods (with the possible exception of test) have to be implemented in the derived classes def train(self, epoch): raise NotImplementedError('You need to provide an train method') def init_fn(self): raise NotImplementedError('You need to provide an _init_fn method') def train_step(self, input_batch): raise NotImplementedError('You need to provide a _train_step method') def train_summaries(self, input_batch): raise NotImplementedError( 'You need to provide a _train_summaries method') def visualize(self, input_batch): raise NotImplementedError('You need to provide a visualize method') def validate(self): pass def test(self): pass def evaluate(self): pass def fit(self): # Run training for num_epochs epochs for epoch in tqdm(range(self.epoch_count, self.options.num_epochs), total=self.options.num_epochs, initial=self.epoch_count): self.epoch_count = epoch self.train(epoch) return