""" Contains various utility function for Pytorch model training and saving. """ import torch from pathlib import Path def save_model(model: torch.nn.Module, target_dir: str, model_name: str): """ Saves a PyTorch model to a target directory Args: model (torch.nn.Module): _description_ target_dir (str): _description_ model_name (str): _description_ """ # Create a target directory target_dir_path = Path(target_dir) target_dir_path.mkdir(parents=True, exist_ok=True) # Create model save path assert model_name.endswith(".pth") or model_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'" model_save_path = target_dir_path / model_name # Save the model state_dict() print(f"[INFO] Saving model to: {model_save_path}") torch.save(obj=model.state_dict(), f=model_save_path)