|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
def load_pretrained_weights(network, fname, verbose=False): |
|
""" |
|
THIS DOES NOT TRANSFER SEGMENTATION HEADS! |
|
""" |
|
saved_model = torch.load(fname) |
|
pretrained_dict = saved_model['state_dict'] |
|
|
|
new_state_dict = {} |
|
|
|
|
|
|
|
for k, value in pretrained_dict.items(): |
|
key = k |
|
|
|
if key.startswith('module.'): |
|
key = key[7:] |
|
new_state_dict[key] = value |
|
|
|
pretrained_dict = new_state_dict |
|
|
|
model_dict = network.state_dict() |
|
ok = True |
|
for key, _ in model_dict.items(): |
|
if ('conv_blocks' in key): |
|
if (key in pretrained_dict) and (model_dict[key].shape == pretrained_dict[key].shape): |
|
continue |
|
else: |
|
ok = False |
|
break |
|
|
|
|
|
if ok: |
|
pretrained_dict = {k: v for k, v in pretrained_dict.items() if |
|
(k in model_dict) and (model_dict[k].shape == pretrained_dict[k].shape)} |
|
|
|
model_dict.update(pretrained_dict) |
|
print("################### Loading pretrained weights from file ", fname, '###################') |
|
if verbose: |
|
print("Below is the list of overlapping blocks in pretrained model and nnUNet architecture:") |
|
for key, _ in pretrained_dict.items(): |
|
print(key) |
|
print("################### Done ###################") |
|
network.load_state_dict(model_dict) |
|
else: |
|
raise RuntimeError("Pretrained weights are not compatible with the current network architecture") |
|
|
|
|