Spaces:
Runtime error
Runtime error
# Copyright (c) MONAI Consortium | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import warnings | |
from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from stardist_pkg.big import _grid_divisible, BlockND, OBJECT_KEYS#, repaint_labels | |
from stardist_pkg.matching import relabel_sequential | |
from stardist_pkg import dist_to_coord, non_maximum_suppression, polygons_to_label | |
#from stardist_pkg import dist_to_coord, polygons_to_label | |
from stardist_pkg import star_dist,edt_prob | |
from monai.data.meta_tensor import MetaTensor | |
from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size | |
from monai.transforms import Resize | |
from monai.utils import ( | |
BlendMode, | |
PytorchPadMode, | |
convert_data_type, | |
convert_to_dst_type, | |
ensure_tuple, | |
fall_back_tuple, | |
look_up_option, | |
optional_import, | |
) | |
import cv2 | |
from scipy import ndimage | |
from scipy.ndimage.filters import gaussian_filter | |
from scipy.ndimage.interpolation import affine_transform, map_coordinates | |
from skimage import morphology as morph | |
from scipy.ndimage import filters, measurements | |
from scipy.ndimage.morphology import ( | |
binary_dilation, | |
binary_fill_holes, | |
distance_transform_cdt, | |
distance_transform_edt, | |
) | |
from skimage.segmentation import watershed | |
tqdm, _ = optional_import("tqdm", name="tqdm") | |
__all__ = ["sliding_window_inference"] | |
#### | |
def normalize(mask, dtype=np.uint8): | |
return (255 * mask / np.amax(mask)).astype(dtype) | |
def fix_mirror_padding(ann): | |
"""Deal with duplicated instances due to mirroring in interpolation | |
during shape augmentation (scale, rotation etc.). | |
""" | |
current_max_id = np.amax(ann) | |
inst_list = list(np.unique(ann)) | |
if 0 in inst_list: | |
inst_list.remove(0) # 0 is background | |
for inst_id in inst_list: | |
inst_map = np.array(ann == inst_id, np.uint8) | |
remapped_ids = measurements.label(inst_map)[0] | |
remapped_ids[remapped_ids > 1] += current_max_id | |
ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1] | |
current_max_id = np.amax(ann) | |
return ann | |
#### | |
def get_bounding_box(img): | |
"""Get bounding box coordinate information.""" | |
rows = np.any(img, axis=1) | |
cols = np.any(img, axis=0) | |
rmin, rmax = np.where(rows)[0][[0, -1]] | |
cmin, cmax = np.where(cols)[0][[0, -1]] | |
# due to python indexing, need to add 1 to max | |
# else accessing will be 1px in the box, not out | |
rmax += 1 | |
cmax += 1 | |
return [rmin, rmax, cmin, cmax] | |
#### | |
def cropping_center(x, crop_shape, batch=False): | |
"""Crop an input image at the centre. | |
Args: | |
x: input array | |
crop_shape: dimensions of cropped array | |
Returns: | |
x: cropped array | |
""" | |
orig_shape = x.shape | |
if not batch: | |
h0 = int((orig_shape[0] - crop_shape[0]) * 0.5) | |
w0 = int((orig_shape[1] - crop_shape[1]) * 0.5) | |
x = x[h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]] | |
else: | |
h0 = int((orig_shape[1] - crop_shape[0]) * 0.5) | |
w0 = int((orig_shape[2] - crop_shape[1]) * 0.5) | |
x = x[:, h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]] | |
return x | |
def gen_instance_hv_map(ann, crop_shape): | |
"""Input annotation must be of original shape. | |
The map is calculated only for instances within the crop portion | |
but based on the original shape in original image. | |
Perform following operation: | |
Obtain the horizontal and vertical distance maps for each | |
nuclear instance. | |
""" | |
orig_ann = ann.copy() # instance ID map | |
fixed_ann = fix_mirror_padding(orig_ann) | |
# re-cropping with fixed instance id map | |
crop_ann = cropping_center(fixed_ann, crop_shape) | |
# TODO: deal with 1 label warning | |
crop_ann = morph.remove_small_objects(crop_ann, min_size=30) | |
x_map = np.zeros(orig_ann.shape[:2], dtype=np.float32) | |
y_map = np.zeros(orig_ann.shape[:2], dtype=np.float32) | |
inst_list = list(np.unique(crop_ann)) | |
if 0 in inst_list: | |
inst_list.remove(0) # 0 is background | |
for inst_id in inst_list: | |
inst_map = np.array(fixed_ann == inst_id, np.uint8) | |
inst_box = get_bounding_box(inst_map) # rmin, rmax, cmin, cmax | |
# expand the box by 2px | |
# Because we first pad the ann at line 207, the bboxes | |
# will remain valid after expansion | |
inst_box[0] -= 2 | |
inst_box[2] -= 2 | |
inst_box[1] += 2 | |
inst_box[3] += 2 | |
# fix inst_box | |
inst_box[0] = max(inst_box[0], 0) | |
inst_box[2] = max(inst_box[2], 0) | |
# inst_box[1] = min(inst_box[1], fixed_ann.shape[0]) | |
# inst_box[3] = min(inst_box[3], fixed_ann.shape[1]) | |
inst_map = inst_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] | |
if inst_map.shape[0] < 2 or inst_map.shape[1] < 2: | |
print(f'inst_map.shape < 2: {inst_map.shape}, {inst_box}, {get_bounding_box(np.array(fixed_ann == inst_id, np.uint8))}') | |
continue | |
# instance center of mass, rounded to nearest pixel | |
inst_com = list(measurements.center_of_mass(inst_map)) | |
if np.isnan(measurements.center_of_mass(inst_map)).any(): | |
print(inst_id, fixed_ann.shape, np.array(fixed_ann == inst_id, np.uint8).shape) | |
print(get_bounding_box(np.array(fixed_ann == inst_id, np.uint8))) | |
print(inst_map) | |
print(inst_list) | |
print(inst_box) | |
print(np.count_nonzero(np.array(fixed_ann == inst_id, np.uint8))) | |
inst_com[0] = int(inst_com[0] + 0.5) | |
inst_com[1] = int(inst_com[1] + 0.5) | |
inst_x_range = np.arange(1, inst_map.shape[1] + 1) | |
inst_y_range = np.arange(1, inst_map.shape[0] + 1) | |
# shifting center of pixels grid to instance center of mass | |
inst_x_range -= inst_com[1] | |
inst_y_range -= inst_com[0] | |
inst_x, inst_y = np.meshgrid(inst_x_range, inst_y_range) | |
# remove coord outside of instance | |
inst_x[inst_map == 0] = 0 | |
inst_y[inst_map == 0] = 0 | |
inst_x = inst_x.astype("float32") | |
inst_y = inst_y.astype("float32") | |
# normalize min into -1 scale | |
if np.min(inst_x) < 0: | |
inst_x[inst_x < 0] /= -np.amin(inst_x[inst_x < 0]) | |
if np.min(inst_y) < 0: | |
inst_y[inst_y < 0] /= -np.amin(inst_y[inst_y < 0]) | |
# normalize max into +1 scale | |
if np.max(inst_x) > 0: | |
inst_x[inst_x > 0] /= np.amax(inst_x[inst_x > 0]) | |
if np.max(inst_y) > 0: | |
inst_y[inst_y > 0] /= np.amax(inst_y[inst_y > 0]) | |
#### | |
x_map_box = x_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] | |
x_map_box[inst_map > 0] = inst_x[inst_map > 0] | |
y_map_box = y_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]] | |
y_map_box[inst_map > 0] = inst_y[inst_map > 0] | |
hv_map = np.dstack([x_map, y_map]) | |
return hv_map | |
def remove_small_objects(pred, min_size=64, connectivity=1): | |
"""Remove connected components smaller than the specified size. | |
This function is taken from skimage.morphology.remove_small_objects, but the warning | |
is removed when a single label is provided. | |
Args: | |
pred: input labelled array | |
min_size: minimum size of instance in output array | |
connectivity: The connectivity defining the neighborhood of a pixel. | |
Returns: | |
out: output array with instances removed under min_size | |
""" | |
out = pred | |
if min_size == 0: # shortcut for efficiency | |
return out | |
if out.dtype == bool: | |
selem = ndimage.generate_binary_structure(pred.ndim, connectivity) | |
ccs = np.zeros_like(pred, dtype=np.int32) | |
ndimage.label(pred, selem, output=ccs) | |
else: | |
ccs = out | |
try: | |
component_sizes = np.bincount(ccs.ravel()) | |
except ValueError: | |
raise ValueError( | |
"Negative value labels are not supported. Try " | |
"relabeling the input with `scipy.ndimage.label` or " | |
"`skimage.morphology.label`." | |
) | |
too_small = component_sizes < min_size | |
too_small_mask = too_small[ccs] | |
out[too_small_mask] = 0 | |
return out | |
#### | |
def gen_targets(ann, crop_shape, **kwargs): | |
"""Generate the targets for the network.""" | |
hv_map = gen_instance_hv_map(ann, crop_shape) | |
np_map = ann.copy() | |
np_map[np_map > 0] = 1 | |
hv_map = cropping_center(hv_map, crop_shape) | |
np_map = cropping_center(np_map, crop_shape) | |
target_dict = { | |
"hv_map": hv_map, | |
"np_map": np_map, | |
} | |
return target_dict | |
def __proc_np_hv(pred, np_thres=0.5, ksize=21, overall_thres=0.4, obj_size_thres=10): | |
"""Process Nuclei Prediction with XY Coordinate Map. | |
Args: | |
pred: prediction output, assuming | |
channel 0 contain probability map of nuclei | |
channel 1 containing the regressed X-map | |
channel 2 containing the regressed Y-map | |
""" | |
pred = np.array(pred, dtype=np.float32) | |
blb_raw = pred[..., 0] | |
h_dir_raw = pred[..., 1] | |
v_dir_raw = pred[..., 2] | |
# processing | |
blb = np.array(blb_raw >= np_thres, dtype=np.int32) | |
blb = measurements.label(blb)[0] | |
blb = remove_small_objects(blb, min_size=10) | |
blb[blb > 0] = 1 # background is 0 already | |
h_dir = cv2.normalize( | |
h_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F | |
) | |
v_dir = cv2.normalize( | |
v_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F | |
) | |
sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=ksize) | |
sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=ksize) | |
sobelh = 1 - ( | |
cv2.normalize( | |
sobelh, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F | |
) | |
) | |
sobelv = 1 - ( | |
cv2.normalize( | |
sobelv, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F | |
) | |
) | |
overall = np.maximum(sobelh, sobelv) | |
overall = overall - (1 - blb) | |
overall[overall < 0] = 0 | |
dist = (1.0 - overall) * blb | |
## nuclei values form mountains so inverse to get basins | |
dist = -cv2.GaussianBlur(dist, (3, 3), 0) | |
overall = np.array(overall >= overall_thres, dtype=np.int32) | |
marker = blb - overall | |
marker[marker < 0] = 0 | |
marker = binary_fill_holes(marker).astype("uint8") | |
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel) | |
marker = measurements.label(marker)[0] | |
marker = remove_small_objects(marker, min_size=obj_size_thres) | |
proced_pred = watershed(dist, markers=marker, mask=blb) | |
return proced_pred | |
#### | |
def colorize(ch, vmin, vmax): | |
"""Will clamp value value outside the provided range to vmax and vmin.""" | |
cmap = plt.get_cmap("jet") | |
ch = np.squeeze(ch.astype("float32")) | |
vmin = vmin if vmin is not None else ch.min() | |
vmax = vmax if vmax is not None else ch.max() | |
ch[ch > vmax] = vmax # clamp value | |
ch[ch < vmin] = vmin | |
ch = (ch - vmin) / (vmax - vmin + 1.0e-16) | |
# take RGB from RGBA heat map | |
ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8") | |
return ch_cmap | |
#### | |
def random_colors(N, bright=True): | |
"""Generate random colors. | |
To get visually distinct colors, generate them in HSV space then | |
convert to RGB. | |
""" | |
brightness = 1.0 if bright else 0.7 | |
hsv = [(i / N, 1, brightness) for i in range(N)] | |
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv)) | |
random.shuffle(colors) | |
return colors | |
#### | |
def visualize_instances_map( | |
input_image, inst_map, type_map=None, type_colour=None, line_thickness=2 | |
): | |
"""Overlays segmentation results on image as contours. | |
Args: | |
input_image: input image | |
inst_map: instance mask with unique value for every object | |
type_map: type mask with unique value for every class | |
type_colour: a dict of {type : colour} , `type` is from 0-N | |
and `colour` is a tuple of (R, G, B) | |
line_thickness: line thickness of contours | |
Returns: | |
overlay: output image with segmentation overlay as contours | |
""" | |
overlay = np.copy((input_image).astype(np.uint8)) | |
inst_list = list(np.unique(inst_map)) # get list of instances | |
inst_list.remove(0) # remove background | |
inst_rng_colors = random_colors(len(inst_list)) | |
inst_rng_colors = np.array(inst_rng_colors) * 255 | |
inst_rng_colors = inst_rng_colors.astype(np.uint8) | |
for inst_idx, inst_id in enumerate(inst_list): | |
inst_map_mask = np.array(inst_map == inst_id, np.uint8) # get single object | |
y1, y2, x1, x2 = get_bounding_box(inst_map_mask) | |
y1 = y1 - 2 if y1 - 2 >= 0 else y1 | |
x1 = x1 - 2 if x1 - 2 >= 0 else x1 | |
x2 = x2 + 2 if x2 + 2 <= inst_map.shape[1] - 1 else x2 | |
y2 = y2 + 2 if y2 + 2 <= inst_map.shape[0] - 1 else y2 | |
inst_map_crop = inst_map_mask[y1:y2, x1:x2] | |
contours_crop = cv2.findContours( | |
inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE | |
) | |
# only has 1 instance per map, no need to check #contour detected by opencv | |
contours_crop = np.squeeze( | |
contours_crop[0][0].astype("int32") | |
) # * opencv protocol format may break | |
contours_crop += np.asarray([[x1, y1]]) # index correction | |
if type_map is not None: | |
type_map_crop = type_map[y1:y2, x1:x2] | |
type_id = np.unique(type_map_crop).max() # non-zero | |
inst_colour = type_colour[type_id] | |
else: | |
inst_colour = (inst_rng_colors[inst_idx]).tolist() | |
cv2.drawContours(overlay, [contours_crop], -1, inst_colour, line_thickness) | |
return overlay | |
def sliding_window_inference_large(inputs,block_size,min_overlap,context,roi_size,sw_batch_size,predictor,device): | |
h,w = inputs.shape[0],inputs.shape[1] | |
if h < 5000 or w < 5000: | |
test_tensor = torch.from_numpy(np.expand_dims(inputs, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device) | |
output_dist,output_prob = sliding_window_inference(test_tensor, roi_size, sw_batch_size, predictor) | |
prob = output_prob[0][0].cpu().numpy() | |
dist = output_dist[0].cpu().numpy() | |
dist = np.transpose(dist,(1,2,0)) | |
dist = np.maximum(1e-3, dist) | |
if h*w < 1500*1500: | |
points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.55, nms_thresh=0.4,cut=True) | |
else: | |
points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4) | |
labels_out = polygons_to_label(disti, points, prob=probi,shape=prob.shape) | |
else: | |
n = inputs.ndim | |
axes = 'YXC' | |
grid = (1,1,1) | |
if np.isscalar(block_size): block_size = n*[block_size] | |
if np.isscalar(min_overlap): min_overlap = n*[min_overlap] | |
if np.isscalar(context): context = n*[context] | |
shape_out = (inputs.shape[0],inputs.shape[1]) | |
labels_out = np.zeros(shape_out, dtype=np.uint64) | |
#print(inputs.dtype) | |
block_size[2] = inputs.shape[2] | |
min_overlap[2] = context[2] = 0 | |
block_size = tuple(_grid_divisible(g, v, name='block_size', verbose=False) for v,g,a in zip(block_size, grid,axes)) | |
min_overlap = tuple(_grid_divisible(g, v, name='min_overlap', verbose=False) for v,g,a in zip(min_overlap,grid,axes)) | |
context = tuple(_grid_divisible(g, v, name='context', verbose=False) for v,g,a in zip(context, grid,axes)) | |
print(f'effective: block_size={block_size}, min_overlap={min_overlap}, context={context}', flush=True) | |
blocks = BlockND.cover(inputs.shape, axes, block_size, min_overlap, context) | |
label_offset = 1 | |
blocks = tqdm(blocks) | |
for block in blocks: | |
image = block.read(inputs, axes=axes) | |
test_tensor = torch.from_numpy(np.expand_dims(image, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device) | |
output_dist,output_prob = sliding_window_inference(test_tensor, roi_size, sw_batch_size, predictor) | |
prob = output_prob[0][0].cpu().numpy() | |
dist = output_dist[0].cpu().numpy() | |
dist = np.transpose(dist,(1,2,0)) | |
dist = np.maximum(1e-3, dist) | |
points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4) | |
coord = dist_to_coord(disti,points) | |
polys = dict(coord=coord, points=points, prob=probi) | |
labels = polygons_to_label(disti, points, prob=probi,shape=prob.shape) | |
labels = block.crop_context(labels, axes='YX') | |
labels, polys = block.filter_objects(labels, polys, axes='YX') | |
labels = relabel_sequential(labels, label_offset)[0] | |
if labels_out is not None: | |
block.write(labels_out, labels, axes='YX') | |
#for k,v in polys.items(): | |
#polys_all.setdefault(k,[]).append(v) | |
label_offset += len(polys['prob']) | |
del labels | |
#polys_all = {k: (np.concatenate(v) if k in OBJECT_KEYS else v[0]) for k,v in polys_all.items()} | |
return labels_out | |
def sliding_window_inference( | |
inputs: torch.Tensor, | |
roi_size: Union[Sequence[int], int], | |
sw_batch_size: int, | |
predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]], | |
overlap: float = 0.25, | |
mode: Union[BlendMode, str] = BlendMode.CONSTANT, | |
sigma_scale: Union[Sequence[float], float] = 0.125, | |
padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, | |
cval: float = 0.0, | |
sw_device: Union[torch.device, str, None] = None, | |
device: Union[torch.device, str, None] = None, | |
progress: bool = False, | |
roi_weight_map: Union[torch.Tensor, None] = None, | |
*args: Any, | |
**kwargs: Any, | |
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]: | |
""" | |
Sliding window inference on `inputs` with `predictor`. | |
The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors. | |
Each output in the tuple or dict value is allowed to have different resolutions with respect to the input. | |
e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes | |
could be ([128,64,256], [64,32,128]). | |
In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still | |
an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters | |
so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension). | |
When roi_size is larger than the inputs' spatial size, the input image are padded during inference. | |
To maintain the same spatial sizes, the output image will be cropped to the original input size. | |
Args: | |
inputs: input image to be processed (assuming NCHW[D]) | |
roi_size: the spatial window size for inferences. | |
When its components have None or non-positives, the corresponding inputs dimension will be used. | |
if the components of the `roi_size` are non-positive values, the transform will use the | |
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted | |
to `(32, 64)` if the second spatial dimension size of img is `64`. | |
sw_batch_size: the batch size to run window slices. | |
predictor: given input tensor ``patch_data`` in shape NCHW[D], | |
The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary | |
with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D']; | |
where H'W'[D'] represents the output patch's spatial size, M is the number of output channels, | |
N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128), | |
the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)). | |
In this case, the parameter `overlap` and `roi_size` need to be carefully chosen | |
to ensure the scaled output ROI sizes are still integers. | |
If the `predictor`'s input and output spatial sizes are different, | |
we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension. | |
overlap: Amount of overlap between scans. | |
mode: {``"constant"``, ``"gaussian"``} | |
How to blend output of overlapping windows. Defaults to ``"constant"``. | |
- ``"constant``": gives equal weight to all predictions. | |
- ``"gaussian``": gives less weight to predictions on edges of windows. | |
sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``. | |
Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``. | |
When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding | |
spatial dimensions. | |
padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``} | |
Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"`` | |
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html | |
cval: fill value for 'constant' padding mode. Default: 0 | |
sw_device: device for the window data. | |
By default the device (and accordingly the memory) of the `inputs` is used. | |
Normally `sw_device` should be consistent with the device where `predictor` is defined. | |
device: device for the stitched output prediction. | |
By default the device (and accordingly the memory) of the `inputs` is used. If for example | |
set to device=torch.device('cpu') the gpu memory consumption is less and independent of the | |
`inputs` and `roi_size`. Output is on the `device`. | |
progress: whether to print a `tqdm` progress bar. | |
roi_weight_map: pre-computed (non-negative) weight map for each ROI. | |
If not given, and ``mode`` is not `constant`, this map will be computed on the fly. | |
args: optional args to be passed to ``predictor``. | |
kwargs: optional keyword args to be passed to ``predictor``. | |
Note: | |
- input must be channel-first and have a batch dim, supports N-D sliding window. | |
""" | |
compute_dtype = inputs.dtype | |
num_spatial_dims = len(inputs.shape) - 2 | |
if overlap < 0 or overlap >= 1: | |
raise ValueError("overlap must be >= 0 and < 1.") | |
# determine image spatial size and batch size | |
# Note: all input images must have the same image size and batch size | |
batch_size, _, *image_size_ = inputs.shape | |
if device is None: | |
device = inputs.device | |
if sw_device is None: | |
sw_device = inputs.device | |
roi_size = fall_back_tuple(roi_size, image_size_) | |
# in case that image size is smaller than roi size | |
image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims)) | |
pad_size = [] | |
for k in range(len(inputs.shape) - 1, 1, -1): | |
diff = max(roi_size[k - 2] - inputs.shape[k], 0) | |
half = diff // 2 | |
pad_size.extend([half, diff - half]) | |
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval) | |
#print('inputs',inputs.shape) | |
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) | |
# Store all slices in list | |
slices = dense_patch_slices(image_size, roi_size, scan_interval) | |
num_win = len(slices) # number of windows per image | |
total_slices = num_win * batch_size # total number of windows | |
# Create window-level importance map | |
valid_patch_size = get_valid_patch_size(image_size, roi_size) | |
if valid_patch_size == roi_size and (roi_weight_map is not None): | |
importance_map = roi_weight_map | |
else: | |
try: | |
importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device) | |
except BaseException as e: | |
raise RuntimeError( | |
"Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'." | |
) from e | |
importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore | |
# handle non-positive weights | |
min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3) | |
importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype) | |
# Perform predictions | |
dict_key, output_image_list, count_map_list = None, [], [] | |
_initialized_ss = -1 | |
is_tensor_output = True # whether the predictor's output is a tensor (instead of dict/tuple) | |
# for each patch | |
for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size): | |
slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices)) | |
unravel_slice = [ | |
[slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) | |
for idx in slice_range | |
] | |
window_data = torch.cat( | |
[convert_data_type(inputs[win_slice], torch.Tensor)[0] for win_slice in unravel_slice] | |
).to(sw_device) | |
seg_prob_out = predictor(window_data, *args, **kwargs) # batched patch segmentation | |
#print('seg_prob_out',seg_prob_out[0].shape) | |
# convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory. | |
seg_prob_tuple: Tuple[torch.Tensor, ...] | |
if isinstance(seg_prob_out, torch.Tensor): | |
seg_prob_tuple = (seg_prob_out,) | |
elif isinstance(seg_prob_out, Mapping): | |
if dict_key is None: | |
dict_key = sorted(seg_prob_out.keys()) # track predictor's output keys | |
seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key) | |
is_tensor_output = False | |
else: | |
seg_prob_tuple = ensure_tuple(seg_prob_out) | |
is_tensor_output = False | |
# for each output in multi-output list | |
for ss, seg_prob in enumerate(seg_prob_tuple): | |
seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN | |
# compute zoom scale: out_roi_size/in_roi_size | |
zoom_scale = [] | |
for axis, (img_s_i, out_w_i, in_w_i) in enumerate( | |
zip(image_size, seg_prob.shape[2:], window_data.shape[2:]) | |
): | |
_scale = out_w_i / float(in_w_i) | |
if not (img_s_i * _scale).is_integer(): | |
warnings.warn( | |
f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial " | |
f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs." | |
) | |
zoom_scale.append(_scale) | |
if _initialized_ss < ss: # init. the ss-th buffer at the first iteration | |
# construct multi-resolution outputs | |
output_classes = seg_prob.shape[1] | |
output_shape = [batch_size, output_classes] + [ | |
int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale) | |
] | |
# allocate memory to store the full output and the count for overlapping parts | |
output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device=device)) | |
count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device)) | |
_initialized_ss += 1 | |
# resizing the importance_map | |
resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False) | |
# store the result in the proper location of the full output. Apply weights from importance map. | |
for idx, original_idx in zip(slice_range, unravel_slice): | |
# zoom roi | |
original_idx_zoom = list(original_idx) # 4D for 2D image, 5D for 3D image | |
for axis in range(2, len(original_idx_zoom)): | |
zoomed_start = original_idx[axis].start * zoom_scale[axis - 2] | |
zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2] | |
if not zoomed_start.is_integer() or (not zoomed_end.is_integer()): | |
warnings.warn( | |
f"For axis-{axis-2} of output[{ss}], the output roi range is not int. " | |
f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). " | |
f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. " | |
f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n" | |
f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. " | |
"Tips: if overlap*roi_size*zoom_scale is an integer, it usually works." | |
) | |
original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None) | |
importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(compute_dtype) | |
# store results and weights | |
output_image_list[ss][original_idx_zoom] += importance_map_zoom * seg_prob[idx - slice_g] | |
count_map_list[ss][original_idx_zoom] += ( | |
importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape) | |
) | |
# account for any overlapping sections | |
for ss in range(len(output_image_list)): | |
output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype) | |
# remove padding if image_size smaller than roi_size | |
for ss, output_i in enumerate(output_image_list): | |
if torch.isnan(output_i).any() or torch.isinf(output_i).any(): | |
warnings.warn("Sliding window inference results contain NaN or Inf.") | |
zoom_scale = [ | |
seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size) | |
] | |
final_slicing: List[slice] = [] | |
for sp in range(num_spatial_dims): | |
slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2]) | |
slice_dim = slice( | |
int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])), | |
int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])), | |
) | |
final_slicing.insert(0, slice_dim) | |
while len(final_slicing) < len(output_i.shape): | |
final_slicing.insert(0, slice(None)) | |
output_image_list[ss] = output_i[final_slicing] | |
if dict_key is not None: # if output of predictor is a dict | |
final_output = dict(zip(dict_key, output_image_list)) | |
else: | |
final_output = tuple(output_image_list) # type: ignore | |
final_output = final_output[0] if is_tensor_output else final_output | |
if isinstance(inputs, MetaTensor): | |
final_output = convert_to_dst_type(final_output, inputs, device=device)[0] # type: ignore | |
return final_output | |
def _get_scan_interval( | |
image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float | |
) -> Tuple[int, ...]: | |
""" | |
Compute scan interval according to the image size, roi size and overlap. | |
Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0, | |
use 1 instead to make sure sliding window works. | |
""" | |
if len(image_size) != num_spatial_dims: | |
raise ValueError("image coord different from spatial dims.") | |
if len(roi_size) != num_spatial_dims: | |
raise ValueError("roi coord different from spatial dims.") | |
scan_interval = [] | |
for i in range(num_spatial_dims): | |
if roi_size[i] == image_size[i]: | |
scan_interval.append(int(roi_size[i])) | |
else: | |
interval = int(roi_size[i] * (1 - overlap)) | |
scan_interval.append(interval if interval > 0 else 1) | |
return tuple(scan_interval) | |