File size: 1,327 Bytes
b887ad8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from .ema import ExponentialMovingAverage

def load_model_weights(model, ckpt_path, use_ema=True, device='cuda:0'):
    """
    Load weights of a model from a checkpoint file.

    Args:
        model (torch.nn.Module): The model to load weights into.
        ckpt_path (str): Path to the checkpoint file.
        use_ema (bool): Whether to use Exponential Moving Average (EMA) weights if available.
    """
    checkpoint = torch.load(ckpt_path,map_location={'cuda:0': str(device)})
    total_iter = checkpoint.get('total_it', 0)
    
    if "model_ema" in checkpoint and use_ema:
        ema_key = next(iter(checkpoint["model_ema"]))
        if ('module' in ema_key) or ('n_averaged' in ema_key):
            model = ExponentialMovingAverage(model, decay=1.0)
            
        model.load_state_dict(checkpoint["model_ema"], strict=True)
        if ('module' in ema_key) or ('n_averaged' in ema_key):
            model = model.module
            print(f'\nLoading EMA module model from {ckpt_path} with {total_iter} iterations')
        else:
            print(f'\nLoading EMA model from {ckpt_path} with {total_iter} iterations')
    else:
        model.load_state_dict(checkpoint['encoder'], strict=True)
        print(f'\nLoading model from {ckpt_path} with {total_iter} iterations')

    return total_iter