NatureLM-Audio / NatureLM /checkpoint_utils.py
gagannarula's picture
App-redesign (#1)
32d3fde verified
"""Module for training utilities.
This module contains utility functions for training models. For example, saving model checkpoints.
"""
import logging
import os
import tempfile
from typing import Any, Union
import torch
import torch.nn as nn
logger = logging.getLogger(__name__)
def maybe_unwrap_dist_model(model: nn.Module, use_distributed: bool) -> nn.Module:
return model.module if use_distributed else model
def get_state_dict(model, drop_untrained_params: bool = True) -> dict[str, Any]:
"""Get model state dict. Optionally drop untrained parameters to keep only those that require gradient.
Args:
model: Model to get state dict from
drop_untrained_params: Whether to drop untrained parameters
Returns:
dict: Model state dict
"""
if not drop_untrained_params:
return model.state_dict()
param_grad_dict = {k: v.requires_grad for (k, v) in model.named_parameters()}
state_dict = model.state_dict()
for k in list(state_dict.keys()):
if k in param_grad_dict.keys() and not param_grad_dict[k]:
# delete parameters that do not require gradient
del state_dict[k]
return state_dict
def save_model_checkpoint(
model: nn.Module,
save_path: Union[str, os.PathLike],
use_distributed: bool = False,
drop_untrained_params: bool = False,
**objects_to_save,
) -> None:
"""Save model checkpoint.
Args:
model (nn.Module): Model to save
output_dir (str): Output directory to save checkpoint
use_distributed (bool): Whether the model is distributed, if so, unwrap it. Default: False.
is_best (bool): Whether the model is the best in the training run. Default: False.
drop_untrained_params (bool): Whether to drop untrained parameters to save. Default: True.
prefix (str): Prefix to add to the checkpoint file name. Default: "".
extention (str): Extension to use for the checkpoint file. Default: "pth".
**objects_to_save: Additional objects to save, e.g. optimizer state dict, etc.
"""
if not os.path.exists(os.path.dirname(save_path)):
raise FileNotFoundError(f"Directory {os.path.dirname(save_path)} does not exist.")
model_no_ddp = maybe_unwrap_dist_model(model, use_distributed)
state_dict = get_state_dict(model_no_ddp, drop_untrained_params)
save_obj = {
"model": state_dict,
**objects_to_save,
}
logger.info("Saving checkpoint to {}.".format(save_path))
torch.save(save_obj, save_path)