File size: 11,803 Bytes
ecf08bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from batchgenerators.utilities.file_and_folder_operations import *
from nnunet.run.default_configuration import get_default_configuration
from nnunet.paths import default_plans_identifier
from nnunet.run.load_pretrained_weights import load_pretrained_weights
from nnunet.training.cascade_stuff.predict_next_stage import predict_next_stage
from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
from nnunet.training.network_training.nnUNetTrainerCascadeFullRes import nnUNetTrainerCascadeFullRes
from nnunet.training.network_training.nnUNetTrainerV2_CascadeFullRes import nnUNetTrainerV2CascadeFullRes
from nnunet.utilities.task_name_id_conversion import convert_id_to_task_name
def main():
parser = argparse.ArgumentParser()
parser.add_argument("network")
parser.add_argument("network_trainer")
parser.add_argument("task", help="can be task name or task id")
parser.add_argument("fold", help='0, 1, ..., 5 or \'all\'')
parser.add_argument("-val", "--validation_only", help="use this if you want to only run the validation",
action="store_true")
parser.add_argument("-c", "--continue_training", help="use this if you want to continue a training",
action="store_true")
parser.add_argument("-p", help="plans identifier. Only change this if you created a custom experiment planner",
default=default_plans_identifier, required=False)
parser.add_argument("--use_compressed_data", default=False, action="store_true",
help="If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data "
"is much more CPU and RAM intensive and should only be used if you know what you are "
"doing", required=False)
parser.add_argument("--deterministic",
help="Makes training deterministic, but reduces training speed substantially. I (Fabian) think "
"this is not necessary. Deterministic training will make you overfit to some random seed. "
"Don't use that.",
required=False, default=False, action="store_true")
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--fp32", required=False, default=False, action="store_true",
help="disable mixed precision training and run old school fp32")
parser.add_argument("--dbs", required=False, default=False, action="store_true", help="distribute batch size. If "
"True then whatever "
"batch_size is in plans will "
"be distributed over DDP "
"models, if False then each "
"model will have batch_size "
"for a total of "
"GPUs*batch_size")
parser.add_argument("--npz", required=False, default=False, action="store_true", help="if set then nnUNet will "
"export npz files of "
"predicted segmentations "
"in the vlaidation as well. "
"This is needed to run the "
"ensembling step so unless "
"you are developing nnUNet "
"you should enable this")
parser.add_argument("--valbest", required=False, default=False, action="store_true", help="")
parser.add_argument("--find_lr", required=False, default=False, action="store_true", help="")
parser.add_argument("--val_folder", required=False, default="validation_raw",
help="name of the validation folder. No need to use this for most people")
parser.add_argument("--disable_saving", required=False, action='store_true',
help="If set nnU-Net will not save any parameter files. Useful for development when you are "
"only interested in the results and want to save some disk space")
parser.add_argument("--disable_postprocessing_on_folds", required=False, action='store_true',
help="Running postprocessing on each fold only makes sense when developing with nnU-Net and "
"closely observing the model performance on specific configurations. You do not need it "
"when applying nnU-Net because the postprocessing for this will be determined only once "
"all five folds have been trained and nnUNet_find_best_configuration is called. Usually "
"running postprocessing on each fold is computationally cheap, but some users have "
"reported issues with very large images. If your images are large (>600x600x600 voxels) "
"you should consider setting this flag.")
# parser.add_argument("--interp_order", required=False, default=3, type=int,
# help="order of interpolation for segmentations. Testing purpose only. Hands off")
# parser.add_argument("--interp_order_z", required=False, default=0, type=int,
# help="order of interpolation along z if z is resampled separately. Testing purpose only. "
# "Hands off")
# parser.add_argument("--force_separate_z", required=False, default="None", type=str,
# help="force_separate_z resampling. Can be None, True or False. Testing purpose only. Hands off")
parser.add_argument('-pretrained_weights', type=str, required=False, default=None,
help='path to nnU-Net checkpoint file to be used as pretrained model (use .model '
'file, for example model_final_checkpoint.model). Will only be used when actually training. '
'Optional. Beta. Use with caution.')
args = parser.parse_args()
task = args.task
fold = args.fold
network = args.network
network_trainer = args.network_trainer
validation_only = args.validation_only
plans_identifier = args.p
use_compressed_data = args.use_compressed_data
decompress_data = not use_compressed_data
deterministic = args.deterministic
valbest = args.valbest
find_lr = args.find_lr
val_folder = args.val_folder
# interp_order = args.interp_order
# interp_order_z = args.interp_order_z
# force_separate_z = args.force_separate_z
fp32 = args.fp32
disable_postprocessing_on_folds = args.disable_postprocessing_on_folds
if not task.startswith("Task"):
task_id = int(task)
task = convert_id_to_task_name(task_id)
if fold == 'all':
pass
else:
fold = int(fold)
#
# if force_separate_z == "None":
# force_separate_z = None
# elif force_separate_z == "False":
# force_separate_z = False
# elif force_separate_z == "True":
# force_separate_z = True
# else:
# raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z)
plans_file, output_folder_name, dataset_directory, batch_dice, stage, \
trainer_class = get_default_configuration(network, task, network_trainer, plans_identifier)
if trainer_class is None:
raise RuntimeError("Could not find trainer class in meddec.model_training")
if network == "3d_cascade_fullres":
assert issubclass(trainer_class, (nnUNetTrainerCascadeFullRes, nnUNetTrainerV2CascadeFullRes)), \
"If running 3d_cascade_fullres then your " \
"trainer class must be derived from " \
"nnUNetTrainerCascadeFullRes"
else:
assert issubclass(trainer_class,
nnUNetTrainer), "network_trainer was found but is not derived from nnUNetTrainer"
trainer = trainer_class(plans_file, fold, local_rank=args.local_rank, output_folder=output_folder_name,
dataset_directory=dataset_directory, batch_dice=batch_dice, stage=stage,
unpack_data=decompress_data, deterministic=deterministic, fp16=not fp32,
distribute_batch_size=args.dbs)
if args.disable_saving:
trainer.save_latest_only = False # if false it will not store/overwrite _latest but separate files each
trainer.save_intermediate_checkpoints = False # whether or not to save checkpoint_latest
trainer.save_best_checkpoint = False # whether or not to save the best checkpoint according to self.best_val_eval_criterion_MA
trainer.save_final_checkpoint = False # whether or not to save the final checkpoint
trainer.initialize(not validation_only)
if find_lr:
trainer.find_lr()
else:
if not validation_only:
if args.continue_training:
# -c was set, continue a previous training and ignore pretrained weights
trainer.load_latest_checkpoint()
elif (not args.continue_training) and (args.pretrained_weights is not None):
# we start a new training. If pretrained_weights are set, use them
load_pretrained_weights(trainer.network, args.pretrained_weights)
else:
# new training without pretraine weights, do nothing
pass
trainer.run_training()
else:
if valbest:
trainer.load_best_checkpoint(train=False)
else:
trainer.load_final_checkpoint(train=False)
trainer.network.eval()
# predict validation
trainer.validate(save_softmax=args.npz, validation_folder_name=val_folder,
run_postprocessing_on_folds=not disable_postprocessing_on_folds)
if network == '3d_lowres':
print("predicting segmentations for the next stage of the cascade")
predict_next_stage(trainer, join(dataset_directory, trainer.plans['data_identifier'] + "_stage%d" % 1))
if __name__ == "__main__":
main()
|