|
import torch
|
|
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
|
|
|
|
def predictNNUNet(model_dir, input_dir, output_dir, folds):
|
|
|
|
predictor = nnUNetPredictor(
|
|
tile_step_size=0.9,
|
|
use_gaussian=True,
|
|
use_mirroring=False,
|
|
|
|
device=torch.device('cpu', 0),
|
|
verbose=True,
|
|
verbose_preprocessing=False,
|
|
allow_tqdm=True,
|
|
)
|
|
|
|
predictor.initialize_from_trained_model_folder(
|
|
model_dir,
|
|
use_folds=folds,
|
|
checkpoint_name='checkpoint_final.pth',
|
|
)
|
|
print("input_dir",input_dir)
|
|
predictor.predict_from_files(input_dir,
|
|
output_dir,
|
|
save_probabilities=False,
|
|
overwrite=True,
|
|
num_processes_preprocessing=2,
|
|
num_processes_segmentation_export=2,
|
|
folder_with_segs_from_prev_stage=None,
|
|
num_parts=1,
|
|
part_id=0
|
|
) |