from pathlib import Path import torch def save_model(model: torch.nn.Module, model_name: str, target_dir: str): target_dir_path = Path(target_dir) target_dir_path.mkdir(parents = True, exist_ok = True) assert model_name.endswith(".pth") or model_name.endswith(".pt"), "Model name should end with .pth or .pt" model_save_path = target_dir_path / model_name torch.save(obj = model.state_dict(), f = model_save_path)