MotionGPT / mGPT /utils /load_checkpoint.py
bill-jiang's picture
Init
4409449
raw
history blame
1.18 kB
import torch
def load_pretrained(cfg, model, logger, phase="train"):
logger.info(f"Loading pretrain model from {cfg.TRAIN.PRETRAINED}")
if phase == "train":
ckpt_path = cfg.TRAIN.PRETRAINED
elif phase == "test":
ckpt_path = cfg.TEST.CHECKPOINTS
state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
model.load_state_dict(state_dict, strict=True)
return model
def load_pretrained_vae(cfg, model, logger):
state_dict = torch.load(cfg.TRAIN.PRETRAINED_VAE,
map_location="cpu")['state_dict']
logger.info(f"Loading pretrain vae from {cfg.TRAIN.PRETRAINED_VAE}")
# Extract encoder/decoder
from collections import OrderedDict
vae_dict = OrderedDict()
for k, v in state_dict.items():
if "motion_vae" in k:
name = k.replace("motion_vae.", "")
vae_dict[name] = v
elif "vae" in k:
name = k.replace("vae.", "")
vae_dict[name] = v
if hasattr(model, 'vae'):
model.vae.load_state_dict(vae_dict, strict=True)
else:
model.motion_vae.load_state_dict(vae_dict, strict=True)
return model