Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,327 Bytes
b887ad8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch
from .ema import ExponentialMovingAverage
def load_model_weights(model, ckpt_path, use_ema=True, device='cuda:0'):
"""
Load weights of a model from a checkpoint file.
Args:
model (torch.nn.Module): The model to load weights into.
ckpt_path (str): Path to the checkpoint file.
use_ema (bool): Whether to use Exponential Moving Average (EMA) weights if available.
"""
checkpoint = torch.load(ckpt_path,map_location={'cuda:0': str(device)})
total_iter = checkpoint.get('total_it', 0)
if "model_ema" in checkpoint and use_ema:
ema_key = next(iter(checkpoint["model_ema"]))
if ('module' in ema_key) or ('n_averaged' in ema_key):
model = ExponentialMovingAverage(model, decay=1.0)
model.load_state_dict(checkpoint["model_ema"], strict=True)
if ('module' in ema_key) or ('n_averaged' in ema_key):
model = model.module
print(f'\nLoading EMA module model from {ckpt_path} with {total_iter} iterations')
else:
print(f'\nLoading EMA model from {ckpt_path} with {total_iter} iterations')
else:
model.load_state_dict(checkpoint['encoder'], strict=True)
print(f'\nLoading model from {ckpt_path} with {total_iter} iterations')
return total_iter |