|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import nnunet |
|
import torch |
|
from batchgenerators.utilities.file_and_folder_operations import * |
|
import importlib |
|
import pkgutil |
|
from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer |
|
|
|
|
|
def recursive_find_python_class(folder, trainer_name, current_module): |
|
tr = None |
|
for importer, modname, ispkg in pkgutil.iter_modules(folder): |
|
|
|
if not ispkg: |
|
m = importlib.import_module(current_module + "." + modname) |
|
if hasattr(m, trainer_name): |
|
tr = getattr(m, trainer_name) |
|
break |
|
|
|
if tr is None: |
|
for importer, modname, ispkg in pkgutil.iter_modules(folder): |
|
if ispkg: |
|
next_current_module = current_module + "." + modname |
|
tr = recursive_find_python_class([join(folder[0], modname)], trainer_name, current_module=next_current_module) |
|
if tr is not None: |
|
break |
|
|
|
return tr |
|
|
|
|
|
def restore_model(pkl_file, checkpoint=None, train=False, fp16=None): |
|
""" |
|
This is a utility function to load any nnUNet trainer from a pkl. It will recursively search |
|
nnunet.trainig.network_training for the file that contains the trainer and instantiate it with the arguments saved in the pkl file. If checkpoint |
|
is specified, it will furthermore load the checkpoint file in train/test mode (as specified by train). |
|
The pkl file required here is the one that will be saved automatically when calling nnUNetTrainer.save_checkpoint. |
|
:param pkl_file: |
|
:param checkpoint: |
|
:param train: |
|
:param fp16: if None then we take no action. If True/False we overwrite what the model has in its init |
|
:return: |
|
""" |
|
info = load_pickle(pkl_file) |
|
init = info['init'] |
|
name = info['name'] |
|
search_in = join(nnunet.__path__[0], "training", "network_training") |
|
tr = recursive_find_python_class([search_in], name, current_module="nnunet.training.network_training") |
|
|
|
if tr is None: |
|
""" |
|
Fabian only. This will trigger searching for trainer classes in other repositories as well |
|
""" |
|
try: |
|
import meddec |
|
search_in = join(meddec.__path__[0], "model_training") |
|
tr = recursive_find_python_class([search_in], name, current_module="meddec.model_training") |
|
except ImportError: |
|
pass |
|
|
|
if tr is None: |
|
raise RuntimeError("Could not find the model trainer specified in checkpoint in nnunet.trainig.network_training. If it " |
|
"is not located there, please move it or change the code of restore_model. Your model " |
|
"trainer can be located in any directory within nnunet.trainig.network_training (search is recursive)." |
|
"\nDebug info: \ncheckpoint file: %s\nName of trainer: %s " % (checkpoint, name)) |
|
assert issubclass(tr, nnUNetTrainer), "The network trainer was found but is not a subclass of nnUNetTrainer. " \ |
|
"Please make it so!" |
|
|
|
|
|
"""if len(init) == 7: |
|
print("warning: this model seems to have been saved with a previous version of nnUNet. Attempting to load it " |
|
"anyways. Expect the unexpected.") |
|
print("manually editing init args...") |
|
init = [init[i] for i in range(len(init)) if i != 2]""" |
|
|
|
|
|
|
|
trainer = tr(*init) |
|
|
|
|
|
|
|
if fp16 is not None: |
|
trainer.fp16 = fp16 |
|
|
|
trainer.process_plans(info['plans']) |
|
if checkpoint is not None: |
|
trainer.load_checkpoint(checkpoint, train) |
|
return trainer |
|
|
|
|
|
def load_best_model_for_inference(folder): |
|
checkpoint = join(folder, "model_best.model") |
|
pkl_file = checkpoint + ".pkl" |
|
return restore_model(pkl_file, checkpoint, False) |
|
|
|
|
|
def load_model_and_checkpoint_files(folder, folds=None, mixed_precision=None, checkpoint_name="model_best"): |
|
""" |
|
used for if you need to ensemble the five models of a cross-validation. This will restore the model from the |
|
checkpoint in fold 0, load all parameters of the five folds in ram and return both. This will allow for fast |
|
switching between parameters (as opposed to loading them form disk each time). |
|
|
|
This is best used for inference and test prediction |
|
:param folder: |
|
:param folds: |
|
:param mixed_precision: if None then we take no action. If True/False we overwrite what the model has in its init |
|
:return: |
|
""" |
|
if isinstance(folds, str): |
|
folds = [join(folder, "all")] |
|
assert isdir(folds[0]), "no output folder for fold %s found" % folds |
|
elif isinstance(folds, (list, tuple)): |
|
if len(folds) == 1 and folds[0] == "all": |
|
folds = [join(folder, "all")] |
|
else: |
|
folds = [join(folder, "fold_%d" % i) for i in folds] |
|
assert all([isdir(i) for i in folds]), "list of folds specified but not all output folders are present" |
|
elif isinstance(folds, int): |
|
folds = [join(folder, "fold_%d" % folds)] |
|
assert all([isdir(i) for i in folds]), "output folder missing for fold %d" % folds |
|
elif folds is None: |
|
print("folds is None so we will automatically look for output folders (not using \'all\'!)") |
|
folds = subfolders(folder, prefix="fold") |
|
print("found the following folds: ", folds) |
|
else: |
|
raise ValueError("Unknown value for folds. Type: %s. Expected: list of int, int, str or None", str(type(folds))) |
|
|
|
trainer = restore_model(join(folds[0], "%s.model.pkl" % checkpoint_name), fp16=mixed_precision) |
|
trainer.output_folder = folder |
|
trainer.output_folder_base = folder |
|
trainer.update_fold(0) |
|
trainer.initialize(False) |
|
all_best_model_files = [join(i, "%s.model" % checkpoint_name) for i in folds] |
|
print("using the following model files: ", all_best_model_files) |
|
all_params = [torch.load(i, map_location=torch.device('cpu')) for i in all_best_model_files] |
|
return trainer, all_params |
|
|
|
|
|
if __name__ == "__main__": |
|
pkl = "/home/fabian/PhD/results/nnUNetV2/nnUNetV2_3D_fullres/Task004_Hippocampus/fold0/model_best.model.pkl" |
|
checkpoint = pkl[:-4] |
|
train = False |
|
trainer = restore_model(pkl, checkpoint, train) |
|
|