import torch import numpy as np import hashlib def wrap(func, *args, unsqueeze=False): """ Wrap a torch function so it can be called with NumPy arrays. Input and return types are seamlessly converted. """ # Convert input types where applicable args = list(args) for i, arg in enumerate(args): if type(arg) == np.ndarray: args[i] = torch.from_numpy(arg) if unsqueeze: args[i] = args[i].unsqueeze(0) result = func(*args) # Convert output types where applicable if isinstance(result, tuple): result = list(result) for i, res in enumerate(result): if type(res) == torch.Tensor: if unsqueeze: res = res.squeeze(0) result[i] = res.numpy() return tuple(result) elif type(result) == torch.Tensor: if unsqueeze: result = result.squeeze(0) return result.numpy() else: return result def deterministic_random(min_value, max_value, data): digest = hashlib.sha256(data.encode()).digest() raw_value = int.from_bytes(digest[:4], byteorder='little', signed=False) return int(raw_value / (2**32 - 1) * (max_value - min_value)) + min_value def load_pretrained_weights(model, checkpoint): """Load pretrianed weights to model Incompatible layers (unmatched in name or size) will be ignored Args: - model (nn.Module): network model, which must not be nn.DataParallel - weight_path (str): path to pretrained weights """ import collections if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] else: state_dict = checkpoint model_dict = model.state_dict() new_state_dict = collections.OrderedDict() matched_layers, discarded_layers = [], [] for k, v in state_dict.items(): # If the pretrained state_dict was saved as nn.DataParallel, # keys would contain "module.", which should be ignored. if k.startswith('module.'): k = k[7:] if k in model_dict and model_dict[k].size() == v.size(): new_state_dict[k] = v matched_layers.append(k) else: discarded_layers.append(k) # new_state_dict.requires_grad = False model_dict.update(new_state_dict) model.load_state_dict(model_dict) print('load_weight', len(matched_layers)) # model.state_dict(model_dict).requires_grad = False return model