import os import sys import time import random import argparse from collections import OrderedDict, defaultdict import torch import torch.utils.model_zoo as model_zoo model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', } def load_model(model, model_file, is_restore=False): t_start = time.time() if model_file is None: return model if isinstance(model_file, str): state_dict = torch.load(model_file) if 'model' in state_dict.keys(): state_dict = state_dict['model'] else: state_dict = model_file t_ioend = time.time() if is_restore: new_state_dict = OrderedDict() for k, v in state_dict.items(): name = 'module.' + k new_state_dict[name] = v state_dict = new_state_dict model.load_state_dict(state_dict, strict=False) ckpt_keys = set(state_dict.keys()) own_keys = set(model.state_dict().keys()) missing_keys = own_keys - ckpt_keys unexpected_keys = ckpt_keys - own_keys del state_dict t_end = time.time() return model