import logging import os import torch def load_checkpoint(checkpoint_path, model): assert os.path.isfile(checkpoint_path) checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') saved_state_dict = checkpoint_dict['model'] if hasattr(model, 'module'): state_dict = model.module.state_dict() else: state_dict = model.state_dict() new_state_dict = {} for k, v in state_dict.items(): try: new_state_dict[k] = saved_state_dict[k] except: logging.info("%s is not in the checkpoint" % k) new_state_dict[k] = v if hasattr(model, 'module'): model.module.load_state_dict(new_state_dict) else: model.load_state_dict(new_state_dict) logging.info("Loaded checkpoint '{}'".format(checkpoint_path)) return