Synthrad2025 / nnunetv2 /inference /export_prediction.py
FelixzeroSun
update
3de7bcb
import os
from copy import deepcopy
from typing import Union, List
import numpy as np
import torch
from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice
from batchgenerators.utilities.file_and_folder_operations import load_json, isfile, save_pickle
from nnunetv2.configuration import default_num_processes
from nnunetv2.utilities.label_handling.label_handling import LabelManager
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
import SimpleITK as sitk
def convert_predicted_logits_to_segmentation_with_correct_shape(predicted_logits: Union[torch.Tensor, np.ndarray],
plans_manager: PlansManager,
configuration_manager: ConfigurationManager,
label_manager: LabelManager,
properties_dict: dict,
return_probabilities: bool = False,
num_threads_torch: int = default_num_processes):
old_threads = torch.get_num_threads()
torch.set_num_threads(num_threads_torch)
# resample to original shape
current_spacing = configuration_manager.spacing if \
len(configuration_manager.spacing) == \
len(properties_dict['shape_after_cropping_and_before_resampling']) else \
[properties_dict['spacing'][0], *configuration_manager.spacing]
predicted_logits = configuration_manager.resampling_fn_probabilities(predicted_logits,
properties_dict['shape_after_cropping_and_before_resampling'],
current_spacing,
properties_dict['spacing'])
# return value of resampling_fn_probabilities can be ndarray or Tensor but that does not matter because
# apply_inference_nonlin will convert to torch
predicted_probabilities = label_manager.apply_inference_nonlin(predicted_logits)
del predicted_logits
segmentation = label_manager.convert_probabilities_to_segmentation(predicted_probabilities)
# segmentation may be torch.Tensor but we continue with numpy
if isinstance(segmentation, torch.Tensor):
segmentation = segmentation.cpu().numpy()
# put segmentation in bbox (revert cropping)
segmentation_reverted_cropping = np.zeros(properties_dict['shape_before_cropping'],
dtype=np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16)
slicer = bounding_box_to_slice(properties_dict['bbox_used_for_cropping'])
segmentation_reverted_cropping[slicer] = segmentation
del segmentation
# revert transpose
segmentation_reverted_cropping = segmentation_reverted_cropping.transpose(plans_manager.transpose_backward)
if return_probabilities:
# revert cropping
predicted_probabilities = label_manager.revert_cropping_on_probabilities(predicted_probabilities,
properties_dict[
'bbox_used_for_cropping'],
properties_dict[
'shape_before_cropping'])
predicted_probabilities = predicted_probabilities.cpu().numpy()
# revert transpose
predicted_probabilities = predicted_probabilities.transpose([0] + [i + 1 for i in
plans_manager.transpose_backward])
torch.set_num_threads(old_threads)
return segmentation_reverted_cropping, predicted_probabilities
else:
torch.set_num_threads(old_threads)
return segmentation_reverted_cropping
import torch
import numpy as np
def convert_predicted_image_to_original_shape(predicted_image: Union[torch.Tensor, np.ndarray], #arthur
plans_manager: PlansManager,
configuration_manager: ConfigurationManager,
properties_dict: dict,
num_threads_torch: int = default_num_processes):
old_threads = torch.get_num_threads()
torch.set_num_threads(num_threads_torch)
# Resample to original shape
# Assuming configuration_manager has a resampling function for images
current_spacing = configuration_manager.spacing if \
len(configuration_manager.spacing) == \
len(properties_dict['shape_after_cropping_and_before_resampling']) else \
[properties_dict['spacing'][0], *configuration_manager.spacing]
predicted_resampled = configuration_manager.resampling_fn_data(predicted_image,
properties_dict['shape_after_cropping_and_before_resampling'],
current_spacing,
properties_dict['spacing'])
# Ensure output is in numpy format for further processing if not already
# print(predicted_resampled.shape)
if isinstance(predicted_resampled, torch.Tensor):
predicted_resampled = predicted_resampled.cpu().numpy()
predicted_resampled = predicted_resampled[0] #arthur : hardcoded first channel
# sitk.WriteImage(sitk.GetImageFromArray(predicted_resampled.astype(np.float32)), "test_tanh.nii.gz")
print(predicted_resampled.shape, np.min(predicted_resampled), np.max(predicted_resampled), predicted_resampled.dtype)
# Put the image back into its original bounding box (revert cropping)
original_shape_image = np.zeros(properties_dict['shape_before_cropping'],
dtype=predicted_resampled.dtype)
slicer = bounding_box_to_slice(properties_dict['bbox_used_for_cropping'])
original_shape_image[slicer] = predicted_resampled
del predicted_resampled
# Revert transpose
original_shape_image = original_shape_image.transpose(plans_manager.transpose_backward)
torch.set_num_threads(old_threads)
return original_shape_image
def export_prediction_from_logits(predicted_array_or_file: Union[np.ndarray, torch.Tensor], properties_dict: dict,
configuration_manager: ConfigurationManager,
plans_manager: PlansManager,
dataset_json_dict_or_file: Union[dict, str], output_file_truncated: str,
save_probabilities: bool = False):
# if isinstance(predicted_array_or_file, str):
# tmp = deepcopy(predicted_array_or_file)
# if predicted_array_or_file.endswith('.npy'):
# predicted_array_or_file = np.load(predicted_array_or_file)
# elif predicted_array_or_file.endswith('.npz'):
# predicted_array_or_file = np.load(predicted_array_or_file)['softmax']
# os.remove(tmp)
if isinstance(dataset_json_dict_or_file, str):
dataset_json_dict_or_file = load_json(dataset_json_dict_or_file)
label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file)
# ret = convert_predicted_logits_to_segmentation_with_correct_shape( #arthur removed
# predicted_array_or_file, plans_manager, configuration_manager, label_manager, properties_dict,
# return_probabilities=save_probabilities
# )
ret = convert_predicted_image_to_original_shape(
predicted_array_or_file, plans_manager, configuration_manager, properties_dict
)
del predicted_array_or_file
# save
if save_probabilities:
segmentation_final, probabilities_final = ret
np.savez_compressed(output_file_truncated + '.npz', probabilities=probabilities_final)
save_pickle(properties_dict, output_file_truncated + '.pkl')
del probabilities_final, ret
else:
segmentation_final = ret
del ret
rw = plans_manager.image_reader_writer_class()
rw.write_seg(segmentation_final, output_file_truncated + dataset_json_dict_or_file['file_ending'],
properties_dict)
return segmentation_final, properties_dict
# print(f"[EXP][after export_prediction_from_logits] seg shape={seg.shape}, "
# f"dtype={seg.dtype}, unique={np.unique(seg)}")
def resample_and_save(predicted: Union[torch.Tensor, np.ndarray], target_shape: List[int], output_file: str,
plans_manager: PlansManager, configuration_manager: ConfigurationManager, properties_dict: dict,
dataset_json_dict_or_file: Union[dict, str], num_threads_torch: int = default_num_processes) \
-> None:
# # needed for cascade
# if isinstance(predicted, str):
# assert isfile(predicted), "If isinstance(segmentation_softmax, str) then " \
# "isfile(segmentation_softmax) must be True"
# del_file = deepcopy(predicted)
# predicted = np.load(predicted)
# os.remove(del_file)
old_threads = torch.get_num_threads()
torch.set_num_threads(num_threads_torch)
if isinstance(dataset_json_dict_or_file, str):
dataset_json_dict_or_file = load_json(dataset_json_dict_or_file)
# resample to original shape
current_spacing = configuration_manager.spacing if \
len(configuration_manager.spacing) == len(properties_dict['shape_after_cropping_and_before_resampling']) else \
[properties_dict['spacing'][0], *configuration_manager.spacing]
target_spacing = configuration_manager.spacing if len(configuration_manager.spacing) == \
len(properties_dict['shape_after_cropping_and_before_resampling']) else \
[properties_dict['spacing'][0], *configuration_manager.spacing]
predicted_array_or_file = configuration_manager.resampling_fn_probabilities(predicted,
target_shape,
current_spacing,
target_spacing)
# create segmentation (argmax, regions, etc)
label_manager = plans_manager.get_label_manager(dataset_json_dict_or_file)
segmentation = label_manager.convert_logits_to_segmentation(predicted_array_or_file)
# segmentation may be torch.Tensor but we continue with numpy
if isinstance(segmentation, torch.Tensor):
segmentation = segmentation.cpu().numpy()
np.savez_compressed(output_file, seg=segmentation.astype(np.uint8))
torch.set_num_threads(old_threads)