Spaces:
Running
on
Zero
Running
on
Zero
| """Module for training utilities. | |
| This module contains utility functions for training models. For example, saving model checkpoints. | |
| """ | |
| import logging | |
| import os | |
| import tempfile | |
| from typing import Any, Union | |
| import torch | |
| import torch.nn as nn | |
| logger = logging.getLogger(__name__) | |
| def maybe_unwrap_dist_model(model: nn.Module, use_distributed: bool) -> nn.Module: | |
| return model.module if use_distributed else model | |
| def get_state_dict(model, drop_untrained_params: bool = True) -> dict[str, Any]: | |
| """Get model state dict. Optionally drop untrained parameters to keep only those that require gradient. | |
| Args: | |
| model: Model to get state dict from | |
| drop_untrained_params: Whether to drop untrained parameters | |
| Returns: | |
| dict: Model state dict | |
| """ | |
| if not drop_untrained_params: | |
| return model.state_dict() | |
| param_grad_dict = {k: v.requires_grad for (k, v) in model.named_parameters()} | |
| state_dict = model.state_dict() | |
| for k in list(state_dict.keys()): | |
| if k in param_grad_dict.keys() and not param_grad_dict[k]: | |
| # delete parameters that do not require gradient | |
| del state_dict[k] | |
| return state_dict | |
| def save_model_checkpoint( | |
| model: nn.Module, | |
| save_path: Union[str, os.PathLike], | |
| use_distributed: bool = False, | |
| drop_untrained_params: bool = False, | |
| **objects_to_save, | |
| ) -> None: | |
| """Save model checkpoint. | |
| Args: | |
| model (nn.Module): Model to save | |
| output_dir (str): Output directory to save checkpoint | |
| use_distributed (bool): Whether the model is distributed, if so, unwrap it. Default: False. | |
| is_best (bool): Whether the model is the best in the training run. Default: False. | |
| drop_untrained_params (bool): Whether to drop untrained parameters to save. Default: True. | |
| prefix (str): Prefix to add to the checkpoint file name. Default: "". | |
| extention (str): Extension to use for the checkpoint file. Default: "pth". | |
| **objects_to_save: Additional objects to save, e.g. optimizer state dict, etc. | |
| """ | |
| if not os.path.exists(os.path.dirname(save_path)): | |
| raise FileNotFoundError(f"Directory {os.path.dirname(save_path)} does not exist.") | |
| model_no_ddp = maybe_unwrap_dist_model(model, use_distributed) | |
| state_dict = get_state_dict(model_no_ddp, drop_untrained_params) | |
| save_obj = { | |
| "model": state_dict, | |
| **objects_to_save, | |
| } | |
| logger.info("Saving checkpoint to {}.".format(save_path)) | |
| torch.save(save_obj, save_path) | |