|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
from batchgenerators.utilities.file_and_folder_operations import * |
|
from nnunet.run.default_configuration import get_default_configuration |
|
from nnunet.paths import default_plans_identifier |
|
from nnunet.run.load_pretrained_weights import load_pretrained_weights |
|
from nnunet.training.cascade_stuff.predict_next_stage import predict_next_stage |
|
from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer |
|
from nnunet.training.network_training.nnUNetTrainerCascadeFullRes import nnUNetTrainerCascadeFullRes |
|
from nnunet.utilities.task_name_id_conversion import convert_id_to_task_name |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("network") |
|
parser.add_argument("network_trainer") |
|
parser.add_argument("task", help="can be task name or task id") |
|
parser.add_argument("fold", help='0, 1, ..., 5 or \'all\'') |
|
parser.add_argument("-val", "--validation_only", help="use this if you want to only run the validation", |
|
action="store_true") |
|
parser.add_argument("-c", "--continue_training", help="use this if you want to continue a training", |
|
action="store_true") |
|
parser.add_argument("-p", help="plans identifier. Only change this if you created a custom experiment planner", |
|
default=default_plans_identifier, required=False) |
|
parser.add_argument("--use_compressed_data", default=False, action="store_true", |
|
help="If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data " |
|
"is much more CPU and RAM intensive and should only be used if you know what you are " |
|
"doing", required=False) |
|
parser.add_argument("--deterministic", |
|
help="Makes training deterministic, but reduces training speed substantially. I (Fabian) think " |
|
"this is not necessary. Deterministic training will make you overfit to some random seed. " |
|
"Don't use that.", |
|
required=False, default=False, action="store_true") |
|
parser.add_argument("-gpus", help="number of gpus", required=True,type=int) |
|
parser.add_argument("--dbs", required=False, default=False, action="store_true", help="distribute batch size. If " |
|
"True then whatever " |
|
"batch_size is in plans will " |
|
"be distributed over DDP " |
|
"models, if False then each " |
|
"model will have batch_size " |
|
"for a total of " |
|
"GPUs*batch_size") |
|
parser.add_argument("--npz", required=False, default=False, action="store_true", help="if set then nnUNet will " |
|
"export npz files of " |
|
"predicted segmentations " |
|
"in the vlaidation as well. " |
|
"This is needed to run the " |
|
"ensembling step so unless " |
|
"you are developing nnUNet " |
|
"you should enable this") |
|
parser.add_argument("--valbest", required=False, default=False, action="store_true", help="") |
|
parser.add_argument("--find_lr", required=False, default=False, action="store_true", help="") |
|
parser.add_argument("--fp32", required=False, default=False, action="store_true", |
|
help="disable mixed precision training and run old school fp32") |
|
parser.add_argument("--val_folder", required=False, default="validation_raw", |
|
help="name of the validation folder. No need to use this for most people") |
|
parser.add_argument("--disable_saving", required=False, action='store_true', |
|
help="If set nnU-Net will not save any parameter files. Useful for development when you are " |
|
"only interested in the results and want to save some disk space") |
|
parser.add_argument("--disable_postprocessing_on_folds", required=False, action='store_true', |
|
help="Running postprocessing on each fold only makes sense when developing with nnU-Net and " |
|
"closely observing the model performance on specific configurations. You do not need it " |
|
"when applying nnU-Net because the postprocessing for this will be determined only once " |
|
"all five folds have been trained and nnUNet_find_best_configuration is called. Usually " |
|
"running postprocessing on each fold is computationally cheap, but some users have " |
|
"reported issues with very large images. If your images are large (>600x600x600 voxels) " |
|
"you should consider setting this flag.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument('-pretrained_weights', type=str, required=False, default=None, |
|
help='path to nnU-Net checkpoint file to be used as pretrained model (use .model ' |
|
'file, for example model_final_checkpoint.model). Will only be used when actually training. ' |
|
'Optional. Beta. Use with caution.') |
|
|
|
args = parser.parse_args() |
|
|
|
task = args.task |
|
fold = args.fold |
|
network = args.network |
|
network_trainer = args.network_trainer |
|
validation_only = args.validation_only |
|
plans_identifier = args.p |
|
disable_postprocessing_on_folds = args.disable_postprocessing_on_folds |
|
|
|
use_compressed_data = args.use_compressed_data |
|
decompress_data = not use_compressed_data |
|
|
|
deterministic = args.deterministic |
|
valbest = args.valbest |
|
find_lr = args.find_lr |
|
num_gpus = args.gpus |
|
fp32 = args.fp32 |
|
val_folder = args.val_folder |
|
|
|
|
|
|
|
|
|
if not task.startswith("Task"): |
|
task_id = int(task) |
|
task = convert_id_to_task_name(task_id) |
|
|
|
if fold == 'all': |
|
pass |
|
else: |
|
fold = int(fold) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
plans_file, output_folder_name, dataset_directory, batch_dice, stage, \ |
|
trainer_class = get_default_configuration(network, task, network_trainer, plans_identifier) |
|
|
|
if trainer_class is None: |
|
raise RuntimeError("Could not find trainer class") |
|
|
|
if network == "3d_cascade_fullres": |
|
assert issubclass(trainer_class, nnUNetTrainerCascadeFullRes), "If running 3d_cascade_fullres then your " \ |
|
"trainer class must be derived from " \ |
|
"nnUNetTrainerCascadeFullRes" |
|
else: |
|
assert issubclass(trainer_class, nnUNetTrainer), "network_trainer was found but is not derived from " \ |
|
"nnUNetTrainer" |
|
|
|
trainer = trainer_class(plans_file, fold, output_folder=output_folder_name, |
|
dataset_directory=dataset_directory, batch_dice=batch_dice, stage=stage, |
|
unpack_data=decompress_data, deterministic=deterministic, |
|
distribute_batch_size=args.dbs, num_gpus=num_gpus, fp16=not fp32) |
|
|
|
if args.disable_saving: |
|
trainer.save_latest_only = False |
|
trainer.save_intermediate_checkpoints = False |
|
trainer.save_best_checkpoint = False |
|
trainer.save_final_checkpoint = False |
|
|
|
trainer.initialize(not validation_only) |
|
|
|
if find_lr: |
|
trainer.find_lr() |
|
else: |
|
if not validation_only: |
|
if args.continue_training: |
|
|
|
trainer.load_latest_checkpoint() |
|
elif (not args.continue_training) and (args.pretrained_weights is not None): |
|
|
|
load_pretrained_weights(trainer.network, args.pretrained_weights) |
|
else: |
|
|
|
pass |
|
trainer.run_training() |
|
else: |
|
if valbest: |
|
trainer.load_best_checkpoint(train=False) |
|
else: |
|
trainer.load_final_checkpoint(train=False) |
|
|
|
trainer.network.eval() |
|
|
|
|
|
trainer.validate(save_softmax=args.npz, validation_folder_name=val_folder, |
|
run_postprocessing_on_folds=not disable_postprocessing_on_folds) |
|
|
|
if network == '3d_lowres': |
|
print("predicting segmentations for the next stage of the cascade") |
|
predict_next_stage(trainer, join(dataset_directory, trainer.plans['data_identifier'] + "_stage%d" % 1)) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|