|
|
|
|
|
""" |
|
Created on Fri Apr 1 19:18:58 2022 |
|
|
|
@author: jma |
|
""" |
|
|
|
from typing import Any, Callable, List, Sequence, Tuple, Union |
|
import torch |
|
import torch.nn.functional as F |
|
from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size |
|
from monai.utils import BlendMode, PytorchPadMode, fall_back_tuple, look_up_option |
|
|
|
|
|
__all__ = ["multi_task_sliding_window_inference"] |
|
|
|
def multi_task_sliding_window_inference( |
|
inputs: torch.Tensor, |
|
roi_size: Union[Sequence[int], int], |
|
sw_batch_size: int, |
|
predictor: Callable[..., torch.Tensor], |
|
overlap = 0.25, |
|
mode = "constant", |
|
sigma_scale = 0.125, |
|
padding_mode = "constant", |
|
cval = 0.0, |
|
sw_device = None, |
|
device = None, |
|
*args: Any, |
|
**kwargs: Any, |
|
) -> torch.Tensor: |
|
""" |
|
Sliding window inference on `inputs` with `predictor`. |
|
|
|
When roi_size is larger than the inputs' spatial size, the input image are padded during inference. |
|
To maintain the same spatial sizes, the output image will be cropped to the original input size. |
|
|
|
Args: |
|
inputs: input image to be processed (assuming NCHW[D]) |
|
roi_size: the spatial window size for inferences. |
|
When its components have None or non-positives, the corresponding inputs dimension will be used. |
|
if the components of the `roi_size` are non-positive values, the transform will use the |
|
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted |
|
to `(32, 64)` if the second spatial dimension size of img is `64`. |
|
sw_batch_size: the batch size to run window slices. |
|
predictor: given input tensor `patch_data` in shape NCHW[D], `predictor(patch_data)` |
|
should return a prediction with the same spatial shape and batch_size, i.e. NMHW[D]; |
|
where HW[D] represents the patch spatial size, M is the number of output channels, N is `sw_batch_size`. |
|
overlap: Amount of overlap between scans. |
|
mode: {``"constant"``, ``"gaussian"``} |
|
How to blend output of overlapping windows. Defaults to ``"constant"``. |
|
|
|
- ``"constant``": gives equal weight to all predictions. |
|
- ``"gaussian``": gives less weight to predictions on edges of windows. |
|
|
|
sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``. |
|
Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``. |
|
When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding |
|
spatial dimensions. |
|
padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``} |
|
Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"`` |
|
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html |
|
cval: fill value for 'constant' padding mode. Default: 0 |
|
sw_device: device for the window data. |
|
By default the device (and accordingly the memory) of the `inputs` is used. |
|
Normally `sw_device` should be consistent with the device where `predictor` is defined. |
|
device: device for the stitched output prediction. |
|
By default the device (and accordingly the memory) of the `inputs` is used. If for example |
|
set to device=torch.device('cpu') the gpu memory consumption is less and independent of the |
|
`inputs` and `roi_size`. Output is on the `device`. |
|
args: optional args to be passed to ``predictor``. |
|
kwargs: optional keyword args to be passed to ``predictor``. |
|
|
|
Note: |
|
- input must be channel-first and have a batch dim, supports N-D sliding window. |
|
|
|
""" |
|
num_spatial_dims = len(inputs.shape) - 2 |
|
if overlap < 0 or overlap >= 1: |
|
raise AssertionError("overlap must be >= 0 and < 1.") |
|
|
|
|
|
|
|
image_size_ = list(inputs.shape[2:]) |
|
batch_size = inputs.shape[0] |
|
|
|
if device is None: |
|
device = inputs.device |
|
if sw_device is None: |
|
sw_device = inputs.device |
|
|
|
roi_size = fall_back_tuple(roi_size, image_size_) |
|
|
|
image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims)) |
|
pad_size = [] |
|
for k in range(len(inputs.shape) - 1, 1, -1): |
|
diff = max(roi_size[k - 2] - inputs.shape[k], 0) |
|
half = diff // 2 |
|
pad_size.extend([half, diff - half]) |
|
inputs = F.pad(inputs, pad=pad_size, mode=mode, value=cval) |
|
|
|
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) |
|
|
|
|
|
slices = dense_patch_slices(image_size, roi_size, scan_interval) |
|
num_win = len(slices) |
|
total_slices = num_win * batch_size |
|
|
|
|
|
importance_map = compute_importance_map( |
|
get_valid_patch_size(image_size, roi_size), mode="gaussian", sigma_scale=sigma_scale, device=device |
|
) |
|
|
|
|
|
output_image, count_map = torch.tensor(0.0, device=device), torch.tensor(0.0, device=device) |
|
output_dist = torch.tensor(0.0, device=device) |
|
_initialized = False |
|
for slice_g in range(0, total_slices, sw_batch_size): |
|
slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices)) |
|
unravel_slice = [ |
|
[slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) |
|
for idx in slice_range |
|
] |
|
window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) |
|
seg_logit, seg_dist = predictor(window_data) |
|
seg_logit = torch.nn.functional.interpolate(seg_logit, size=roi_size, mode="bilinear", align_corners=False) |
|
seg_logit = torch.softmax(seg_logit, dim=1) |
|
seg_dist = torch.nn.functional.interpolate(seg_dist, size=roi_size, mode="bilinear", align_corners=False) |
|
seg_dist = torch.sigmoid(seg_dist) |
|
|
|
if not _initialized: |
|
output_classes = seg_logit.shape[1] |
|
dist_class = seg_dist.shape[1] |
|
output_shape = [batch_size, output_classes] + list(image_size) |
|
output_dist_shape = [batch_size, dist_class] + list(image_size) |
|
|
|
output_image = torch.zeros(output_shape, dtype=torch.float32, device=device) |
|
output_dist = torch.zeros(output_dist_shape, dtype=torch.float32, device=device) |
|
count_map = torch.zeros(output_shape, dtype=torch.float32, device=device) |
|
count_dist_map = torch.zeros(output_dist_shape, dtype=torch.float32, device=device) |
|
_initialized = True |
|
|
|
|
|
for idx, original_idx in zip(slice_range, unravel_slice): |
|
output_image[original_idx] += importance_map * seg_logit[idx - slice_g] |
|
output_dist[original_idx] += importance_map * seg_dist[idx - slice_g] |
|
count_map[original_idx] += importance_map |
|
count_dist_map[original_idx] += importance_map |
|
|
|
|
|
output_image = output_image / count_map |
|
output_dist = output_dist / count_dist_map |
|
|
|
final_slicing: List[slice] = [] |
|
for sp in range(num_spatial_dims): |
|
slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2]) |
|
final_slicing.insert(0, slice_dim) |
|
while len(final_slicing) < len(output_image.shape): |
|
final_slicing.insert(0, slice(None)) |
|
return output_image[final_slicing], output_dist[final_slicing] |
|
|
|
|
|
|
|
def _get_scan_interval( |
|
image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float |
|
) -> Tuple[int, ...]: |
|
""" |
|
Compute scan interval according to the image size, roi size and overlap. |
|
Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0, |
|
use 1 instead to make sure sliding window works. |
|
|
|
""" |
|
if len(image_size) != num_spatial_dims: |
|
raise ValueError("image coord different from spatial dims.") |
|
if len(roi_size) != num_spatial_dims: |
|
raise ValueError("roi coord different from spatial dims.") |
|
|
|
scan_interval = [] |
|
for i in range(num_spatial_dims): |
|
if roi_size[i] == image_size[i]: |
|
scan_interval.append(int(roi_size[i])) |
|
else: |
|
interval = int(roi_size[i] * (1 - overlap)) |
|
scan_interval.append(interval if interval > 0 else 1) |
|
return tuple(scan_interval) |
|
|
|
|