"""Module for training utilities. This module contains utility functions for training models. For example, saving model checkpoints. """ import logging import os import tempfile from typing import Any, Union import torch import torch.nn as nn logger = logging.getLogger(__name__) def maybe_unwrap_dist_model(model: nn.Module, use_distributed: bool) -> nn.Module: return model.module if use_distributed else model def get_state_dict(model, drop_untrained_params: bool = True) -> dict[str, Any]: """Get model state dict. Optionally drop untrained parameters to keep only those that require gradient. Args: model: Model to get state dict from drop_untrained_params: Whether to drop untrained parameters Returns: dict: Model state dict """ if not drop_untrained_params: return model.state_dict() param_grad_dict = {k: v.requires_grad for (k, v) in model.named_parameters()} state_dict = model.state_dict() for k in list(state_dict.keys()): if k in param_grad_dict.keys() and not param_grad_dict[k]: # delete parameters that do not require gradient del state_dict[k] return state_dict def save_model_checkpoint( model: nn.Module, save_path: Union[str, os.PathLike], use_distributed: bool = False, drop_untrained_params: bool = False, **objects_to_save, ) -> None: """Save model checkpoint. Args: model (nn.Module): Model to save output_dir (str): Output directory to save checkpoint use_distributed (bool): Whether the model is distributed, if so, unwrap it. Default: False. is_best (bool): Whether the model is the best in the training run. Default: False. drop_untrained_params (bool): Whether to drop untrained parameters to save. Default: True. prefix (str): Prefix to add to the checkpoint file name. Default: "". extention (str): Extension to use for the checkpoint file. Default: "pth". **objects_to_save: Additional objects to save, e.g. optimizer state dict, etc. """ if not os.path.exists(os.path.dirname(save_path)): raise FileNotFoundError(f"Directory {os.path.dirname(save_path)} does not exist.") model_no_ddp = maybe_unwrap_dist_model(model, use_distributed) state_dict = get_state_dict(model_no_ddp, drop_untrained_params) save_obj = { "model": state_dict, **objects_to_save, } logger.info("Saving checkpoint to {}.".format(save_path)) torch.save(save_obj, save_path)