""" Checkpoint loading / state_dict helpers Copyright 2020 Ross Wightman """ import torch import os from collections import OrderedDict try: from torch.hub import load_state_dict_from_url except ImportError: from torch.utils.model_zoo import load_url as load_state_dict_from_url def load_checkpoint(model, checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path): print("=> Loading checkpoint '{}'".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: new_state_dict = OrderedDict() for k, v in checkpoint['state_dict'].items(): if k.startswith('module'): name = k[7:] # remove `module.` else: name = k new_state_dict[name] = v model.load_state_dict(new_state_dict) else: model.load_state_dict(checkpoint) print("=> Loaded checkpoint '{}'".format(checkpoint_path)) else: print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() def load_pretrained(model, url, filter_fn=None, strict=True): if not url: print("=> Warning: Pretrained model URL is empty, using random initialization.") return state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu') input_conv = 'conv_stem' classifier = 'classifier' in_chans = getattr(model, input_conv).weight.shape[1] num_classes = getattr(model, classifier).weight.shape[0] input_conv_weight = input_conv + '.weight' pretrained_in_chans = state_dict[input_conv_weight].shape[1] if in_chans != pretrained_in_chans: if in_chans == 1: print('=> Converting pretrained input conv {} from {} to 1 channel'.format( input_conv_weight, pretrained_in_chans)) conv1_weight = state_dict[input_conv_weight] state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True) else: print('=> Discarding pretrained input conv {} since input channel count != {}'.format( input_conv_weight, pretrained_in_chans)) del state_dict[input_conv_weight] strict = False classifier_weight = classifier + '.weight' pretrained_num_classes = state_dict[classifier_weight].shape[0] if num_classes != pretrained_num_classes: print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes)) del state_dict[classifier_weight] del state_dict[classifier + '.bias'] strict = False if filter_fn is not None: state_dict = filter_fn(state_dict) model.load_state_dict(state_dict, strict=strict)