masabhuq's picture
Initial Commit
0c7049d
from pathlib import Path
import torch
def save_model(model: torch.nn.Module,
model_name: str,
target_dir: str):
target_dir_path = Path(target_dir)
target_dir_path.mkdir(parents = True,
exist_ok = True)
assert model_name.endswith(".pth") or model_name.endswith(".pt"), "Model name should end with .pth or .pt"
model_save_path = target_dir_path / model_name
torch.save(obj = model.state_dict(),
f = model_save_path)