""" | |
Contains various utility function for Pytorch model training and saving. | |
""" | |
import torch | |
from pathlib import Path | |
def save_model(model: torch.nn.Module, | |
target_dir: str, | |
model_name: str): | |
""" | |
Saves a PyTorch model to a target directory | |
Args: | |
model (torch.nn.Module): _description_ | |
target_dir (str): _description_ | |
model_name (str): _description_ | |
""" | |
# Create a target directory | |
target_dir_path = Path(target_dir) | |
target_dir_path.mkdir(parents=True, | |
exist_ok=True) | |
# Create model save path | |
assert model_name.endswith(".pth") or model_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'" | |
model_save_path = target_dir_path / model_name | |
# Save the model state_dict() | |
print(f"[INFO] Saving model to: {model_save_path}") | |
torch.save(obj=model.state_dict(), | |
f=model_save_path) | |