nnUNet_calvingfront_detection / nnunet /run /default_configuration.py
ho11laqe's picture
init
ecf08bc
raw history blame
No virus
3.81 kB
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import nnunet
from nnunet.paths import network_training_output_dir, preprocessing_output_dir, default_plans_identifier
from batchgenerators.utilities.file_and_folder_operations import *
from nnunet.experiment_planning.summarize_plans import summarize_plans
from nnunet.training.model_restore import recursive_find_python_class
def get_configuration_from_output_folder(folder):
# split off network_training_output_dir
folder = folder[len(network_training_output_dir):]
if folder.startswith("/"):
folder = folder[1:]
configuration, task, trainer_and_plans_identifier = folder.split("/")
trainer, plans_identifier = trainer_and_plans_identifier.split("__")
return configuration, task, trainer, plans_identifier
def get_default_configuration(network, task, network_trainer, plans_identifier=default_plans_identifier,
search_in=(nnunet.__path__[0], "training", "network_training"),
base_module='nnunet.training.network_training'):
assert network in ['2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres'], \
"network can only be one of the following: \'3d_lowres\', \'3d_fullres\', \'3d_cascade_fullres\'"
dataset_directory = join(preprocessing_output_dir, task)
if network == '2d':
plans_file = join(preprocessing_output_dir, task, plans_identifier + "_plans_2D.pkl")
else:
plans_file = join(preprocessing_output_dir, task, plans_identifier + "_plans_3D.pkl")
plans = load_pickle(plans_file)
possible_stages = list(plans['plans_per_stage'].keys())
if (network == '3d_cascade_fullres' or network == "3d_lowres") and len(possible_stages) == 1:
raise RuntimeError("3d_lowres/3d_cascade_fullres only applies if there is more than one stage. This task does "
"not require the cascade. Run 3d_fullres instead")
if network == '2d' or network == "3d_lowres":
stage = 0
else:
stage = possible_stages[-1]
trainer_class = recursive_find_python_class([join(*search_in)], network_trainer,
current_module=base_module)
output_folder_name = join(network_training_output_dir, network, task, network_trainer + "__" + plans_identifier)
print("###############################################")
print("I am running the following nnUNet: %s" % network)
print("My trainer class is: ", trainer_class)
print("For that I will be using the following configuration:")
summarize_plans(plans_file)
print("I am using stage %d from these plans" % stage)
if (network == '2d' or len(possible_stages) > 1) and not network == '3d_lowres':
batch_dice = True
print("I am using batch dice + CE loss")
else:
batch_dice = False
print("I am using sample dice + CE loss")
print("\nI am using data from this folder: ", join(dataset_directory, plans['data_identifier']))
print("###############################################")
return plans_file, output_folder_name, dataset_directory, batch_dice, stage, trainer_class