tinyvgg / utils.py
ajitsi's picture
tinyvgg cnn model for image classification
7fc0372
raw
history blame contribute delete
948 Bytes
"""
Contains various utility function for Pytorch model training and saving.
"""
import torch
from pathlib import Path
def save_model(model: torch.nn.Module,
target_dir: str,
model_name: str):
"""
Saves a PyTorch model to a target directory
Args:
model (torch.nn.Module): _description_
target_dir (str): _description_
model_name (str): _description_
"""
# Create a target directory
target_dir_path = Path(target_dir)
target_dir_path.mkdir(parents=True,
exist_ok=True)
# Create model save path
assert model_name.endswith(".pth") or model_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'"
model_save_path = target_dir_path / model_name
# Save the model state_dict()
print(f"[INFO] Saving model to: {model_save_path}")
torch.save(obj=model.state_dict(),
f=model_save_path)