|
import os |
|
|
|
import torch |
|
|
|
|
|
def save_weights(model, filename, path="./saved_models"): |
|
os.makedirs(path, exist_ok=True) |
|
|
|
fpath = os.path.join(path, filename) |
|
torch.save(model.state_dict(), fpath) |
|
return |
|
|
|
def save_checkpoint(model, optimizer, epoch, filename, root="./checkpoints"): |
|
if not os.path.isdir(root): |
|
os.makedirs(root) |
|
|
|
fpath = os.path.join(root, filename) |
|
torch.save( |
|
{ |
|
"model": model.state_dict(), |
|
"optimizer": optimizer.state_dict(), |
|
"epoch": epoch |
|
} |
|
, fpath) |
|
|
|
def load_weights(model, filename, path="./saved_models"): |
|
fpath = os.path.join(path, filename) |
|
state_dict = torch.load(fpath) |
|
model.load_state_dict(state_dict) |
|
return model |
|
|
|
def load_checkpoint(fpath, model, optimizer=None): |
|
ckpt = torch.load(fpath, map_location='cpu') |
|
if ckpt is None: |
|
raise Exception(f"\nERROR Loading AdaBins_nyu.pt. Read this for a fix:\nhttps://github.com/deforum-art/deforum-for-automatic1111-webui/wiki/FAQ-&-Troubleshooting#3d-animation-mode-is-not-working-only-2d-works") |
|
if optimizer is None: |
|
optimizer = ckpt.get('optimizer', None) |
|
else: |
|
optimizer.load_state_dict(ckpt['optimizer']) |
|
epoch = ckpt['epoch'] |
|
|
|
if 'model' in ckpt: |
|
ckpt = ckpt['model'] |
|
load_dict = {} |
|
for k, v in ckpt.items(): |
|
if k.startswith('module.'): |
|
k_ = k.replace('module.', '') |
|
load_dict[k_] = v |
|
else: |
|
load_dict[k] = v |
|
|
|
modified = {} |
|
for k, v in load_dict.items(): |
|
if k.startswith('adaptive_bins_layer.embedding_conv.'): |
|
k_ = k.replace('adaptive_bins_layer.embedding_conv.', |
|
'adaptive_bins_layer.conv3x3.') |
|
modified[k_] = v |
|
|
|
|
|
elif k.startswith('adaptive_bins_layer.patch_transformer.embedding_encoder'): |
|
|
|
k_ = k.replace('adaptive_bins_layer.patch_transformer.embedding_encoder', |
|
'adaptive_bins_layer.patch_transformer.embedding_convPxP') |
|
modified[k_] = v |
|
|
|
else: |
|
modified[k] = v |
|
|
|
model.load_state_dict(modified) |
|
return model, optimizer, epoch |