# python3.7 """Contains the base class for runner. This runner can be used for both training and inference with multi-threads. """ import os import json from copy import deepcopy import torch import torch.distributed as dist from datasets import BaseDataset from datasets import IterDataLoader from models import build_model from . import controllers from . import losses from . import misc from .optimizer import build_optimizers from .running_stats import RunningStats def _strip_state_dict_prefix(state_dict, prefix='module.'): """Removes the name prefix in checkpoint. Basically, when the model is deployed in parallel, the prefix `module.` will be added to the saved checkpoint. This function is used to remove the prefix, which is friendly to checkpoint loading. Args: state_dict: The state dict where the variable names are processed. prefix: The prefix to remove. (default: `module.`) """ if not all(key.startswith(prefix) for key in state_dict.keys()): return state_dict stripped_state_dict = dict() for key in state_dict: stripped_state_dict[key.replace(prefix, '')] = state_dict[key] return stripped_state_dict class BaseRunner(object): """Defines the base runner class.""" def __init__(self, config, logger): self._name = self.__class__.__name__ self._config = deepcopy(config) self.logger = logger self.work_dir = self.config.work_dir os.makedirs(self.work_dir, exist_ok=True) self.logger.info('Running Configuration:') config_str = json.dumps(self.config, indent=4).replace('"', '\'') self.logger.print(config_str + '\n') with open(os.path.join(self.work_dir, 'config.json'), 'w') as f: json.dump(self.config, f, indent=4) self._rank = dist.get_rank() self._world_size = dist.get_world_size() self.batch_size = self.config.batch_size self.val_batch_size = self.config.get('val_batch_size', self.batch_size) self._iter = 0 self._start_iter = 0 self.seen_img = 0 self.total_iters = self.config.get('total_iters', 0) if self.total_iters == 0 and self.config.get('total_img', 0) > 0: total_image = self.config.get('total_img') total_batch = self.world_size * self.batch_size self.total_iters = int(total_image / total_batch + 0.5) self.mode = None self.train_loader = None self.val_loader = None self.models = dict() self.optimizers = dict() self.lr_schedulers = dict() self.controllers = [] self.loss = None self.running_stats = RunningStats() self.start_time = 0 self.end_time = 0 self.timer = controllers.Timer() self.timer.start(self) self.build_models() self.build_controllers() def finish(self): """Finishes runner by ending controllers and timer.""" for controller in self.controllers: controller.end(self) self.timer.end(self) self.logger.info(f'Finish runner in ' f'{misc.format_time(self.end_time - self.start_time)}') @property def name(self): """Returns the name of the runner.""" return self._name @property def config(self): """Returns the configuration of the runner.""" return self._config @property def rank(self): """Returns the rank of the current runner.""" return self._rank @property def world_size(self): """Returns the world size.""" return self._world_size @property def iter(self): """Returns the current iteration.""" return self._iter @property def start_iter(self): """Returns the start iteration.""" return self._start_iter def convert_epoch_to_iter(self, epoch): """Converts number of epochs to number of iterations.""" return int(epoch * len(self.train_loader) + 0.5) def build_dataset(self, mode): """Builds train/val dataset.""" if not hasattr(self.config, 'data'): return assert isinstance(mode, str) mode = mode.lower() self.logger.info(f'Building `{mode}` dataset ...') if mode not in ['train', 'val']: raise ValueError(f'Invalid dataset mode `{mode}`!') dataset = BaseDataset(**self.config.data[mode]) if mode == 'train': self.train_loader = IterDataLoader( dataset=dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.config.data.get('num_workers', 2), current_iter=self.iter, repeat=self.config.data.get('repeat', 1)) elif mode == 'val': self.val_loader = IterDataLoader( dataset=dataset, batch_size=self.val_batch_size, shuffle=False, num_workers=self.config.data.get('num_workers', 2), current_iter=0, repeat=1) else: raise NotImplementedError(f'Not implemented dataset mode `{mode}`!') self.logger.info(f'Finish building `{mode}` dataset.') def build_models(self): """Builds models, optimizers, and learning rate schedulers.""" self.logger.info(f'Building models ...') lr_config = dict() opt_config = dict() for module, module_config in self.config.modules.items(): model_config = module_config['model'] self.models[module] = build_model(module=module, **model_config) self.models[module].cuda() opt_config[module] = module_config.get('opt', None) lr_config[module] = module_config.get('lr', None) build_optimizers(opt_config, self) self.controllers.append(controllers.LRScheduler(lr_config)) self.logger.info(f'Finish building models.') model_info = 'Model structures:\n' model_info += '==============================================\n' for module in self.models: model_info += f'{module}\n' model_info += '----------------------------------------------\n' model_info += str(self.models[module]) model_info += '\n' model_info += "==============================================\n" self.logger.info(model_info) def distribute(self): """Sets `self.model` as `torch.nn.parallel.DistributedDataParallel`.""" for name in self.models: self.models[name] = torch.nn.parallel.DistributedDataParallel( module=self.models[name], device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=True) @staticmethod def get_module(model): """Handles distributed model.""" if hasattr(model, 'module'): return model.module return model def build_controllers(self): """Builds additional controllers besides LRScheduler.""" if not hasattr(self.config, 'controllers'): return self.logger.info(f'Building controllers ...') for key, ctrl_config in self.config.controllers.items(): self.controllers.append(getattr(controllers, key)(ctrl_config)) self.controllers.sort(key=lambda x: x.priority) for controller in self.controllers: controller.start(self) self.logger.info(f'Finish building controllers.') def build_loss(self): """Builds loss functions.""" if not hasattr(self.config, 'loss'): return self.logger.info(f'Building loss function ...') loss_config = deepcopy(self.config.loss) loss_type = loss_config.pop('type') self.loss = getattr(losses, loss_type)(self, **loss_config) self.logger.info(f'Finish building loss function.') def pre_execute_controllers(self): """Pre-executes all controllers in order of priority.""" for controller in self.controllers: controller.pre_execute(self) def post_execute_controllers(self): """Post-executes all controllers in order of priority.""" for controller in self.controllers: controller.post_execute(self) def cpu(self): """Puts models to CPU.""" for name in self.models: self.models[name].cpu() def cuda(self): """Puts models to CUDA.""" for name in self.models: self.models[name].cuda() def set_model_requires_grad(self, name, requires_grad): """Sets the `requires_grad` configuration for a particular model.""" for param in self.models[name].parameters(): param.requires_grad = requires_grad def set_models_requires_grad(self, requires_grad): """Sets the `requires_grad` configuration for all models.""" for name in self.models: self.set_model_requires_grad(name, requires_grad) def set_model_mode(self, name, mode): """Sets the `train/val` mode for a particular model.""" if isinstance(mode, str): mode = mode.lower() if mode == 'train' or mode is True: self.models[name].train() elif mode in ['val', 'test', 'eval'] or mode is False: self.models[name].eval() else: raise ValueError(f'Invalid model mode `{mode}`!') def set_mode(self, mode): """Sets the `train/val` mode for all models.""" self.mode = mode for name in self.models: self.set_model_mode(name, mode) def train_step(self, data, **train_kwargs): """Executes one training step.""" raise NotImplementedError('Should be implemented in derived class.') def train(self, **train_kwargs): """Training function.""" self.set_mode('train') self.distribute() self.build_dataset('train') self.build_loss() self.logger.print() self.logger.info(f'Start training.') if self.total_iters == 0: total_epochs = self.config.get('total_epochs', 0) self.total_iters = self.convert_epoch_to_iter(total_epochs) assert self.total_iters > 0 while self.iter < self.total_iters: self._iter += 1 self.pre_execute_controllers() data_batch = next(self.train_loader) self.timer.pre_execute(self) for key in data_batch: assert data_batch[key].shape[0] == self.batch_size data_batch[key] = data_batch[key].cuda( torch.cuda.current_device(), non_blocking=True) self.train_step(data_batch, **train_kwargs) self.seen_img += self.batch_size * self.world_size self.timer.post_execute(self) self.post_execute_controllers() self.finish() def val(self, **val_kwargs): """Validation function.""" raise NotImplementedError('Should be implemented in derived class.') def save(self, filepath, running_metadata=True, learning_rate=True, optimizer=True, running_stats=False): """Saves the current running status. Args: filepath: File path to save the checkpoint. running_metadata: Whether to save the running metadata, such as batch size, current iteration, etc. (default: True) learning_rate: Whether to save the learning rate. (default: True) optimizer: Whether to save the optimizer. (default: True) running_stats: Whether to save the running stats. (default: False) """ checkpoint = dict() # Models. checkpoint['models'] = dict() for name, model in self.models.items(): checkpoint['models'][name] = self.get_module(model).state_dict() # Running metadata. if running_metadata: checkpoint['running_metadata'] = { 'iter': self.iter, 'seen_img': self.seen_img, } # Optimizers. if optimizer: checkpoint['optimizers'] = dict() for opt_name, opt in self.optimizers.items(): checkpoint['optimizers'][opt_name] = opt.state_dict() # Learning rates. if learning_rate: checkpoint['learning_rates'] = dict() for lr_name, lr in self.lr_schedulers.items(): checkpoint['learning_rates'][lr_name] = lr.state_dict() # Running stats. # TODO: Test saving and loading running stats. if running_stats: checkpoint['running_stats'] = self.running_stats # Save checkpoint. os.makedirs(os.path.dirname(filepath), exist_ok=True) torch.save(checkpoint, filepath) self.logger.info(f'Successfully saved checkpoint to `{filepath}`.') def load(self, filepath, running_metadata=True, learning_rate=True, optimizer=True, running_stats=False, map_location='cpu'): """Loads previous running status. Args: filepath: File path to load the checkpoint. running_metadata: Whether to load the running metadata, such as batch size, current iteration, etc. (default: True) learning_rate: Whether to load the learning rate. (default: True) optimizer: Whether to load the optimizer. (default: True) running_stats: Whether to load the running stats. (default: False) map_location: Map location used for model loading. (default: `cpu`) """ self.logger.info(f'Resuming from checkpoint `{filepath}` ...') if not os.path.isfile(filepath): raise IOError(f'Checkpoint `{filepath}` does not exist!') map_location = map_location.lower() assert map_location in ['cpu', 'gpu'] if map_location == 'gpu': device = torch.cuda.current_device() map_location = lambda storage, location: storage.cuda(device) checkpoint = torch.load(filepath, map_location=map_location) # Load models. if 'models' not in checkpoint: checkpoint = {'models': checkpoint} for model_name, model in self.models.items(): if model_name not in checkpoint['models']: self.logger.warning(f'Model `{model_name}` is not included in ' f'the checkpoint, and hence will NOT be ' f'loaded!') continue state_dict = _strip_state_dict_prefix( checkpoint['models'][model_name]) model.load_state_dict(state_dict) self.logger.info(f' Successfully loaded model `{model_name}`.') # Load running metedata. if running_metadata: if 'running_metadata' not in checkpoint: self.logger.warning(f'Running metadata is not included in the ' f'checkpoint, and hence will NOT be ' f'loaded!') else: self._iter = checkpoint['running_metadata']['iter'] self._start_iter = self._iter self.seen_img = checkpoint['running_metadata']['seen_img'] # Load optimizers. if optimizer: if 'optimizers' not in checkpoint: self.logger.warning(f'Optimizers are not included in the ' f'checkpoint, and hence will NOT be ' f'loaded!') else: for opt_name, opt in self.optimizers.items(): if opt_name not in checkpoint['optimizers']: self.logger.warning(f'Optimizer `{opt_name}` is not ' f'included in the checkpoint, and ' f'hence will NOT be loaded!') continue opt.load_state_dict(checkpoint['optimizers'][opt_name]) self.logger.info(f' Successfully loaded optimizer ' f'`{opt_name}`.') # Load learning rates. if learning_rate: if 'learning_rates' not in checkpoint: self.logger.warning(f'Learning rates are not included in the ' f'checkpoint, and hence will NOT be ' f'loaded!') else: for lr_name, lr in self.lr_schedulers.items(): if lr_name not in checkpoint['learning_rates']: self.logger.warning(f'Learning rate `{lr_name}` is not ' f'included in the checkpoint, and ' f'hence will NOT be loaded!') continue lr.load_state_dict(checkpoint['learning_rates'][lr_name]) self.logger.info(f' Successfully loaded learning rate ' f'`{lr_name}`.') # Load running stats. if running_stats: if 'running_stats' not in checkpoint: self.logger.warning(f'Running stats is not included in the ' f'checkpoint, and hence will NOT be ' f'loaded!') else: self.running_stats = deepcopy(checkpoint['running_stats']) self.logger.info(f' Successfully loaded running stats.') # Log message. tailing_message = '' if running_metadata and 'running_metadata' in checkpoint: tailing_message = f' (iteration {self.iter})' self.logger.info(f'Successfully resumed from checkpoint `{filepath}`.' f'{tailing_message}')