stable-fashion / utils /saving_utils.py
ovshake
add app.py and related files
6724ca0
raw history blame
No virus
1.36 kB
import os
import copy
import cv2
import numpy as np
from collections import OrderedDict
import torch
def load_checkpoint(model, checkpoint_path):
if not os.path.exists(checkpoint_path):
print("----No checkpoints at given path----")
return
model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device("cpu")))
print("----checkpoints loaded from path: {}----".format(checkpoint_path))
return model
def load_checkpoint_mgpu(model, checkpoint_path):
if not os.path.exists(checkpoint_path):
print("----No checkpoints at given path----")
return
model_state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
new_state_dict = OrderedDict()
for k, v in model_state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
print("----checkpoints loaded from path: {}----".format(checkpoint_path))
return model
def save_checkpoint(model, save_path):
print(save_path)
if not os.path.exists(os.path.dirname(save_path)):
os.makedirs(os.path.dirname(save_path))
torch.save(model.state_dict(), save_path)
def save_checkpoints(opt, itr, net):
save_checkpoint(
net,
os.path.join(opt.save_dir, "checkpoints", "itr_{:08d}_u2net.pth".format(itr)),
)