import collections import glob import logging import os from typing import List import torch from torch import nn from torch.optim.lr_scheduler import LambdaLR from torch.serialization import default_restore_location logger = logging.getLogger() CheckpointState = collections.namedtuple( "CheckpointState", [ "model_dict", "optimizer_dict", "scheduler_dict", "offset", "epoch", "encoder_params", ], ) def setup_for_distributed_mode( model: nn.Module, optimizer: torch.optim.Optimizer, device: object, n_gpu: int = 1, local_rank: int = -1, fp16: bool = False, fp16_opt_level: str = "O1", ) -> (nn.Module, torch.optim.Optimizer): model.to(device) if fp16: try: import apex from apex import amp apex.amp.register_half_function(torch, "einsum") except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level) if n_gpu > 1: model = torch.nn.DataParallel(model) if local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True, ) return model, optimizer def move_to_cuda(sample): if len(sample) == 0: return {} def _move_to_cuda(maybe_tensor): if torch.is_tensor(maybe_tensor): return maybe_tensor.cuda() elif isinstance(maybe_tensor, dict): return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()} elif isinstance(maybe_tensor, list): return [_move_to_cuda(x) for x in maybe_tensor] elif isinstance(maybe_tensor, tuple): return [_move_to_cuda(x) for x in maybe_tensor] else: return maybe_tensor return _move_to_cuda(sample) def move_to_device(sample, device): if len(sample) == 0: return {} def _move_to_device(maybe_tensor, device): if torch.is_tensor(maybe_tensor): return maybe_tensor.to(device) elif isinstance(maybe_tensor, dict): return { key: _move_to_device(value, device) for key, value in maybe_tensor.items() } elif isinstance(maybe_tensor, list): return [_move_to_device(x, device) for x in maybe_tensor] elif isinstance(maybe_tensor, tuple): return [_move_to_device(x, device) for x in maybe_tensor] else: return maybe_tensor return _move_to_device(sample, device) def get_schedule_linear(optimizer, warmup_steps, training_steps, last_epoch=-1): """Create a schedule with a learning rate that decreases linearly after linearly increasing during a warmup period. """ def lr_lambda(current_step): if current_step < warmup_steps: return float(current_step) / float(max(1, warmup_steps)) return max( 0.0, float(training_steps - current_step) / float(max(1, training_steps - warmup_steps)), ) return LambdaLR(optimizer, lr_lambda, last_epoch) def init_weights(modules: List): for module in modules: if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() def get_model_obj(model: nn.Module): return model.module if hasattr(model, "module") else model def get_model_file(args, file_prefix) -> str: if args.model_file and os.path.exists(args.model_file): return args.model_file out_cp_files = ( glob.glob(os.path.join(args.output_dir, file_prefix + "*")) if args.output_dir else [] ) logger.info("Checkpoint files %s", out_cp_files) model_file = None if len(out_cp_files) > 0: model_file = max(out_cp_files, key=os.path.getctime) return model_file def load_states_from_checkpoint(model_file: str) -> CheckpointState: logger.info("Reading saved model from s", model_file) if isinstance(model_file, tuple): model_file = model_file[0] state_dict = torch.load( model_file, map_location=lambda s, l: default_restore_location(s, "cpu") ) logger.info("model_state_dict keys %s", state_dict.keys()) return CheckpointState(**state_dict)