File size: 948 Bytes
7fc0372 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
"""
Contains various utility function for Pytorch model training and saving.
"""
import torch
from pathlib import Path
def save_model(model: torch.nn.Module,
target_dir: str,
model_name: str):
"""
Saves a PyTorch model to a target directory
Args:
model (torch.nn.Module): _description_
target_dir (str): _description_
model_name (str): _description_
"""
# Create a target directory
target_dir_path = Path(target_dir)
target_dir_path.mkdir(parents=True,
exist_ok=True)
# Create model save path
assert model_name.endswith(".pth") or model_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'"
model_save_path = target_dir_path / model_name
# Save the model state_dict()
print(f"[INFO] Saving model to: {model_save_path}")
torch.save(obj=model.state_dict(),
f=model_save_path)
|