SwinUNETR_body_segmentation / postprocessing.py
Margerie's picture
requirements, model weights, preprocessing and post processing
5e2c32d verified
raw
history blame
18.1 kB
from monai.transforms import MapTransform, Transform
from monai.config import KeysCollection
from typing import Dict, Hashable, Mapping, Optional, Type, Union, Sequence
import torch, sys
from pathlib import Path
from monai.config import DtypeLike, KeysCollection, PathLike
from monai.data import image_writer
from monai.transforms.transform import MapTransform
from monai.utils import GridSamplePadMode, ensure_tuple, ensure_tuple_rep, optional_import
from monai.data.meta_tensor import MetaTensor
from monai.data.folder_layout import FolderLayout
from pydoc import locate
import numpy as np
import nibabel as nib, os
from monai.utils.enums import PostFix
DEFAULT_POST_FIX = PostFix.meta()
def set_header_info(nii_file, voxelsize, image_position_patient, contours_exist = None):
nii_file.header['pixdim'][1] = voxelsize[0]
nii_file.header['pixdim'][2] = voxelsize[1]
nii_file.header['pixdim'][3] = voxelsize[2]
#affine - voxelsize
nii_file.affine[0][0] = voxelsize[0]
nii_file.affine[1][1] = voxelsize[1]
nii_file.affine[2][2] = voxelsize[2]
#affine - imagecorner
nii_file.affine[0][3] = image_position_patient[0]
nii_file.affine[1][3] = image_position_patient[1]
nii_file.affine[2][3] = image_position_patient[2]
if contours_exist:
nii_file.header.extensions.append(nib.nifti1.Nifti1Extension(0, bytearray(contours_exist)))
return nii_file
def add_contours_exist(preddir, refCT):
img = nib.load(os.path.join(preddir, 'RTStruct.nii.gz'))
data = img.get_fdata().astype(int)
contours_exist = []
data_one_hot = np.zeros(data.shape[:3])
# We remove the first channel as it is the background
for i in range(data.shape[-1]-1):
if np.count_nonzero(data[:,:,:,i+1])>0:
contours_exist.append(1)
data_one_hot+=np.where(data[:,:,:,i+1]==1,2**i,0)
else:
contours_exist.append(0)
data_one_hot_nii = nib.Nifti1Image(data_one_hot, affine=np.eye(4))
data_one_hot_nii = set_header_info(data_one_hot_nii, voxelsize=np.array(refCT.PixelSpacing), image_position_patient=refCT.ImagePositionPatient, contours_exist=contours_exist)
nib.save(data_one_hot_nii,os.path.join(preddir, 'RTStruct.nii.gz'))
class SaveImaged(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.SaveImage`.
Note:
Image should be channel-first shape: [C,H,W,[D]].
If the data is a patch of big image, will append the patch index to filename.
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
meta_keys: explicitly indicate the key of the corresponding metadata dictionary.
For example, for data with key `image`, the metadata by default is in `image_meta_dict`.
The metadata is a dictionary contains values such as filename, original_shape.
This argument can be a sequence of string, map to the `keys`.
If `None`, will try to construct meta_keys by `key_{meta_key_postfix}`.
meta_key_postfix: if `meta_keys` is `None`, use `key_{meta_key_postfix}` to retrieve the metadict.
output_dir: output image directory.
output_postfix: a string appended to all output file names, default to `trans`.
output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`.
output_dtype: data type for saving data. Defaults to ``np.float32``.
resample: whether to resample image (if needed) before saving the data array,
based on the `spatial_shape` (and `original_affine`) from metadata.
mode: This option is used when ``resample=True``. Defaults to ``"nearest"``.
Depending on the writers, the possible options are:
- {``"bilinear"``, ``"nearest"``, ``"bicubic"``}.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
- {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}.
See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate
padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``.
Possible options are {``"zeros"``, ``"border"``, ``"reflection"``}
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling
[0, 255] (uint8) or [0, 65535] (uint16). Default is `None` (no scaling).
dtype: data type during resampling computation. Defaults to ``np.float64`` for best precision.
if None, use the data type of input data. To be compatible with other modules,
output_dtype: data type for saving data. Defaults to ``np.float32``.
it's used for NIfTI format only.
allow_missing_keys: don't raise exception if key is missing.
squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel
has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and
then if C==1, it will be saved as (H,W,D). If D is also 1, it will be saved as (H,W). If `false`,
image will always be saved as (H,W,D,C).
data_root_dir: if not empty, it specifies the beginning parts of the input file's
absolute path. It's used to compute `input_file_rel_path`, the relative path to the file from
`data_root_dir` to preserve folder structure when saving in case there are files in different
folders with the same file names. For example, with the following inputs:
- input_file_name: `/foo/bar/test1/image.nii`
- output_postfix: `seg`
- output_ext: `.nii.gz`
- output_dir: `/output`
- data_root_dir: `/foo/bar`
The output will be: /output/test1/image/image_seg.nii.gz
separate_folder: whether to save every file in a separate folder. For example: for the input filename
`image.nii`, postfix `seg` and folder_path `output`, if `separate_folder=True`, it will be saved as:
`output/image/image_seg.nii`, if `False`, saving as `output/image_seg.nii`. Default to `True`.
print_log: whether to print logs when saving. Default to `True`.
output_format: an optional string to specify the output image writer.
see also: `monai.data.image_writer.SUPPORTED_WRITERS`.
writer: a customised `monai.data.ImageWriter` subclass to save data arrays.
if `None`, use the default writer from `monai.data.image_writer` according to `output_ext`.
if it's a string, it's treated as a class name or dotted path;
the supported built-in writer classes are ``"NibabelWriter"``, ``"ITKWriter"``, ``"PILWriter"``.
"""
def __init__(
self,
keys: KeysCollection,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = DEFAULT_POST_FIX,
output_dir: Union[Path, str] = "./",
output_postfix: str = "trans",
output_ext: str = ".nii.gz",
resample: bool = True,
mode: str = "nearest",
padding_mode: str = GridSamplePadMode.BORDER,
scale: Optional[int] = None,
dtype: DtypeLike = np.float64,
output_dtype: DtypeLike = np.float32,
allow_missing_keys: bool = False,
squeeze_end_dims: bool = True,
data_root_dir: str = "",
separate_folder: bool = True,
print_log: bool = True,
output_format: str = "",
writer: Union[Type[image_writer.ImageWriter], str, None] = None,
) -> None:
super().__init__(keys, allow_missing_keys)
self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys))
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
self.saver = SaveImage(
output_dir=output_dir,
output_postfix=output_postfix,
output_ext=output_ext,
resample=resample,
mode=mode,
padding_mode=padding_mode,
scale=scale,
dtype=dtype,
output_dtype=output_dtype,
squeeze_end_dims=squeeze_end_dims,
data_root_dir=data_root_dir,
separate_folder=separate_folder,
print_log=print_log,
output_format=output_format,
writer=writer,
)
def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None):
self.saver.set_options(init_kwargs, data_kwargs, meta_kwargs, write_kwargs)
def __call__(self, data):
d = dict(data)
for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
if meta_key is None and meta_key_postfix is not None:
meta_key = f"{key}_{meta_key_postfix}"
meta_data = d.get(meta_key) if meta_key is not None else None
self.saver(img=d[key], meta_data=meta_data)
return d
class SaveImage(Transform):
"""
Save the image (in the form of torch tensor or numpy ndarray) and metadata dictionary into files.
The name of saved file will be `{input_image_name}_{output_postfix}{output_ext}`,
where the `input_image_name` is extracted from the provided metadata dictionary.
If no metadata provided, a running index starting from 0 will be used as the filename prefix.
Args:
output_dir: output image directory.
output_postfix: a string appended to all output file names, default to `trans`.
output_ext: output file extension name.
output_dtype: data type for saving data. Defaults to ``np.float32``.
resample: whether to resample image (if needed) before saving the data array,
based on the `spatial_shape` (and `original_affine`) from metadata.
mode: This option is used when ``resample=True``. Defaults to ``"nearest"``.
Depending on the writers, the possible options are
- {``"bilinear"``, ``"nearest"``, ``"bicubic"``}.
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
- {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}.
See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate
padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``.
Possible options are {``"zeros"``, ``"border"``, ``"reflection"``}
See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample
scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling
[0, 255] (uint8) or [0, 65535] (uint16). Default is `None` (no scaling).
dtype: data type during resampling computation. Defaults to ``np.float64`` for best precision.
if None, use the data type of input data. To be compatible with other modules,
squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel
has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and
then if C==1, it will be saved as (H,W,D). If D is also 1, it will be saved as (H,W). If `false`,
image will always be saved as (H,W,D,C).
data_root_dir: if not empty, it specifies the beginning parts of the input file's
absolute path. It's used to compute `input_file_rel_path`, the relative path to the file from
`data_root_dir` to preserve folder structure when saving in case there are files in different
folders with the same file names. For example, with the following inputs:
- input_file_name: `/foo/bar/test1/image.nii`
- output_postfix: `seg`
- output_ext: `.nii.gz`
- output_dir: `/output`
- data_root_dir: `/foo/bar`
The output will be: /output/test1/image/image_seg.nii.gz
separate_folder: whether to save every file in a separate folder. For example: for the input filename
`image.nii`, postfix `seg` and folder_path `output`, if `separate_folder=True`, it will be saved as:
`output/image/image_seg.nii`, if `False`, saving as `output/image_seg.nii`. Default to `True`.
print_log: whether to print logs when saving. Default to `True`.
output_format: an optional string of filename extension to specify the output image writer.
see also: `monai.data.image_writer.SUPPORTED_WRITERS`.
writer: a customised `monai.data.ImageWriter` subclass to save data arrays.
if `None`, use the default writer from `monai.data.image_writer` according to `output_ext`.
if it's a string, it's treated as a class name or dotted path (such as ``"monai.data.ITKWriter"``);
the supported built-in writer classes are ``"NibabelWriter"``, ``"ITKWriter"``, ``"PILWriter"``.
channel_dim: the index of the channel dimension. Default to `0`.
`None` to indicate no channel dimension.
"""
def __init__(
self,
output_dir: PathLike = "./",
output_postfix: str = "trans",
output_ext: str = ".nii.gz",
output_dtype: DtypeLike = np.float32,
resample: bool = True,
mode: str = "nearest",
padding_mode: str = GridSamplePadMode.BORDER,
scale: Optional[int] = None,
dtype: DtypeLike = np.float64,
squeeze_end_dims: bool = True,
data_root_dir: PathLike = "",
separate_folder: bool = True,
print_log: bool = True,
output_format: str = "",
writer: Union[Type[image_writer.ImageWriter], str, None] = None,
channel_dim: Optional[int] = 0,
) -> None:
self.folder_layout = FolderLayout(
output_dir=output_dir,
postfix=output_postfix,
extension=output_ext,
parent=separate_folder,
makedirs=True,
data_root_dir=data_root_dir,
)
self.output_ext = output_ext.lower() or output_format.lower()
if isinstance(writer, str):
writer_, has_built_in = optional_import("monai.data", name=f"{writer}") # search built-in
if not has_built_in:
writer_ = locate(f"{writer}") # search dotted path
if writer_ is None:
raise ValueError(f"writer {writer} not found")
writer = writer_
self.writers = image_writer.resolve_writer(self.output_ext) if writer is None else (writer,)
self.writer_obj = None
_output_dtype = output_dtype
if self.output_ext == ".png" and _output_dtype not in (np.uint8, np.uint16):
_output_dtype = np.uint8
if self.output_ext == ".dcm" and _output_dtype not in (np.uint8, np.uint16):
_output_dtype = np.uint8
self.init_kwargs = {"output_dtype": _output_dtype, "scale": scale}
self.data_kwargs = {"squeeze_end_dims": squeeze_end_dims, "channel_dim": channel_dim}
self.meta_kwargs = {"resample": resample, "mode": mode, "padding_mode": padding_mode, "dtype": dtype}
self.write_kwargs = {"verbose": print_log}
self._data_index = 0
def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None):
"""
Set the options for the underlying writer by updating the `self.*_kwargs` dictionaries.
The arguments correspond to the following usage:
- `writer = ImageWriter(**init_kwargs)`
- `writer.set_data_array(array, **data_kwargs)`
- `writer.set_metadata(meta_data, **meta_kwargs)`
- `writer.write(filename, **write_kwargs)`
"""
if init_kwargs is not None:
self.init_kwargs.update(init_kwargs)
if data_kwargs is not None:
self.data_kwargs.update(data_kwargs)
if meta_kwargs is not None:
self.meta_kwargs.update(meta_kwargs)
if write_kwargs is not None:
self.write_kwargs.update(write_kwargs)
def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None):
"""
Args:
img: target data content that save into file. The image should be channel-first, shape: `[C,H,W,[D]]`.
meta_data: key-value pairs of metadata corresponding to the data.
"""
meta_data = img.meta if isinstance(img, MetaTensor) else meta_data
subject = "RTStruct"#meta_data["patient_name"] if meta_data else str(self._data_index)
patch_index = None#meta_data.get(Key.PATCH_INDEX, None) if meta_data else None
filename = self.folder_layout.filename(subject=f"{subject}", idx=patch_index)
if meta_data and len(ensure_tuple(meta_data.get("spatial_shape", ()))) == len(img.shape):
self.data_kwargs["channel_dim"] = None
err = []
for writer_cls in self.writers:
try:
writer_obj = writer_cls(**self.init_kwargs)
writer_obj.set_data_array(data_array=img, **self.data_kwargs)
writer_obj.set_metadata(meta_dict=meta_data, **self.meta_kwargs)
writer_obj.write(filename, **self.write_kwargs)
self.writer_obj = writer_obj
except Exception as e:
print('err',e)
else:
self._data_index += 1
return img
msg = "\n".join([f"{e}" for e in err])
raise RuntimeError(
f"{self.__class__.__name__} cannot find a suitable writer for {filename}.\n"
" Please install the writer libraries, see also the installation instructions:\n"
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}"
)