Spaces:
Running
on
T4
Running
on
T4
import torch | |
import os | |
import json | |
import sys | |
from utils import pickle_util | |
history_array = [] | |
def save_model(epoch, model, optimizer, file_save_path): | |
dirpath = os.path.abspath(os.path.join(file_save_path, os.pardir)) | |
if not os.path.exists(dirpath): | |
print("mkdir:", dirpath) | |
os.makedirs(dirpath) | |
opti = None | |
if optimizer is not None: | |
opti = optimizer.state_dict() | |
torch.save(obj={ | |
'epoch': epoch, | |
'model': model.state_dict(), | |
'optimizer': opti, | |
}, f=file_save_path) | |
history_array.append(file_save_path) | |
def save_model_v4(epoch, model, optimizer, file_save_path, discriminator): | |
dirpath = os.path.abspath(os.path.join(file_save_path, os.pardir)) | |
if not os.path.exists(dirpath): | |
print("mkdir:", dirpath) | |
os.makedirs(dirpath) | |
opti = None | |
if optimizer is not None: | |
opti = optimizer.state_dict() | |
torch.save(obj={ | |
'epoch': epoch, | |
'model': model.state_dict(), | |
'optimizer': opti, | |
"discriminator": discriminator, | |
}, f=file_save_path) | |
history_array.append(file_save_path) | |
def delete_last_saved_model(): | |
if len(history_array) == 0: | |
return | |
last_path = history_array.pop() | |
if os.path.exists(last_path): | |
os.remove(last_path) | |
print("delete model:", last_path) | |
if os.path.exists(last_path + ".json"): | |
os.remove(last_path + ".json") | |
def load_model(resume_path, model, optimizer=None, strict=True): | |
checkpoint = torch.load(resume_path, map_location=torch.device('cpu')) | |
start_epoch = checkpoint['epoch'] + 1 | |
model.load_state_dict(checkpoint['model'], strict=strict) | |
if optimizer is not None: | |
optimizer.load_state_dict(checkpoint['optimizer']) | |
print("checkpoint loaded!") | |
return start_epoch | |
def save_model_v2(model, args, model_save_name): | |
model_save_path = os.path.join(args.model_save_folder, args.project, args.name, model_save_name) | |
save_model(0, model, None, model_save_path) | |
print("save:", model_save_path) | |
def save_project_info(args): | |
run_info = { | |
"cmd_str": ' '.join(sys.argv[1:]), | |
"args": vars(args), | |
} | |
name = "run_info.json" | |
folder = os.path.join(args.model_save_folder, args.project, args.name) | |
if not os.path.exists(folder): | |
os.makedirs(folder) | |
json_file_path = os.path.join(folder, name) | |
with open(json_file_path, "w") as f: | |
json.dump(run_info, f) | |
print("save_project_info:", json_file_path) | |
def get_pkl_json(folder): | |
names = [i for i in os.listdir(folder) if ".pkl.json" in i] | |
assert len(names) == 1 | |
json_path = os.path.join(folder, names[0]) | |
obj = pickle_util.read_json(json_path) | |
return obj | |
# 并行 | |
def is_data_parallel_checkpoint(state_dict): | |
return any(key.startswith('module.') for key in state_dict.keys()) | |
def map_state_dict(state_dict): | |
if is_data_parallel_checkpoint(state_dict): | |
# 处理 DataParallel 添加的前缀 'module.' | |
from collections import OrderedDict | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
name = k[7:] if k.startswith('module.') else k # 移除前缀 'module.' | |
new_state_dict[name] = v | |
return new_state_dict | |
return state_dict | |