|
import torch |
|
import numpy as np |
|
import hashlib |
|
|
|
def wrap(func, *args, unsqueeze=False): |
|
""" |
|
Wrap a torch function so it can be called with NumPy arrays. |
|
Input and return types are seamlessly converted. |
|
""" |
|
|
|
|
|
args = list(args) |
|
for i, arg in enumerate(args): |
|
if type(arg) == np.ndarray: |
|
args[i] = torch.from_numpy(arg) |
|
if unsqueeze: |
|
args[i] = args[i].unsqueeze(0) |
|
|
|
result = func(*args) |
|
|
|
|
|
if isinstance(result, tuple): |
|
result = list(result) |
|
for i, res in enumerate(result): |
|
if type(res) == torch.Tensor: |
|
if unsqueeze: |
|
res = res.squeeze(0) |
|
result[i] = res.numpy() |
|
return tuple(result) |
|
elif type(result) == torch.Tensor: |
|
if unsqueeze: |
|
result = result.squeeze(0) |
|
return result.numpy() |
|
else: |
|
return result |
|
|
|
def deterministic_random(min_value, max_value, data): |
|
digest = hashlib.sha256(data.encode()).digest() |
|
raw_value = int.from_bytes(digest[:4], byteorder='little', signed=False) |
|
return int(raw_value / (2**32 - 1) * (max_value - min_value)) + min_value |
|
|
|
def load_pretrained_weights(model, checkpoint): |
|
"""Load pretrianed weights to model |
|
Incompatible layers (unmatched in name or size) will be ignored |
|
Args: |
|
- model (nn.Module): network model, which must not be nn.DataParallel |
|
- weight_path (str): path to pretrained weights |
|
""" |
|
import collections |
|
if 'state_dict' in checkpoint: |
|
state_dict = checkpoint['state_dict'] |
|
else: |
|
state_dict = checkpoint |
|
model_dict = model.state_dict() |
|
new_state_dict = collections.OrderedDict() |
|
matched_layers, discarded_layers = [], [] |
|
for k, v in state_dict.items(): |
|
|
|
|
|
if k.startswith('module.'): |
|
k = k[7:] |
|
if k in model_dict and model_dict[k].size() == v.size(): |
|
new_state_dict[k] = v |
|
matched_layers.append(k) |
|
else: |
|
discarded_layers.append(k) |
|
|
|
model_dict.update(new_state_dict) |
|
|
|
model.load_state_dict(model_dict) |
|
print('load_weight', len(matched_layers)) |
|
|
|
return model |
|
|