Lewislou commited on
Commit
0764ef3
·
1 Parent(s): 72897a5

Upload 2 files

Browse files
Files changed (2) hide show
  1. classifiers.py +261 -0
  2. utils_modify.py +743 -0
classifiers.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, Callable, List, Optional, Type, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+
8
+ def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
9
+ """3x3 convolution with padding"""
10
+ return nn.Conv2d(
11
+ in_planes,
12
+ out_planes,
13
+ kernel_size=3,
14
+ stride=stride,
15
+ padding=dilation,
16
+ groups=groups,
17
+ bias=False,
18
+ dilation=dilation,
19
+ )
20
+
21
+
22
+ def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
23
+ """1x1 convolution"""
24
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
25
+
26
+
27
+ class BasicBlock(nn.Module):
28
+ expansion: int = 1
29
+
30
+ def __init__(
31
+ self,
32
+ inplanes: int,
33
+ planes: int,
34
+ stride: int = 1,
35
+ downsample: Optional[nn.Module] = None,
36
+ groups: int = 1,
37
+ base_width: int = 64,
38
+ dilation: int = 1,
39
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
40
+ ) -> None:
41
+ super().__init__()
42
+ if norm_layer is None:
43
+ norm_layer = nn.BatchNorm2d
44
+ if groups != 1 or base_width != 64:
45
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
46
+ if dilation > 1:
47
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
48
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
49
+ self.conv1 = conv3x3(inplanes, planes, stride)
50
+ self.bn1 = norm_layer(planes)
51
+ self.relu = nn.ReLU(inplace=True)
52
+ self.conv2 = conv3x3(planes, planes)
53
+ self.bn2 = norm_layer(planes)
54
+ self.downsample = downsample
55
+ self.stride = stride
56
+
57
+ def forward(self, x: Tensor) -> Tensor:
58
+ identity = x
59
+
60
+ out = self.conv1(x)
61
+ out = self.bn1(out)
62
+ out = self.relu(out)
63
+
64
+ out = self.conv2(out)
65
+ out = self.bn2(out)
66
+
67
+ if self.downsample is not None:
68
+ identity = self.downsample(x)
69
+
70
+ out += identity
71
+ out = self.relu(out)
72
+
73
+ return out
74
+
75
+ class Bottleneck(nn.Module):
76
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
77
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
78
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
79
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
80
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
81
+
82
+ expansion: int = 4
83
+
84
+ def __init__(
85
+ self,
86
+ inplanes: int,
87
+ planes: int,
88
+ stride: int = 1,
89
+ downsample: Optional[nn.Module] = None,
90
+ groups: int = 1,
91
+ base_width: int = 64,
92
+ dilation: int = 1,
93
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
94
+ ) -> None:
95
+ super().__init__()
96
+ if norm_layer is None:
97
+ norm_layer = nn.BatchNorm2d
98
+ width = int(planes * (base_width / 64.0)) * groups
99
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
100
+ self.conv1 = conv1x1(inplanes, width)
101
+ self.bn1 = norm_layer(width)
102
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
103
+ self.bn2 = norm_layer(width)
104
+ self.conv3 = conv1x1(width, planes * self.expansion)
105
+ self.bn3 = norm_layer(planes * self.expansion)
106
+ self.relu = nn.ReLU(inplace=True)
107
+ self.downsample = downsample
108
+ self.stride = stride
109
+
110
+ def forward(self, x: Tensor) -> Tensor:
111
+ identity = x
112
+
113
+ out = self.conv1(x)
114
+ out = self.bn1(out)
115
+ out = self.relu(out)
116
+
117
+ out = self.conv2(out)
118
+ out = self.bn2(out)
119
+ out = self.relu(out)
120
+
121
+ out = self.conv3(out)
122
+ out = self.bn3(out)
123
+
124
+ if self.downsample is not None:
125
+ identity = self.downsample(x)
126
+
127
+ out += identity
128
+ out = self.relu(out)
129
+
130
+ return out
131
+
132
+ class ResNet(nn.Module):
133
+ def __init__(
134
+ self,
135
+ block: Type[Union[BasicBlock, Bottleneck]],
136
+ layers: List[int],
137
+ num_classes: int = 1000,
138
+ zero_init_residual: bool = False,
139
+ groups: int = 1,
140
+ width_per_group: int = 64,
141
+ replace_stride_with_dilation: Optional[List[bool]] = None,
142
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
143
+ ) -> None:
144
+ super().__init__()
145
+ # _log_api_usage_once(self)
146
+ if norm_layer is None:
147
+ norm_layer = nn.BatchNorm2d
148
+ self._norm_layer = norm_layer
149
+
150
+ self.inplanes = 64
151
+ self.dilation = 1
152
+ if replace_stride_with_dilation is None:
153
+ # each element in the tuple indicates if we should replace
154
+ # the 2x2 stride with a dilated convolution instead
155
+ replace_stride_with_dilation = [False, False, False]
156
+ if len(replace_stride_with_dilation) != 3:
157
+ raise ValueError(
158
+ "replace_stride_with_dilation should be None "
159
+ f"or a 3-element tuple, got {replace_stride_with_dilation}"
160
+ )
161
+ self.groups = groups
162
+ self.base_width = width_per_group
163
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
164
+ self.bn1 = norm_layer(self.inplanes)
165
+ self.relu = nn.ReLU(inplace=True)
166
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
167
+ self.layer1 = self._make_layer(block, 64, layers[0])
168
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
169
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
170
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
171
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
172
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
173
+
174
+ for m in self.modules():
175
+ if isinstance(m, nn.Conv2d):
176
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
177
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
178
+ nn.init.constant_(m.weight, 1)
179
+ nn.init.constant_(m.bias, 0)
180
+
181
+ # Zero-initialize the last BN in each residual branch,
182
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
183
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
184
+ if zero_init_residual:
185
+ for m in self.modules():
186
+ if isinstance(m, Bottleneck) and m.bn3.weight is not None:
187
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
188
+ elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
189
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
190
+
191
+ def _make_layer(
192
+ self,
193
+ block: Type[Union[BasicBlock, Bottleneck]],
194
+ planes: int,
195
+ blocks: int,
196
+ stride: int = 1,
197
+ dilate: bool = False,
198
+ ) -> nn.Sequential:
199
+ norm_layer = self._norm_layer
200
+ downsample = None
201
+ previous_dilation = self.dilation
202
+ if dilate:
203
+ self.dilation *= stride
204
+ stride = 1
205
+ if stride != 1 or self.inplanes != planes * block.expansion:
206
+ downsample = nn.Sequential(
207
+ conv1x1(self.inplanes, planes * block.expansion, stride),
208
+ norm_layer(planes * block.expansion),
209
+ )
210
+
211
+ layers = []
212
+ layers.append(
213
+ block(
214
+ self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
215
+ )
216
+ )
217
+ self.inplanes = planes * block.expansion
218
+ for _ in range(1, blocks):
219
+ layers.append(
220
+ block(
221
+ self.inplanes,
222
+ planes,
223
+ groups=self.groups,
224
+ base_width=self.base_width,
225
+ dilation=self.dilation,
226
+ norm_layer=norm_layer,
227
+ )
228
+ )
229
+
230
+ return nn.Sequential(*layers)
231
+
232
+ def _forward_impl(self, x: Tensor) -> Tensor:
233
+ # See note [TorchScript super()]
234
+ x = self.conv1(x)
235
+ x = self.bn1(x)
236
+ x = self.relu(x)
237
+ x = self.maxpool(x)
238
+
239
+ x = self.layer1(x)
240
+ x = self.layer2(x)
241
+ x = self.layer3(x)
242
+ x = self.layer4(x)
243
+
244
+ x = self.avgpool(x)
245
+ x = torch.flatten(x, 1)
246
+ x = self.fc(x)
247
+
248
+ return x
249
+
250
+ def forward(self, x: Tensor) -> Tensor:
251
+ return self._forward_impl(x)
252
+
253
+ def resnet18(weights=None):
254
+ # weights: path
255
+ model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=4)
256
+ if weights is not None:
257
+ model.load_state_dict(torch.load(weights))
258
+ return model
259
+
260
+ def resnet10():
261
+ return ResNet(BasicBlock, [1, 1, 1, 1], num_classes=4)
utils_modify.py ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import warnings
13
+ from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from stardist_pkg.big import _grid_divisible, BlockND, OBJECT_KEYS#, repaint_labels
18
+ from stardist_pkg.matching import relabel_sequential
19
+ from stardist_pkg import dist_to_coord, non_maximum_suppression, polygons_to_label
20
+ #from stardist_pkg import dist_to_coord, polygons_to_label
21
+ from stardist_pkg import star_dist,edt_prob
22
+ from monai.data.meta_tensor import MetaTensor
23
+ from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size
24
+ from monai.transforms import Resize
25
+ from monai.utils import (
26
+ BlendMode,
27
+ PytorchPadMode,
28
+ convert_data_type,
29
+ convert_to_dst_type,
30
+ ensure_tuple,
31
+ fall_back_tuple,
32
+ look_up_option,
33
+ optional_import,
34
+ )
35
+ import cv2
36
+ from scipy import ndimage
37
+ from scipy.ndimage.filters import gaussian_filter
38
+ from scipy.ndimage.interpolation import affine_transform, map_coordinates
39
+ from skimage import morphology as morph
40
+ from scipy.ndimage import filters, measurements
41
+ from scipy.ndimage.morphology import (
42
+ binary_dilation,
43
+ binary_fill_holes,
44
+ distance_transform_cdt,
45
+ distance_transform_edt,
46
+ )
47
+
48
+ from skimage.segmentation import watershed
49
+ tqdm, _ = optional_import("tqdm", name="tqdm")
50
+
51
+ __all__ = ["sliding_window_inference"]
52
+
53
+
54
+ ####
55
+ def normalize(mask, dtype=np.uint8):
56
+ return (255 * mask / np.amax(mask)).astype(dtype)
57
+
58
+ def fix_mirror_padding(ann):
59
+ """Deal with duplicated instances due to mirroring in interpolation
60
+ during shape augmentation (scale, rotation etc.).
61
+
62
+ """
63
+ current_max_id = np.amax(ann)
64
+ inst_list = list(np.unique(ann))
65
+ if 0 in inst_list:
66
+ inst_list.remove(0) # 0 is background
67
+ for inst_id in inst_list:
68
+ inst_map = np.array(ann == inst_id, np.uint8)
69
+ remapped_ids = measurements.label(inst_map)[0]
70
+ remapped_ids[remapped_ids > 1] += current_max_id
71
+ ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1]
72
+ current_max_id = np.amax(ann)
73
+ return ann
74
+
75
+ ####
76
+ def get_bounding_box(img):
77
+ """Get bounding box coordinate information."""
78
+ rows = np.any(img, axis=1)
79
+ cols = np.any(img, axis=0)
80
+ rmin, rmax = np.where(rows)[0][[0, -1]]
81
+ cmin, cmax = np.where(cols)[0][[0, -1]]
82
+ # due to python indexing, need to add 1 to max
83
+ # else accessing will be 1px in the box, not out
84
+ rmax += 1
85
+ cmax += 1
86
+ return [rmin, rmax, cmin, cmax]
87
+
88
+
89
+ ####
90
+ def cropping_center(x, crop_shape, batch=False):
91
+ """Crop an input image at the centre.
92
+
93
+ Args:
94
+ x: input array
95
+ crop_shape: dimensions of cropped array
96
+
97
+ Returns:
98
+ x: cropped array
99
+
100
+ """
101
+ orig_shape = x.shape
102
+ if not batch:
103
+ h0 = int((orig_shape[0] - crop_shape[0]) * 0.5)
104
+ w0 = int((orig_shape[1] - crop_shape[1]) * 0.5)
105
+ x = x[h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]]
106
+ else:
107
+ h0 = int((orig_shape[1] - crop_shape[0]) * 0.5)
108
+ w0 = int((orig_shape[2] - crop_shape[1]) * 0.5)
109
+ x = x[:, h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]]
110
+ return x
111
+
112
+ def gen_instance_hv_map(ann, crop_shape):
113
+ """Input annotation must be of original shape.
114
+
115
+ The map is calculated only for instances within the crop portion
116
+ but based on the original shape in original image.
117
+
118
+ Perform following operation:
119
+ Obtain the horizontal and vertical distance maps for each
120
+ nuclear instance.
121
+
122
+ """
123
+ orig_ann = ann.copy() # instance ID map
124
+ fixed_ann = fix_mirror_padding(orig_ann)
125
+ # re-cropping with fixed instance id map
126
+ crop_ann = cropping_center(fixed_ann, crop_shape)
127
+ # TODO: deal with 1 label warning
128
+ crop_ann = morph.remove_small_objects(crop_ann, min_size=30)
129
+
130
+ x_map = np.zeros(orig_ann.shape[:2], dtype=np.float32)
131
+ y_map = np.zeros(orig_ann.shape[:2], dtype=np.float32)
132
+
133
+ inst_list = list(np.unique(crop_ann))
134
+ if 0 in inst_list:
135
+ inst_list.remove(0) # 0 is background
136
+ for inst_id in inst_list:
137
+ inst_map = np.array(fixed_ann == inst_id, np.uint8)
138
+ inst_box = get_bounding_box(inst_map) # rmin, rmax, cmin, cmax
139
+
140
+ # expand the box by 2px
141
+ # Because we first pad the ann at line 207, the bboxes
142
+ # will remain valid after expansion
143
+ inst_box[0] -= 2
144
+ inst_box[2] -= 2
145
+ inst_box[1] += 2
146
+ inst_box[3] += 2
147
+
148
+ # fix inst_box
149
+ inst_box[0] = max(inst_box[0], 0)
150
+ inst_box[2] = max(inst_box[2], 0)
151
+ # inst_box[1] = min(inst_box[1], fixed_ann.shape[0])
152
+ # inst_box[3] = min(inst_box[3], fixed_ann.shape[1])
153
+
154
+ inst_map = inst_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
155
+
156
+ if inst_map.shape[0] < 2 or inst_map.shape[1] < 2:
157
+ print(f'inst_map.shape < 2: {inst_map.shape}, {inst_box}, {get_bounding_box(np.array(fixed_ann == inst_id, np.uint8))}')
158
+ continue
159
+
160
+ # instance center of mass, rounded to nearest pixel
161
+ inst_com = list(measurements.center_of_mass(inst_map))
162
+ if np.isnan(measurements.center_of_mass(inst_map)).any():
163
+ print(inst_id, fixed_ann.shape, np.array(fixed_ann == inst_id, np.uint8).shape)
164
+ print(get_bounding_box(np.array(fixed_ann == inst_id, np.uint8)))
165
+ print(inst_map)
166
+ print(inst_list)
167
+ print(inst_box)
168
+ print(np.count_nonzero(np.array(fixed_ann == inst_id, np.uint8)))
169
+
170
+ inst_com[0] = int(inst_com[0] + 0.5)
171
+ inst_com[1] = int(inst_com[1] + 0.5)
172
+
173
+ inst_x_range = np.arange(1, inst_map.shape[1] + 1)
174
+ inst_y_range = np.arange(1, inst_map.shape[0] + 1)
175
+ # shifting center of pixels grid to instance center of mass
176
+ inst_x_range -= inst_com[1]
177
+ inst_y_range -= inst_com[0]
178
+
179
+ inst_x, inst_y = np.meshgrid(inst_x_range, inst_y_range)
180
+
181
+ # remove coord outside of instance
182
+ inst_x[inst_map == 0] = 0
183
+ inst_y[inst_map == 0] = 0
184
+ inst_x = inst_x.astype("float32")
185
+ inst_y = inst_y.astype("float32")
186
+
187
+ # normalize min into -1 scale
188
+ if np.min(inst_x) < 0:
189
+ inst_x[inst_x < 0] /= -np.amin(inst_x[inst_x < 0])
190
+ if np.min(inst_y) < 0:
191
+ inst_y[inst_y < 0] /= -np.amin(inst_y[inst_y < 0])
192
+ # normalize max into +1 scale
193
+ if np.max(inst_x) > 0:
194
+ inst_x[inst_x > 0] /= np.amax(inst_x[inst_x > 0])
195
+ if np.max(inst_y) > 0:
196
+ inst_y[inst_y > 0] /= np.amax(inst_y[inst_y > 0])
197
+
198
+ ####
199
+ x_map_box = x_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
200
+ x_map_box[inst_map > 0] = inst_x[inst_map > 0]
201
+
202
+ y_map_box = y_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
203
+ y_map_box[inst_map > 0] = inst_y[inst_map > 0]
204
+
205
+ hv_map = np.dstack([x_map, y_map])
206
+ return hv_map
207
+
208
+ def remove_small_objects(pred, min_size=64, connectivity=1):
209
+ """Remove connected components smaller than the specified size.
210
+
211
+ This function is taken from skimage.morphology.remove_small_objects, but the warning
212
+ is removed when a single label is provided.
213
+
214
+ Args:
215
+ pred: input labelled array
216
+ min_size: minimum size of instance in output array
217
+ connectivity: The connectivity defining the neighborhood of a pixel.
218
+
219
+ Returns:
220
+ out: output array with instances removed under min_size
221
+
222
+ """
223
+ out = pred
224
+
225
+ if min_size == 0: # shortcut for efficiency
226
+ return out
227
+
228
+ if out.dtype == bool:
229
+ selem = ndimage.generate_binary_structure(pred.ndim, connectivity)
230
+ ccs = np.zeros_like(pred, dtype=np.int32)
231
+ ndimage.label(pred, selem, output=ccs)
232
+ else:
233
+ ccs = out
234
+
235
+ try:
236
+ component_sizes = np.bincount(ccs.ravel())
237
+ except ValueError:
238
+ raise ValueError(
239
+ "Negative value labels are not supported. Try "
240
+ "relabeling the input with `scipy.ndimage.label` or "
241
+ "`skimage.morphology.label`."
242
+ )
243
+
244
+ too_small = component_sizes < min_size
245
+ too_small_mask = too_small[ccs]
246
+ out[too_small_mask] = 0
247
+
248
+ return out
249
+
250
+ ####
251
+ def gen_targets(ann, crop_shape, **kwargs):
252
+ """Generate the targets for the network."""
253
+ hv_map = gen_instance_hv_map(ann, crop_shape)
254
+ np_map = ann.copy()
255
+ np_map[np_map > 0] = 1
256
+
257
+ hv_map = cropping_center(hv_map, crop_shape)
258
+ np_map = cropping_center(np_map, crop_shape)
259
+
260
+ target_dict = {
261
+ "hv_map": hv_map,
262
+ "np_map": np_map,
263
+ }
264
+
265
+ return target_dict
266
+ def __proc_np_hv(pred, np_thres=0.5, ksize=21, overall_thres=0.4, obj_size_thres=10):
267
+ """Process Nuclei Prediction with XY Coordinate Map.
268
+
269
+ Args:
270
+ pred: prediction output, assuming
271
+ channel 0 contain probability map of nuclei
272
+ channel 1 containing the regressed X-map
273
+ channel 2 containing the regressed Y-map
274
+
275
+ """
276
+ pred = np.array(pred, dtype=np.float32)
277
+
278
+ blb_raw = pred[..., 0]
279
+ h_dir_raw = pred[..., 1]
280
+ v_dir_raw = pred[..., 2]
281
+
282
+ # processing
283
+ blb = np.array(blb_raw >= np_thres, dtype=np.int32)
284
+
285
+ blb = measurements.label(blb)[0]
286
+ blb = remove_small_objects(blb, min_size=10)
287
+ blb[blb > 0] = 1 # background is 0 already
288
+
289
+ h_dir = cv2.normalize(
290
+ h_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
291
+ )
292
+ v_dir = cv2.normalize(
293
+ v_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
294
+ )
295
+
296
+ sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=ksize)
297
+ sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=ksize)
298
+
299
+ sobelh = 1 - (
300
+ cv2.normalize(
301
+ sobelh, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
302
+ )
303
+ )
304
+ sobelv = 1 - (
305
+ cv2.normalize(
306
+ sobelv, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
307
+ )
308
+ )
309
+
310
+ overall = np.maximum(sobelh, sobelv)
311
+ overall = overall - (1 - blb)
312
+ overall[overall < 0] = 0
313
+
314
+ dist = (1.0 - overall) * blb
315
+ ## nuclei values form mountains so inverse to get basins
316
+ dist = -cv2.GaussianBlur(dist, (3, 3), 0)
317
+
318
+ overall = np.array(overall >= overall_thres, dtype=np.int32)
319
+
320
+ marker = blb - overall
321
+ marker[marker < 0] = 0
322
+ marker = binary_fill_holes(marker).astype("uint8")
323
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
324
+ marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel)
325
+ marker = measurements.label(marker)[0]
326
+ marker = remove_small_objects(marker, min_size=obj_size_thres)
327
+
328
+ proced_pred = watershed(dist, markers=marker, mask=blb)
329
+
330
+ return proced_pred
331
+
332
+ ####
333
+ def colorize(ch, vmin, vmax):
334
+ """Will clamp value value outside the provided range to vmax and vmin."""
335
+ cmap = plt.get_cmap("jet")
336
+ ch = np.squeeze(ch.astype("float32"))
337
+ vmin = vmin if vmin is not None else ch.min()
338
+ vmax = vmax if vmax is not None else ch.max()
339
+ ch[ch > vmax] = vmax # clamp value
340
+ ch[ch < vmin] = vmin
341
+ ch = (ch - vmin) / (vmax - vmin + 1.0e-16)
342
+ # take RGB from RGBA heat map
343
+ ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8")
344
+ return ch_cmap
345
+
346
+
347
+ ####
348
+ def random_colors(N, bright=True):
349
+ """Generate random colors.
350
+
351
+ To get visually distinct colors, generate them in HSV space then
352
+ convert to RGB.
353
+ """
354
+ brightness = 1.0 if bright else 0.7
355
+ hsv = [(i / N, 1, brightness) for i in range(N)]
356
+ colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
357
+ random.shuffle(colors)
358
+ return colors
359
+
360
+
361
+ ####
362
+ def visualize_instances_map(
363
+ input_image, inst_map, type_map=None, type_colour=None, line_thickness=2
364
+ ):
365
+ """Overlays segmentation results on image as contours.
366
+
367
+ Args:
368
+ input_image: input image
369
+ inst_map: instance mask with unique value for every object
370
+ type_map: type mask with unique value for every class
371
+ type_colour: a dict of {type : colour} , `type` is from 0-N
372
+ and `colour` is a tuple of (R, G, B)
373
+ line_thickness: line thickness of contours
374
+
375
+ Returns:
376
+ overlay: output image with segmentation overlay as contours
377
+ """
378
+ overlay = np.copy((input_image).astype(np.uint8))
379
+
380
+ inst_list = list(np.unique(inst_map)) # get list of instances
381
+ inst_list.remove(0) # remove background
382
+
383
+ inst_rng_colors = random_colors(len(inst_list))
384
+ inst_rng_colors = np.array(inst_rng_colors) * 255
385
+ inst_rng_colors = inst_rng_colors.astype(np.uint8)
386
+
387
+ for inst_idx, inst_id in enumerate(inst_list):
388
+ inst_map_mask = np.array(inst_map == inst_id, np.uint8) # get single object
389
+ y1, y2, x1, x2 = get_bounding_box(inst_map_mask)
390
+ y1 = y1 - 2 if y1 - 2 >= 0 else y1
391
+ x1 = x1 - 2 if x1 - 2 >= 0 else x1
392
+ x2 = x2 + 2 if x2 + 2 <= inst_map.shape[1] - 1 else x2
393
+ y2 = y2 + 2 if y2 + 2 <= inst_map.shape[0] - 1 else y2
394
+ inst_map_crop = inst_map_mask[y1:y2, x1:x2]
395
+ contours_crop = cv2.findContours(
396
+ inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
397
+ )
398
+ # only has 1 instance per map, no need to check #contour detected by opencv
399
+ contours_crop = np.squeeze(
400
+ contours_crop[0][0].astype("int32")
401
+ ) # * opencv protocol format may break
402
+ contours_crop += np.asarray([[x1, y1]]) # index correction
403
+ if type_map is not None:
404
+ type_map_crop = type_map[y1:y2, x1:x2]
405
+ type_id = np.unique(type_map_crop).max() # non-zero
406
+ inst_colour = type_colour[type_id]
407
+ else:
408
+ inst_colour = (inst_rng_colors[inst_idx]).tolist()
409
+ cv2.drawContours(overlay, [contours_crop], -1, inst_colour, line_thickness)
410
+ return overlay
411
+
412
+
413
+ def sliding_window_inference_large(inputs,block_size,min_overlap,context,roi_size,sw_batch_size,predictor,device):
414
+
415
+ h,w = inputs.shape[0],inputs.shape[1]
416
+ if h < 5000 or w < 5000:
417
+ test_tensor = torch.from_numpy(np.expand_dims(inputs, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
418
+ output_dist,output_prob = sliding_window_inference(test_tensor, roi_size, sw_batch_size, predictor)
419
+ prob = output_prob[0][0].cpu().numpy()
420
+ dist = output_dist[0].cpu().numpy()
421
+ dist = np.transpose(dist,(1,2,0))
422
+ dist = np.maximum(1e-3, dist)
423
+ if h*w < 1500*1500:
424
+ points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.55, nms_thresh=0.4,cut=True)
425
+ else:
426
+ points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4)
427
+
428
+
429
+ labels_out = polygons_to_label(disti, points, prob=probi,shape=prob.shape)
430
+ else:
431
+ n = inputs.ndim
432
+ axes = 'YXC'
433
+ grid = (1,1,1)
434
+ if np.isscalar(block_size): block_size = n*[block_size]
435
+ if np.isscalar(min_overlap): min_overlap = n*[min_overlap]
436
+ if np.isscalar(context): context = n*[context]
437
+ shape_out = (inputs.shape[0],inputs.shape[1])
438
+ labels_out = np.zeros(shape_out, dtype=np.uint64)
439
+ #print(inputs.dtype)
440
+ block_size[2] = inputs.shape[2]
441
+ min_overlap[2] = context[2] = 0
442
+ block_size = tuple(_grid_divisible(g, v, name='block_size', verbose=False) for v,g,a in zip(block_size, grid,axes))
443
+ min_overlap = tuple(_grid_divisible(g, v, name='min_overlap', verbose=False) for v,g,a in zip(min_overlap,grid,axes))
444
+ context = tuple(_grid_divisible(g, v, name='context', verbose=False) for v,g,a in zip(context, grid,axes))
445
+ print(f'effective: block_size={block_size}, min_overlap={min_overlap}, context={context}', flush=True)
446
+ blocks = BlockND.cover(inputs.shape, axes, block_size, min_overlap, context)
447
+ label_offset = 1
448
+ blocks = tqdm(blocks)
449
+ for block in blocks:
450
+ image = block.read(inputs, axes=axes)
451
+ test_tensor = torch.from_numpy(np.expand_dims(image, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
452
+ output_dist,output_prob = sliding_window_inference(test_tensor, roi_size, sw_batch_size, predictor)
453
+ prob = output_prob[0][0].cpu().numpy()
454
+ dist = output_dist[0].cpu().numpy()
455
+ dist = np.transpose(dist,(1,2,0))
456
+ dist = np.maximum(1e-3, dist)
457
+ points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4)
458
+
459
+ coord = dist_to_coord(disti,points)
460
+ polys = dict(coord=coord, points=points, prob=probi)
461
+ labels = polygons_to_label(disti, points, prob=probi,shape=prob.shape)
462
+ labels = block.crop_context(labels, axes='YX')
463
+ labels, polys = block.filter_objects(labels, polys, axes='YX')
464
+ labels = relabel_sequential(labels, label_offset)[0]
465
+ if labels_out is not None:
466
+ block.write(labels_out, labels, axes='YX')
467
+ #for k,v in polys.items():
468
+ #polys_all.setdefault(k,[]).append(v)
469
+ label_offset += len(polys['prob'])
470
+ del labels
471
+ #polys_all = {k: (np.concatenate(v) if k in OBJECT_KEYS else v[0]) for k,v in polys_all.items()}
472
+ return labels_out
473
+ def sliding_window_inference(
474
+ inputs: torch.Tensor,
475
+ roi_size: Union[Sequence[int], int],
476
+ sw_batch_size: int,
477
+ predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]],
478
+ overlap: float = 0.25,
479
+ mode: Union[BlendMode, str] = BlendMode.CONSTANT,
480
+ sigma_scale: Union[Sequence[float], float] = 0.125,
481
+ padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
482
+ cval: float = 0.0,
483
+ sw_device: Union[torch.device, str, None] = None,
484
+ device: Union[torch.device, str, None] = None,
485
+ progress: bool = False,
486
+ roi_weight_map: Union[torch.Tensor, None] = None,
487
+ *args: Any,
488
+ **kwargs: Any,
489
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]:
490
+ """
491
+ Sliding window inference on `inputs` with `predictor`.
492
+
493
+ The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
494
+ Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
495
+ e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
496
+ could be ([128,64,256], [64,32,128]).
497
+ In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
498
+ an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
499
+ so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).
500
+
501
+ When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
502
+ To maintain the same spatial sizes, the output image will be cropped to the original input size.
503
+
504
+ Args:
505
+ inputs: input image to be processed (assuming NCHW[D])
506
+ roi_size: the spatial window size for inferences.
507
+ When its components have None or non-positives, the corresponding inputs dimension will be used.
508
+ if the components of the `roi_size` are non-positive values, the transform will use the
509
+ corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
510
+ to `(32, 64)` if the second spatial dimension size of img is `64`.
511
+ sw_batch_size: the batch size to run window slices.
512
+ predictor: given input tensor ``patch_data`` in shape NCHW[D],
513
+ The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary
514
+ with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];
515
+ where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,
516
+ N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),
517
+ the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).
518
+ In this case, the parameter `overlap` and `roi_size` need to be carefully chosen
519
+ to ensure the scaled output ROI sizes are still integers.
520
+ If the `predictor`'s input and output spatial sizes are different,
521
+ we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.
522
+ overlap: Amount of overlap between scans.
523
+ mode: {``"constant"``, ``"gaussian"``}
524
+ How to blend output of overlapping windows. Defaults to ``"constant"``.
525
+
526
+ - ``"constant``": gives equal weight to all predictions.
527
+ - ``"gaussian``": gives less weight to predictions on edges of windows.
528
+
529
+ sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
530
+ Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
531
+ When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
532
+ spatial dimensions.
533
+ padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
534
+ Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
535
+ See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
536
+ cval: fill value for 'constant' padding mode. Default: 0
537
+ sw_device: device for the window data.
538
+ By default the device (and accordingly the memory) of the `inputs` is used.
539
+ Normally `sw_device` should be consistent with the device where `predictor` is defined.
540
+ device: device for the stitched output prediction.
541
+ By default the device (and accordingly the memory) of the `inputs` is used. If for example
542
+ set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
543
+ `inputs` and `roi_size`. Output is on the `device`.
544
+ progress: whether to print a `tqdm` progress bar.
545
+ roi_weight_map: pre-computed (non-negative) weight map for each ROI.
546
+ If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
547
+ args: optional args to be passed to ``predictor``.
548
+ kwargs: optional keyword args to be passed to ``predictor``.
549
+
550
+ Note:
551
+ - input must be channel-first and have a batch dim, supports N-D sliding window.
552
+
553
+ """
554
+ compute_dtype = inputs.dtype
555
+ num_spatial_dims = len(inputs.shape) - 2
556
+ if overlap < 0 or overlap >= 1:
557
+ raise ValueError("overlap must be >= 0 and < 1.")
558
+
559
+ # determine image spatial size and batch size
560
+ # Note: all input images must have the same image size and batch size
561
+ batch_size, _, *image_size_ = inputs.shape
562
+
563
+ if device is None:
564
+ device = inputs.device
565
+ if sw_device is None:
566
+ sw_device = inputs.device
567
+
568
+ roi_size = fall_back_tuple(roi_size, image_size_)
569
+ # in case that image size is smaller than roi size
570
+ image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
571
+ pad_size = []
572
+ for k in range(len(inputs.shape) - 1, 1, -1):
573
+ diff = max(roi_size[k - 2] - inputs.shape[k], 0)
574
+ half = diff // 2
575
+ pad_size.extend([half, diff - half])
576
+ inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)
577
+ #print('inputs',inputs.shape)
578
+ scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
579
+
580
+ # Store all slices in list
581
+ slices = dense_patch_slices(image_size, roi_size, scan_interval)
582
+ num_win = len(slices) # number of windows per image
583
+ total_slices = num_win * batch_size # total number of windows
584
+
585
+ # Create window-level importance map
586
+ valid_patch_size = get_valid_patch_size(image_size, roi_size)
587
+ if valid_patch_size == roi_size and (roi_weight_map is not None):
588
+ importance_map = roi_weight_map
589
+ else:
590
+ try:
591
+ importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device)
592
+ except BaseException as e:
593
+ raise RuntimeError(
594
+ "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
595
+ ) from e
596
+ importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore
597
+ # handle non-positive weights
598
+ min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
599
+ importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype)
600
+
601
+ # Perform predictions
602
+ dict_key, output_image_list, count_map_list = None, [], []
603
+ _initialized_ss = -1
604
+ is_tensor_output = True # whether the predictor's output is a tensor (instead of dict/tuple)
605
+
606
+ # for each patch
607
+ for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size):
608
+ slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices))
609
+ unravel_slice = [
610
+ [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win])
611
+ for idx in slice_range
612
+ ]
613
+ window_data = torch.cat(
614
+ [convert_data_type(inputs[win_slice], torch.Tensor)[0] for win_slice in unravel_slice]
615
+ ).to(sw_device)
616
+ seg_prob_out = predictor(window_data, *args, **kwargs) # batched patch segmentation
617
+ #print('seg_prob_out',seg_prob_out[0].shape)
618
+ # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
619
+ seg_prob_tuple: Tuple[torch.Tensor, ...]
620
+ if isinstance(seg_prob_out, torch.Tensor):
621
+ seg_prob_tuple = (seg_prob_out,)
622
+ elif isinstance(seg_prob_out, Mapping):
623
+ if dict_key is None:
624
+ dict_key = sorted(seg_prob_out.keys()) # track predictor's output keys
625
+ seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key)
626
+ is_tensor_output = False
627
+ else:
628
+ seg_prob_tuple = ensure_tuple(seg_prob_out)
629
+ is_tensor_output = False
630
+
631
+ # for each output in multi-output list
632
+ for ss, seg_prob in enumerate(seg_prob_tuple):
633
+ seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN
634
+
635
+ # compute zoom scale: out_roi_size/in_roi_size
636
+ zoom_scale = []
637
+ for axis, (img_s_i, out_w_i, in_w_i) in enumerate(
638
+ zip(image_size, seg_prob.shape[2:], window_data.shape[2:])
639
+ ):
640
+ _scale = out_w_i / float(in_w_i)
641
+ if not (img_s_i * _scale).is_integer():
642
+ warnings.warn(
643
+ f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial "
644
+ f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs."
645
+ )
646
+ zoom_scale.append(_scale)
647
+
648
+ if _initialized_ss < ss: # init. the ss-th buffer at the first iteration
649
+ # construct multi-resolution outputs
650
+ output_classes = seg_prob.shape[1]
651
+ output_shape = [batch_size, output_classes] + [
652
+ int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale)
653
+ ]
654
+ # allocate memory to store the full output and the count for overlapping parts
655
+ output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device=device))
656
+ count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device))
657
+ _initialized_ss += 1
658
+
659
+ # resizing the importance_map
660
+ resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False)
661
+
662
+ # store the result in the proper location of the full output. Apply weights from importance map.
663
+ for idx, original_idx in zip(slice_range, unravel_slice):
664
+ # zoom roi
665
+ original_idx_zoom = list(original_idx) # 4D for 2D image, 5D for 3D image
666
+ for axis in range(2, len(original_idx_zoom)):
667
+ zoomed_start = original_idx[axis].start * zoom_scale[axis - 2]
668
+ zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2]
669
+ if not zoomed_start.is_integer() or (not zoomed_end.is_integer()):
670
+ warnings.warn(
671
+ f"For axis-{axis-2} of output[{ss}], the output roi range is not int. "
672
+ f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). "
673
+ f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. "
674
+ f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n"
675
+ f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. "
676
+ "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works."
677
+ )
678
+ original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None)
679
+ importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(compute_dtype)
680
+ # store results and weights
681
+ output_image_list[ss][original_idx_zoom] += importance_map_zoom * seg_prob[idx - slice_g]
682
+ count_map_list[ss][original_idx_zoom] += (
683
+ importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape)
684
+ )
685
+
686
+ # account for any overlapping sections
687
+ for ss in range(len(output_image_list)):
688
+ output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype)
689
+
690
+ # remove padding if image_size smaller than roi_size
691
+ for ss, output_i in enumerate(output_image_list):
692
+ if torch.isnan(output_i).any() or torch.isinf(output_i).any():
693
+ warnings.warn("Sliding window inference results contain NaN or Inf.")
694
+
695
+ zoom_scale = [
696
+ seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size)
697
+ ]
698
+
699
+ final_slicing: List[slice] = []
700
+ for sp in range(num_spatial_dims):
701
+ slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2])
702
+ slice_dim = slice(
703
+ int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])),
704
+ int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])),
705
+ )
706
+ final_slicing.insert(0, slice_dim)
707
+ while len(final_slicing) < len(output_i.shape):
708
+ final_slicing.insert(0, slice(None))
709
+ output_image_list[ss] = output_i[final_slicing]
710
+
711
+ if dict_key is not None: # if output of predictor is a dict
712
+ final_output = dict(zip(dict_key, output_image_list))
713
+ else:
714
+ final_output = tuple(output_image_list) # type: ignore
715
+ final_output = final_output[0] if is_tensor_output else final_output
716
+
717
+ if isinstance(inputs, MetaTensor):
718
+ final_output = convert_to_dst_type(final_output, inputs, device=device)[0] # type: ignore
719
+ return final_output
720
+
721
+
722
+ def _get_scan_interval(
723
+ image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float
724
+ ) -> Tuple[int, ...]:
725
+ """
726
+ Compute scan interval according to the image size, roi size and overlap.
727
+ Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
728
+ use 1 instead to make sure sliding window works.
729
+
730
+ """
731
+ if len(image_size) != num_spatial_dims:
732
+ raise ValueError("image coord different from spatial dims.")
733
+ if len(roi_size) != num_spatial_dims:
734
+ raise ValueError("roi coord different from spatial dims.")
735
+
736
+ scan_interval = []
737
+ for i in range(num_spatial_dims):
738
+ if roi_size[i] == image_size[i]:
739
+ scan_interval.append(int(roi_size[i]))
740
+ else:
741
+ interval = int(roi_size[i] * (1 - overlap))
742
+ scan_interval.append(interval if interval > 0 else 1)
743
+ return tuple(scan_interval)