import torch from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor def predictNNUNet(model_dir, input_dir, output_dir, folds): predictor = nnUNetPredictor( tile_step_size=0.9, #0.5, use_gaussian=True, use_mirroring=False, # --disable_tta # perform_everything_on_device=True, device=torch.device('cpu', 0), verbose=True, verbose_preprocessing=False, allow_tqdm=True, ) predictor.initialize_from_trained_model_folder( model_dir, use_folds=folds, # None if autodetect folds checkpoint_name='checkpoint_final.pth', ) print("input_dir",input_dir) predictor.predict_from_files(input_dir, output_dir, save_probabilities=False, overwrite=True, num_processes_preprocessing=2, num_processes_segmentation_export=2, folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0 )