Upload 2 files
Browse files- classifiers.py +261 -0
- utils_modify.py +743 -0
classifiers.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Any, Callable, List, Optional, Type, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
|
9 |
+
"""3x3 convolution with padding"""
|
10 |
+
return nn.Conv2d(
|
11 |
+
in_planes,
|
12 |
+
out_planes,
|
13 |
+
kernel_size=3,
|
14 |
+
stride=stride,
|
15 |
+
padding=dilation,
|
16 |
+
groups=groups,
|
17 |
+
bias=False,
|
18 |
+
dilation=dilation,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
23 |
+
"""1x1 convolution"""
|
24 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
25 |
+
|
26 |
+
|
27 |
+
class BasicBlock(nn.Module):
|
28 |
+
expansion: int = 1
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
inplanes: int,
|
33 |
+
planes: int,
|
34 |
+
stride: int = 1,
|
35 |
+
downsample: Optional[nn.Module] = None,
|
36 |
+
groups: int = 1,
|
37 |
+
base_width: int = 64,
|
38 |
+
dilation: int = 1,
|
39 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
40 |
+
) -> None:
|
41 |
+
super().__init__()
|
42 |
+
if norm_layer is None:
|
43 |
+
norm_layer = nn.BatchNorm2d
|
44 |
+
if groups != 1 or base_width != 64:
|
45 |
+
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
|
46 |
+
if dilation > 1:
|
47 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
48 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
49 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
50 |
+
self.bn1 = norm_layer(planes)
|
51 |
+
self.relu = nn.ReLU(inplace=True)
|
52 |
+
self.conv2 = conv3x3(planes, planes)
|
53 |
+
self.bn2 = norm_layer(planes)
|
54 |
+
self.downsample = downsample
|
55 |
+
self.stride = stride
|
56 |
+
|
57 |
+
def forward(self, x: Tensor) -> Tensor:
|
58 |
+
identity = x
|
59 |
+
|
60 |
+
out = self.conv1(x)
|
61 |
+
out = self.bn1(out)
|
62 |
+
out = self.relu(out)
|
63 |
+
|
64 |
+
out = self.conv2(out)
|
65 |
+
out = self.bn2(out)
|
66 |
+
|
67 |
+
if self.downsample is not None:
|
68 |
+
identity = self.downsample(x)
|
69 |
+
|
70 |
+
out += identity
|
71 |
+
out = self.relu(out)
|
72 |
+
|
73 |
+
return out
|
74 |
+
|
75 |
+
class Bottleneck(nn.Module):
|
76 |
+
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
77 |
+
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
78 |
+
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
79 |
+
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
80 |
+
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
81 |
+
|
82 |
+
expansion: int = 4
|
83 |
+
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
inplanes: int,
|
87 |
+
planes: int,
|
88 |
+
stride: int = 1,
|
89 |
+
downsample: Optional[nn.Module] = None,
|
90 |
+
groups: int = 1,
|
91 |
+
base_width: int = 64,
|
92 |
+
dilation: int = 1,
|
93 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
94 |
+
) -> None:
|
95 |
+
super().__init__()
|
96 |
+
if norm_layer is None:
|
97 |
+
norm_layer = nn.BatchNorm2d
|
98 |
+
width = int(planes * (base_width / 64.0)) * groups
|
99 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
100 |
+
self.conv1 = conv1x1(inplanes, width)
|
101 |
+
self.bn1 = norm_layer(width)
|
102 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
103 |
+
self.bn2 = norm_layer(width)
|
104 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
105 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
106 |
+
self.relu = nn.ReLU(inplace=True)
|
107 |
+
self.downsample = downsample
|
108 |
+
self.stride = stride
|
109 |
+
|
110 |
+
def forward(self, x: Tensor) -> Tensor:
|
111 |
+
identity = x
|
112 |
+
|
113 |
+
out = self.conv1(x)
|
114 |
+
out = self.bn1(out)
|
115 |
+
out = self.relu(out)
|
116 |
+
|
117 |
+
out = self.conv2(out)
|
118 |
+
out = self.bn2(out)
|
119 |
+
out = self.relu(out)
|
120 |
+
|
121 |
+
out = self.conv3(out)
|
122 |
+
out = self.bn3(out)
|
123 |
+
|
124 |
+
if self.downsample is not None:
|
125 |
+
identity = self.downsample(x)
|
126 |
+
|
127 |
+
out += identity
|
128 |
+
out = self.relu(out)
|
129 |
+
|
130 |
+
return out
|
131 |
+
|
132 |
+
class ResNet(nn.Module):
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
block: Type[Union[BasicBlock, Bottleneck]],
|
136 |
+
layers: List[int],
|
137 |
+
num_classes: int = 1000,
|
138 |
+
zero_init_residual: bool = False,
|
139 |
+
groups: int = 1,
|
140 |
+
width_per_group: int = 64,
|
141 |
+
replace_stride_with_dilation: Optional[List[bool]] = None,
|
142 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
143 |
+
) -> None:
|
144 |
+
super().__init__()
|
145 |
+
# _log_api_usage_once(self)
|
146 |
+
if norm_layer is None:
|
147 |
+
norm_layer = nn.BatchNorm2d
|
148 |
+
self._norm_layer = norm_layer
|
149 |
+
|
150 |
+
self.inplanes = 64
|
151 |
+
self.dilation = 1
|
152 |
+
if replace_stride_with_dilation is None:
|
153 |
+
# each element in the tuple indicates if we should replace
|
154 |
+
# the 2x2 stride with a dilated convolution instead
|
155 |
+
replace_stride_with_dilation = [False, False, False]
|
156 |
+
if len(replace_stride_with_dilation) != 3:
|
157 |
+
raise ValueError(
|
158 |
+
"replace_stride_with_dilation should be None "
|
159 |
+
f"or a 3-element tuple, got {replace_stride_with_dilation}"
|
160 |
+
)
|
161 |
+
self.groups = groups
|
162 |
+
self.base_width = width_per_group
|
163 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
164 |
+
self.bn1 = norm_layer(self.inplanes)
|
165 |
+
self.relu = nn.ReLU(inplace=True)
|
166 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
167 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
168 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
|
169 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
|
170 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
|
171 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
172 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
173 |
+
|
174 |
+
for m in self.modules():
|
175 |
+
if isinstance(m, nn.Conv2d):
|
176 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
177 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
178 |
+
nn.init.constant_(m.weight, 1)
|
179 |
+
nn.init.constant_(m.bias, 0)
|
180 |
+
|
181 |
+
# Zero-initialize the last BN in each residual branch,
|
182 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
183 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
184 |
+
if zero_init_residual:
|
185 |
+
for m in self.modules():
|
186 |
+
if isinstance(m, Bottleneck) and m.bn3.weight is not None:
|
187 |
+
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
|
188 |
+
elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
|
189 |
+
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
|
190 |
+
|
191 |
+
def _make_layer(
|
192 |
+
self,
|
193 |
+
block: Type[Union[BasicBlock, Bottleneck]],
|
194 |
+
planes: int,
|
195 |
+
blocks: int,
|
196 |
+
stride: int = 1,
|
197 |
+
dilate: bool = False,
|
198 |
+
) -> nn.Sequential:
|
199 |
+
norm_layer = self._norm_layer
|
200 |
+
downsample = None
|
201 |
+
previous_dilation = self.dilation
|
202 |
+
if dilate:
|
203 |
+
self.dilation *= stride
|
204 |
+
stride = 1
|
205 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
206 |
+
downsample = nn.Sequential(
|
207 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
208 |
+
norm_layer(planes * block.expansion),
|
209 |
+
)
|
210 |
+
|
211 |
+
layers = []
|
212 |
+
layers.append(
|
213 |
+
block(
|
214 |
+
self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
|
215 |
+
)
|
216 |
+
)
|
217 |
+
self.inplanes = planes * block.expansion
|
218 |
+
for _ in range(1, blocks):
|
219 |
+
layers.append(
|
220 |
+
block(
|
221 |
+
self.inplanes,
|
222 |
+
planes,
|
223 |
+
groups=self.groups,
|
224 |
+
base_width=self.base_width,
|
225 |
+
dilation=self.dilation,
|
226 |
+
norm_layer=norm_layer,
|
227 |
+
)
|
228 |
+
)
|
229 |
+
|
230 |
+
return nn.Sequential(*layers)
|
231 |
+
|
232 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
233 |
+
# See note [TorchScript super()]
|
234 |
+
x = self.conv1(x)
|
235 |
+
x = self.bn1(x)
|
236 |
+
x = self.relu(x)
|
237 |
+
x = self.maxpool(x)
|
238 |
+
|
239 |
+
x = self.layer1(x)
|
240 |
+
x = self.layer2(x)
|
241 |
+
x = self.layer3(x)
|
242 |
+
x = self.layer4(x)
|
243 |
+
|
244 |
+
x = self.avgpool(x)
|
245 |
+
x = torch.flatten(x, 1)
|
246 |
+
x = self.fc(x)
|
247 |
+
|
248 |
+
return x
|
249 |
+
|
250 |
+
def forward(self, x: Tensor) -> Tensor:
|
251 |
+
return self._forward_impl(x)
|
252 |
+
|
253 |
+
def resnet18(weights=None):
|
254 |
+
# weights: path
|
255 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=4)
|
256 |
+
if weights is not None:
|
257 |
+
model.load_state_dict(torch.load(weights))
|
258 |
+
return model
|
259 |
+
|
260 |
+
def resnet10():
|
261 |
+
return ResNet(BasicBlock, [1, 1, 1, 1], num_classes=4)
|
utils_modify.py
ADDED
@@ -0,0 +1,743 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
import warnings
|
13 |
+
from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from stardist_pkg.big import _grid_divisible, BlockND, OBJECT_KEYS#, repaint_labels
|
18 |
+
from stardist_pkg.matching import relabel_sequential
|
19 |
+
from stardist_pkg import dist_to_coord, non_maximum_suppression, polygons_to_label
|
20 |
+
#from stardist_pkg import dist_to_coord, polygons_to_label
|
21 |
+
from stardist_pkg import star_dist,edt_prob
|
22 |
+
from monai.data.meta_tensor import MetaTensor
|
23 |
+
from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size
|
24 |
+
from monai.transforms import Resize
|
25 |
+
from monai.utils import (
|
26 |
+
BlendMode,
|
27 |
+
PytorchPadMode,
|
28 |
+
convert_data_type,
|
29 |
+
convert_to_dst_type,
|
30 |
+
ensure_tuple,
|
31 |
+
fall_back_tuple,
|
32 |
+
look_up_option,
|
33 |
+
optional_import,
|
34 |
+
)
|
35 |
+
import cv2
|
36 |
+
from scipy import ndimage
|
37 |
+
from scipy.ndimage.filters import gaussian_filter
|
38 |
+
from scipy.ndimage.interpolation import affine_transform, map_coordinates
|
39 |
+
from skimage import morphology as morph
|
40 |
+
from scipy.ndimage import filters, measurements
|
41 |
+
from scipy.ndimage.morphology import (
|
42 |
+
binary_dilation,
|
43 |
+
binary_fill_holes,
|
44 |
+
distance_transform_cdt,
|
45 |
+
distance_transform_edt,
|
46 |
+
)
|
47 |
+
|
48 |
+
from skimage.segmentation import watershed
|
49 |
+
tqdm, _ = optional_import("tqdm", name="tqdm")
|
50 |
+
|
51 |
+
__all__ = ["sliding_window_inference"]
|
52 |
+
|
53 |
+
|
54 |
+
####
|
55 |
+
def normalize(mask, dtype=np.uint8):
|
56 |
+
return (255 * mask / np.amax(mask)).astype(dtype)
|
57 |
+
|
58 |
+
def fix_mirror_padding(ann):
|
59 |
+
"""Deal with duplicated instances due to mirroring in interpolation
|
60 |
+
during shape augmentation (scale, rotation etc.).
|
61 |
+
|
62 |
+
"""
|
63 |
+
current_max_id = np.amax(ann)
|
64 |
+
inst_list = list(np.unique(ann))
|
65 |
+
if 0 in inst_list:
|
66 |
+
inst_list.remove(0) # 0 is background
|
67 |
+
for inst_id in inst_list:
|
68 |
+
inst_map = np.array(ann == inst_id, np.uint8)
|
69 |
+
remapped_ids = measurements.label(inst_map)[0]
|
70 |
+
remapped_ids[remapped_ids > 1] += current_max_id
|
71 |
+
ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1]
|
72 |
+
current_max_id = np.amax(ann)
|
73 |
+
return ann
|
74 |
+
|
75 |
+
####
|
76 |
+
def get_bounding_box(img):
|
77 |
+
"""Get bounding box coordinate information."""
|
78 |
+
rows = np.any(img, axis=1)
|
79 |
+
cols = np.any(img, axis=0)
|
80 |
+
rmin, rmax = np.where(rows)[0][[0, -1]]
|
81 |
+
cmin, cmax = np.where(cols)[0][[0, -1]]
|
82 |
+
# due to python indexing, need to add 1 to max
|
83 |
+
# else accessing will be 1px in the box, not out
|
84 |
+
rmax += 1
|
85 |
+
cmax += 1
|
86 |
+
return [rmin, rmax, cmin, cmax]
|
87 |
+
|
88 |
+
|
89 |
+
####
|
90 |
+
def cropping_center(x, crop_shape, batch=False):
|
91 |
+
"""Crop an input image at the centre.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
x: input array
|
95 |
+
crop_shape: dimensions of cropped array
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
x: cropped array
|
99 |
+
|
100 |
+
"""
|
101 |
+
orig_shape = x.shape
|
102 |
+
if not batch:
|
103 |
+
h0 = int((orig_shape[0] - crop_shape[0]) * 0.5)
|
104 |
+
w0 = int((orig_shape[1] - crop_shape[1]) * 0.5)
|
105 |
+
x = x[h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]]
|
106 |
+
else:
|
107 |
+
h0 = int((orig_shape[1] - crop_shape[0]) * 0.5)
|
108 |
+
w0 = int((orig_shape[2] - crop_shape[1]) * 0.5)
|
109 |
+
x = x[:, h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]]
|
110 |
+
return x
|
111 |
+
|
112 |
+
def gen_instance_hv_map(ann, crop_shape):
|
113 |
+
"""Input annotation must be of original shape.
|
114 |
+
|
115 |
+
The map is calculated only for instances within the crop portion
|
116 |
+
but based on the original shape in original image.
|
117 |
+
|
118 |
+
Perform following operation:
|
119 |
+
Obtain the horizontal and vertical distance maps for each
|
120 |
+
nuclear instance.
|
121 |
+
|
122 |
+
"""
|
123 |
+
orig_ann = ann.copy() # instance ID map
|
124 |
+
fixed_ann = fix_mirror_padding(orig_ann)
|
125 |
+
# re-cropping with fixed instance id map
|
126 |
+
crop_ann = cropping_center(fixed_ann, crop_shape)
|
127 |
+
# TODO: deal with 1 label warning
|
128 |
+
crop_ann = morph.remove_small_objects(crop_ann, min_size=30)
|
129 |
+
|
130 |
+
x_map = np.zeros(orig_ann.shape[:2], dtype=np.float32)
|
131 |
+
y_map = np.zeros(orig_ann.shape[:2], dtype=np.float32)
|
132 |
+
|
133 |
+
inst_list = list(np.unique(crop_ann))
|
134 |
+
if 0 in inst_list:
|
135 |
+
inst_list.remove(0) # 0 is background
|
136 |
+
for inst_id in inst_list:
|
137 |
+
inst_map = np.array(fixed_ann == inst_id, np.uint8)
|
138 |
+
inst_box = get_bounding_box(inst_map) # rmin, rmax, cmin, cmax
|
139 |
+
|
140 |
+
# expand the box by 2px
|
141 |
+
# Because we first pad the ann at line 207, the bboxes
|
142 |
+
# will remain valid after expansion
|
143 |
+
inst_box[0] -= 2
|
144 |
+
inst_box[2] -= 2
|
145 |
+
inst_box[1] += 2
|
146 |
+
inst_box[3] += 2
|
147 |
+
|
148 |
+
# fix inst_box
|
149 |
+
inst_box[0] = max(inst_box[0], 0)
|
150 |
+
inst_box[2] = max(inst_box[2], 0)
|
151 |
+
# inst_box[1] = min(inst_box[1], fixed_ann.shape[0])
|
152 |
+
# inst_box[3] = min(inst_box[3], fixed_ann.shape[1])
|
153 |
+
|
154 |
+
inst_map = inst_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
|
155 |
+
|
156 |
+
if inst_map.shape[0] < 2 or inst_map.shape[1] < 2:
|
157 |
+
print(f'inst_map.shape < 2: {inst_map.shape}, {inst_box}, {get_bounding_box(np.array(fixed_ann == inst_id, np.uint8))}')
|
158 |
+
continue
|
159 |
+
|
160 |
+
# instance center of mass, rounded to nearest pixel
|
161 |
+
inst_com = list(measurements.center_of_mass(inst_map))
|
162 |
+
if np.isnan(measurements.center_of_mass(inst_map)).any():
|
163 |
+
print(inst_id, fixed_ann.shape, np.array(fixed_ann == inst_id, np.uint8).shape)
|
164 |
+
print(get_bounding_box(np.array(fixed_ann == inst_id, np.uint8)))
|
165 |
+
print(inst_map)
|
166 |
+
print(inst_list)
|
167 |
+
print(inst_box)
|
168 |
+
print(np.count_nonzero(np.array(fixed_ann == inst_id, np.uint8)))
|
169 |
+
|
170 |
+
inst_com[0] = int(inst_com[0] + 0.5)
|
171 |
+
inst_com[1] = int(inst_com[1] + 0.5)
|
172 |
+
|
173 |
+
inst_x_range = np.arange(1, inst_map.shape[1] + 1)
|
174 |
+
inst_y_range = np.arange(1, inst_map.shape[0] + 1)
|
175 |
+
# shifting center of pixels grid to instance center of mass
|
176 |
+
inst_x_range -= inst_com[1]
|
177 |
+
inst_y_range -= inst_com[0]
|
178 |
+
|
179 |
+
inst_x, inst_y = np.meshgrid(inst_x_range, inst_y_range)
|
180 |
+
|
181 |
+
# remove coord outside of instance
|
182 |
+
inst_x[inst_map == 0] = 0
|
183 |
+
inst_y[inst_map == 0] = 0
|
184 |
+
inst_x = inst_x.astype("float32")
|
185 |
+
inst_y = inst_y.astype("float32")
|
186 |
+
|
187 |
+
# normalize min into -1 scale
|
188 |
+
if np.min(inst_x) < 0:
|
189 |
+
inst_x[inst_x < 0] /= -np.amin(inst_x[inst_x < 0])
|
190 |
+
if np.min(inst_y) < 0:
|
191 |
+
inst_y[inst_y < 0] /= -np.amin(inst_y[inst_y < 0])
|
192 |
+
# normalize max into +1 scale
|
193 |
+
if np.max(inst_x) > 0:
|
194 |
+
inst_x[inst_x > 0] /= np.amax(inst_x[inst_x > 0])
|
195 |
+
if np.max(inst_y) > 0:
|
196 |
+
inst_y[inst_y > 0] /= np.amax(inst_y[inst_y > 0])
|
197 |
+
|
198 |
+
####
|
199 |
+
x_map_box = x_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
|
200 |
+
x_map_box[inst_map > 0] = inst_x[inst_map > 0]
|
201 |
+
|
202 |
+
y_map_box = y_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
|
203 |
+
y_map_box[inst_map > 0] = inst_y[inst_map > 0]
|
204 |
+
|
205 |
+
hv_map = np.dstack([x_map, y_map])
|
206 |
+
return hv_map
|
207 |
+
|
208 |
+
def remove_small_objects(pred, min_size=64, connectivity=1):
|
209 |
+
"""Remove connected components smaller than the specified size.
|
210 |
+
|
211 |
+
This function is taken from skimage.morphology.remove_small_objects, but the warning
|
212 |
+
is removed when a single label is provided.
|
213 |
+
|
214 |
+
Args:
|
215 |
+
pred: input labelled array
|
216 |
+
min_size: minimum size of instance in output array
|
217 |
+
connectivity: The connectivity defining the neighborhood of a pixel.
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
out: output array with instances removed under min_size
|
221 |
+
|
222 |
+
"""
|
223 |
+
out = pred
|
224 |
+
|
225 |
+
if min_size == 0: # shortcut for efficiency
|
226 |
+
return out
|
227 |
+
|
228 |
+
if out.dtype == bool:
|
229 |
+
selem = ndimage.generate_binary_structure(pred.ndim, connectivity)
|
230 |
+
ccs = np.zeros_like(pred, dtype=np.int32)
|
231 |
+
ndimage.label(pred, selem, output=ccs)
|
232 |
+
else:
|
233 |
+
ccs = out
|
234 |
+
|
235 |
+
try:
|
236 |
+
component_sizes = np.bincount(ccs.ravel())
|
237 |
+
except ValueError:
|
238 |
+
raise ValueError(
|
239 |
+
"Negative value labels are not supported. Try "
|
240 |
+
"relabeling the input with `scipy.ndimage.label` or "
|
241 |
+
"`skimage.morphology.label`."
|
242 |
+
)
|
243 |
+
|
244 |
+
too_small = component_sizes < min_size
|
245 |
+
too_small_mask = too_small[ccs]
|
246 |
+
out[too_small_mask] = 0
|
247 |
+
|
248 |
+
return out
|
249 |
+
|
250 |
+
####
|
251 |
+
def gen_targets(ann, crop_shape, **kwargs):
|
252 |
+
"""Generate the targets for the network."""
|
253 |
+
hv_map = gen_instance_hv_map(ann, crop_shape)
|
254 |
+
np_map = ann.copy()
|
255 |
+
np_map[np_map > 0] = 1
|
256 |
+
|
257 |
+
hv_map = cropping_center(hv_map, crop_shape)
|
258 |
+
np_map = cropping_center(np_map, crop_shape)
|
259 |
+
|
260 |
+
target_dict = {
|
261 |
+
"hv_map": hv_map,
|
262 |
+
"np_map": np_map,
|
263 |
+
}
|
264 |
+
|
265 |
+
return target_dict
|
266 |
+
def __proc_np_hv(pred, np_thres=0.5, ksize=21, overall_thres=0.4, obj_size_thres=10):
|
267 |
+
"""Process Nuclei Prediction with XY Coordinate Map.
|
268 |
+
|
269 |
+
Args:
|
270 |
+
pred: prediction output, assuming
|
271 |
+
channel 0 contain probability map of nuclei
|
272 |
+
channel 1 containing the regressed X-map
|
273 |
+
channel 2 containing the regressed Y-map
|
274 |
+
|
275 |
+
"""
|
276 |
+
pred = np.array(pred, dtype=np.float32)
|
277 |
+
|
278 |
+
blb_raw = pred[..., 0]
|
279 |
+
h_dir_raw = pred[..., 1]
|
280 |
+
v_dir_raw = pred[..., 2]
|
281 |
+
|
282 |
+
# processing
|
283 |
+
blb = np.array(blb_raw >= np_thres, dtype=np.int32)
|
284 |
+
|
285 |
+
blb = measurements.label(blb)[0]
|
286 |
+
blb = remove_small_objects(blb, min_size=10)
|
287 |
+
blb[blb > 0] = 1 # background is 0 already
|
288 |
+
|
289 |
+
h_dir = cv2.normalize(
|
290 |
+
h_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
|
291 |
+
)
|
292 |
+
v_dir = cv2.normalize(
|
293 |
+
v_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
|
294 |
+
)
|
295 |
+
|
296 |
+
sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=ksize)
|
297 |
+
sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=ksize)
|
298 |
+
|
299 |
+
sobelh = 1 - (
|
300 |
+
cv2.normalize(
|
301 |
+
sobelh, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
|
302 |
+
)
|
303 |
+
)
|
304 |
+
sobelv = 1 - (
|
305 |
+
cv2.normalize(
|
306 |
+
sobelv, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
|
307 |
+
)
|
308 |
+
)
|
309 |
+
|
310 |
+
overall = np.maximum(sobelh, sobelv)
|
311 |
+
overall = overall - (1 - blb)
|
312 |
+
overall[overall < 0] = 0
|
313 |
+
|
314 |
+
dist = (1.0 - overall) * blb
|
315 |
+
## nuclei values form mountains so inverse to get basins
|
316 |
+
dist = -cv2.GaussianBlur(dist, (3, 3), 0)
|
317 |
+
|
318 |
+
overall = np.array(overall >= overall_thres, dtype=np.int32)
|
319 |
+
|
320 |
+
marker = blb - overall
|
321 |
+
marker[marker < 0] = 0
|
322 |
+
marker = binary_fill_holes(marker).astype("uint8")
|
323 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
324 |
+
marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel)
|
325 |
+
marker = measurements.label(marker)[0]
|
326 |
+
marker = remove_small_objects(marker, min_size=obj_size_thres)
|
327 |
+
|
328 |
+
proced_pred = watershed(dist, markers=marker, mask=blb)
|
329 |
+
|
330 |
+
return proced_pred
|
331 |
+
|
332 |
+
####
|
333 |
+
def colorize(ch, vmin, vmax):
|
334 |
+
"""Will clamp value value outside the provided range to vmax and vmin."""
|
335 |
+
cmap = plt.get_cmap("jet")
|
336 |
+
ch = np.squeeze(ch.astype("float32"))
|
337 |
+
vmin = vmin if vmin is not None else ch.min()
|
338 |
+
vmax = vmax if vmax is not None else ch.max()
|
339 |
+
ch[ch > vmax] = vmax # clamp value
|
340 |
+
ch[ch < vmin] = vmin
|
341 |
+
ch = (ch - vmin) / (vmax - vmin + 1.0e-16)
|
342 |
+
# take RGB from RGBA heat map
|
343 |
+
ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8")
|
344 |
+
return ch_cmap
|
345 |
+
|
346 |
+
|
347 |
+
####
|
348 |
+
def random_colors(N, bright=True):
|
349 |
+
"""Generate random colors.
|
350 |
+
|
351 |
+
To get visually distinct colors, generate them in HSV space then
|
352 |
+
convert to RGB.
|
353 |
+
"""
|
354 |
+
brightness = 1.0 if bright else 0.7
|
355 |
+
hsv = [(i / N, 1, brightness) for i in range(N)]
|
356 |
+
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
|
357 |
+
random.shuffle(colors)
|
358 |
+
return colors
|
359 |
+
|
360 |
+
|
361 |
+
####
|
362 |
+
def visualize_instances_map(
|
363 |
+
input_image, inst_map, type_map=None, type_colour=None, line_thickness=2
|
364 |
+
):
|
365 |
+
"""Overlays segmentation results on image as contours.
|
366 |
+
|
367 |
+
Args:
|
368 |
+
input_image: input image
|
369 |
+
inst_map: instance mask with unique value for every object
|
370 |
+
type_map: type mask with unique value for every class
|
371 |
+
type_colour: a dict of {type : colour} , `type` is from 0-N
|
372 |
+
and `colour` is a tuple of (R, G, B)
|
373 |
+
line_thickness: line thickness of contours
|
374 |
+
|
375 |
+
Returns:
|
376 |
+
overlay: output image with segmentation overlay as contours
|
377 |
+
"""
|
378 |
+
overlay = np.copy((input_image).astype(np.uint8))
|
379 |
+
|
380 |
+
inst_list = list(np.unique(inst_map)) # get list of instances
|
381 |
+
inst_list.remove(0) # remove background
|
382 |
+
|
383 |
+
inst_rng_colors = random_colors(len(inst_list))
|
384 |
+
inst_rng_colors = np.array(inst_rng_colors) * 255
|
385 |
+
inst_rng_colors = inst_rng_colors.astype(np.uint8)
|
386 |
+
|
387 |
+
for inst_idx, inst_id in enumerate(inst_list):
|
388 |
+
inst_map_mask = np.array(inst_map == inst_id, np.uint8) # get single object
|
389 |
+
y1, y2, x1, x2 = get_bounding_box(inst_map_mask)
|
390 |
+
y1 = y1 - 2 if y1 - 2 >= 0 else y1
|
391 |
+
x1 = x1 - 2 if x1 - 2 >= 0 else x1
|
392 |
+
x2 = x2 + 2 if x2 + 2 <= inst_map.shape[1] - 1 else x2
|
393 |
+
y2 = y2 + 2 if y2 + 2 <= inst_map.shape[0] - 1 else y2
|
394 |
+
inst_map_crop = inst_map_mask[y1:y2, x1:x2]
|
395 |
+
contours_crop = cv2.findContours(
|
396 |
+
inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
397 |
+
)
|
398 |
+
# only has 1 instance per map, no need to check #contour detected by opencv
|
399 |
+
contours_crop = np.squeeze(
|
400 |
+
contours_crop[0][0].astype("int32")
|
401 |
+
) # * opencv protocol format may break
|
402 |
+
contours_crop += np.asarray([[x1, y1]]) # index correction
|
403 |
+
if type_map is not None:
|
404 |
+
type_map_crop = type_map[y1:y2, x1:x2]
|
405 |
+
type_id = np.unique(type_map_crop).max() # non-zero
|
406 |
+
inst_colour = type_colour[type_id]
|
407 |
+
else:
|
408 |
+
inst_colour = (inst_rng_colors[inst_idx]).tolist()
|
409 |
+
cv2.drawContours(overlay, [contours_crop], -1, inst_colour, line_thickness)
|
410 |
+
return overlay
|
411 |
+
|
412 |
+
|
413 |
+
def sliding_window_inference_large(inputs,block_size,min_overlap,context,roi_size,sw_batch_size,predictor,device):
|
414 |
+
|
415 |
+
h,w = inputs.shape[0],inputs.shape[1]
|
416 |
+
if h < 5000 or w < 5000:
|
417 |
+
test_tensor = torch.from_numpy(np.expand_dims(inputs, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
|
418 |
+
output_dist,output_prob = sliding_window_inference(test_tensor, roi_size, sw_batch_size, predictor)
|
419 |
+
prob = output_prob[0][0].cpu().numpy()
|
420 |
+
dist = output_dist[0].cpu().numpy()
|
421 |
+
dist = np.transpose(dist,(1,2,0))
|
422 |
+
dist = np.maximum(1e-3, dist)
|
423 |
+
if h*w < 1500*1500:
|
424 |
+
points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.55, nms_thresh=0.4,cut=True)
|
425 |
+
else:
|
426 |
+
points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4)
|
427 |
+
|
428 |
+
|
429 |
+
labels_out = polygons_to_label(disti, points, prob=probi,shape=prob.shape)
|
430 |
+
else:
|
431 |
+
n = inputs.ndim
|
432 |
+
axes = 'YXC'
|
433 |
+
grid = (1,1,1)
|
434 |
+
if np.isscalar(block_size): block_size = n*[block_size]
|
435 |
+
if np.isscalar(min_overlap): min_overlap = n*[min_overlap]
|
436 |
+
if np.isscalar(context): context = n*[context]
|
437 |
+
shape_out = (inputs.shape[0],inputs.shape[1])
|
438 |
+
labels_out = np.zeros(shape_out, dtype=np.uint64)
|
439 |
+
#print(inputs.dtype)
|
440 |
+
block_size[2] = inputs.shape[2]
|
441 |
+
min_overlap[2] = context[2] = 0
|
442 |
+
block_size = tuple(_grid_divisible(g, v, name='block_size', verbose=False) for v,g,a in zip(block_size, grid,axes))
|
443 |
+
min_overlap = tuple(_grid_divisible(g, v, name='min_overlap', verbose=False) for v,g,a in zip(min_overlap,grid,axes))
|
444 |
+
context = tuple(_grid_divisible(g, v, name='context', verbose=False) for v,g,a in zip(context, grid,axes))
|
445 |
+
print(f'effective: block_size={block_size}, min_overlap={min_overlap}, context={context}', flush=True)
|
446 |
+
blocks = BlockND.cover(inputs.shape, axes, block_size, min_overlap, context)
|
447 |
+
label_offset = 1
|
448 |
+
blocks = tqdm(blocks)
|
449 |
+
for block in blocks:
|
450 |
+
image = block.read(inputs, axes=axes)
|
451 |
+
test_tensor = torch.from_numpy(np.expand_dims(image, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
|
452 |
+
output_dist,output_prob = sliding_window_inference(test_tensor, roi_size, sw_batch_size, predictor)
|
453 |
+
prob = output_prob[0][0].cpu().numpy()
|
454 |
+
dist = output_dist[0].cpu().numpy()
|
455 |
+
dist = np.transpose(dist,(1,2,0))
|
456 |
+
dist = np.maximum(1e-3, dist)
|
457 |
+
points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4)
|
458 |
+
|
459 |
+
coord = dist_to_coord(disti,points)
|
460 |
+
polys = dict(coord=coord, points=points, prob=probi)
|
461 |
+
labels = polygons_to_label(disti, points, prob=probi,shape=prob.shape)
|
462 |
+
labels = block.crop_context(labels, axes='YX')
|
463 |
+
labels, polys = block.filter_objects(labels, polys, axes='YX')
|
464 |
+
labels = relabel_sequential(labels, label_offset)[0]
|
465 |
+
if labels_out is not None:
|
466 |
+
block.write(labels_out, labels, axes='YX')
|
467 |
+
#for k,v in polys.items():
|
468 |
+
#polys_all.setdefault(k,[]).append(v)
|
469 |
+
label_offset += len(polys['prob'])
|
470 |
+
del labels
|
471 |
+
#polys_all = {k: (np.concatenate(v) if k in OBJECT_KEYS else v[0]) for k,v in polys_all.items()}
|
472 |
+
return labels_out
|
473 |
+
def sliding_window_inference(
|
474 |
+
inputs: torch.Tensor,
|
475 |
+
roi_size: Union[Sequence[int], int],
|
476 |
+
sw_batch_size: int,
|
477 |
+
predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]],
|
478 |
+
overlap: float = 0.25,
|
479 |
+
mode: Union[BlendMode, str] = BlendMode.CONSTANT,
|
480 |
+
sigma_scale: Union[Sequence[float], float] = 0.125,
|
481 |
+
padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
|
482 |
+
cval: float = 0.0,
|
483 |
+
sw_device: Union[torch.device, str, None] = None,
|
484 |
+
device: Union[torch.device, str, None] = None,
|
485 |
+
progress: bool = False,
|
486 |
+
roi_weight_map: Union[torch.Tensor, None] = None,
|
487 |
+
*args: Any,
|
488 |
+
**kwargs: Any,
|
489 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]:
|
490 |
+
"""
|
491 |
+
Sliding window inference on `inputs` with `predictor`.
|
492 |
+
|
493 |
+
The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
|
494 |
+
Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
|
495 |
+
e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
|
496 |
+
could be ([128,64,256], [64,32,128]).
|
497 |
+
In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
|
498 |
+
an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
|
499 |
+
so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).
|
500 |
+
|
501 |
+
When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
|
502 |
+
To maintain the same spatial sizes, the output image will be cropped to the original input size.
|
503 |
+
|
504 |
+
Args:
|
505 |
+
inputs: input image to be processed (assuming NCHW[D])
|
506 |
+
roi_size: the spatial window size for inferences.
|
507 |
+
When its components have None or non-positives, the corresponding inputs dimension will be used.
|
508 |
+
if the components of the `roi_size` are non-positive values, the transform will use the
|
509 |
+
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
|
510 |
+
to `(32, 64)` if the second spatial dimension size of img is `64`.
|
511 |
+
sw_batch_size: the batch size to run window slices.
|
512 |
+
predictor: given input tensor ``patch_data`` in shape NCHW[D],
|
513 |
+
The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary
|
514 |
+
with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];
|
515 |
+
where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,
|
516 |
+
N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),
|
517 |
+
the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).
|
518 |
+
In this case, the parameter `overlap` and `roi_size` need to be carefully chosen
|
519 |
+
to ensure the scaled output ROI sizes are still integers.
|
520 |
+
If the `predictor`'s input and output spatial sizes are different,
|
521 |
+
we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.
|
522 |
+
overlap: Amount of overlap between scans.
|
523 |
+
mode: {``"constant"``, ``"gaussian"``}
|
524 |
+
How to blend output of overlapping windows. Defaults to ``"constant"``.
|
525 |
+
|
526 |
+
- ``"constant``": gives equal weight to all predictions.
|
527 |
+
- ``"gaussian``": gives less weight to predictions on edges of windows.
|
528 |
+
|
529 |
+
sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
|
530 |
+
Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
|
531 |
+
When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
|
532 |
+
spatial dimensions.
|
533 |
+
padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
|
534 |
+
Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
|
535 |
+
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
536 |
+
cval: fill value for 'constant' padding mode. Default: 0
|
537 |
+
sw_device: device for the window data.
|
538 |
+
By default the device (and accordingly the memory) of the `inputs` is used.
|
539 |
+
Normally `sw_device` should be consistent with the device where `predictor` is defined.
|
540 |
+
device: device for the stitched output prediction.
|
541 |
+
By default the device (and accordingly the memory) of the `inputs` is used. If for example
|
542 |
+
set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
|
543 |
+
`inputs` and `roi_size`. Output is on the `device`.
|
544 |
+
progress: whether to print a `tqdm` progress bar.
|
545 |
+
roi_weight_map: pre-computed (non-negative) weight map for each ROI.
|
546 |
+
If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
|
547 |
+
args: optional args to be passed to ``predictor``.
|
548 |
+
kwargs: optional keyword args to be passed to ``predictor``.
|
549 |
+
|
550 |
+
Note:
|
551 |
+
- input must be channel-first and have a batch dim, supports N-D sliding window.
|
552 |
+
|
553 |
+
"""
|
554 |
+
compute_dtype = inputs.dtype
|
555 |
+
num_spatial_dims = len(inputs.shape) - 2
|
556 |
+
if overlap < 0 or overlap >= 1:
|
557 |
+
raise ValueError("overlap must be >= 0 and < 1.")
|
558 |
+
|
559 |
+
# determine image spatial size and batch size
|
560 |
+
# Note: all input images must have the same image size and batch size
|
561 |
+
batch_size, _, *image_size_ = inputs.shape
|
562 |
+
|
563 |
+
if device is None:
|
564 |
+
device = inputs.device
|
565 |
+
if sw_device is None:
|
566 |
+
sw_device = inputs.device
|
567 |
+
|
568 |
+
roi_size = fall_back_tuple(roi_size, image_size_)
|
569 |
+
# in case that image size is smaller than roi size
|
570 |
+
image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
|
571 |
+
pad_size = []
|
572 |
+
for k in range(len(inputs.shape) - 1, 1, -1):
|
573 |
+
diff = max(roi_size[k - 2] - inputs.shape[k], 0)
|
574 |
+
half = diff // 2
|
575 |
+
pad_size.extend([half, diff - half])
|
576 |
+
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)
|
577 |
+
#print('inputs',inputs.shape)
|
578 |
+
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
|
579 |
+
|
580 |
+
# Store all slices in list
|
581 |
+
slices = dense_patch_slices(image_size, roi_size, scan_interval)
|
582 |
+
num_win = len(slices) # number of windows per image
|
583 |
+
total_slices = num_win * batch_size # total number of windows
|
584 |
+
|
585 |
+
# Create window-level importance map
|
586 |
+
valid_patch_size = get_valid_patch_size(image_size, roi_size)
|
587 |
+
if valid_patch_size == roi_size and (roi_weight_map is not None):
|
588 |
+
importance_map = roi_weight_map
|
589 |
+
else:
|
590 |
+
try:
|
591 |
+
importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device)
|
592 |
+
except BaseException as e:
|
593 |
+
raise RuntimeError(
|
594 |
+
"Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
|
595 |
+
) from e
|
596 |
+
importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore
|
597 |
+
# handle non-positive weights
|
598 |
+
min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
|
599 |
+
importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype)
|
600 |
+
|
601 |
+
# Perform predictions
|
602 |
+
dict_key, output_image_list, count_map_list = None, [], []
|
603 |
+
_initialized_ss = -1
|
604 |
+
is_tensor_output = True # whether the predictor's output is a tensor (instead of dict/tuple)
|
605 |
+
|
606 |
+
# for each patch
|
607 |
+
for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size):
|
608 |
+
slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices))
|
609 |
+
unravel_slice = [
|
610 |
+
[slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win])
|
611 |
+
for idx in slice_range
|
612 |
+
]
|
613 |
+
window_data = torch.cat(
|
614 |
+
[convert_data_type(inputs[win_slice], torch.Tensor)[0] for win_slice in unravel_slice]
|
615 |
+
).to(sw_device)
|
616 |
+
seg_prob_out = predictor(window_data, *args, **kwargs) # batched patch segmentation
|
617 |
+
#print('seg_prob_out',seg_prob_out[0].shape)
|
618 |
+
# convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
|
619 |
+
seg_prob_tuple: Tuple[torch.Tensor, ...]
|
620 |
+
if isinstance(seg_prob_out, torch.Tensor):
|
621 |
+
seg_prob_tuple = (seg_prob_out,)
|
622 |
+
elif isinstance(seg_prob_out, Mapping):
|
623 |
+
if dict_key is None:
|
624 |
+
dict_key = sorted(seg_prob_out.keys()) # track predictor's output keys
|
625 |
+
seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key)
|
626 |
+
is_tensor_output = False
|
627 |
+
else:
|
628 |
+
seg_prob_tuple = ensure_tuple(seg_prob_out)
|
629 |
+
is_tensor_output = False
|
630 |
+
|
631 |
+
# for each output in multi-output list
|
632 |
+
for ss, seg_prob in enumerate(seg_prob_tuple):
|
633 |
+
seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN
|
634 |
+
|
635 |
+
# compute zoom scale: out_roi_size/in_roi_size
|
636 |
+
zoom_scale = []
|
637 |
+
for axis, (img_s_i, out_w_i, in_w_i) in enumerate(
|
638 |
+
zip(image_size, seg_prob.shape[2:], window_data.shape[2:])
|
639 |
+
):
|
640 |
+
_scale = out_w_i / float(in_w_i)
|
641 |
+
if not (img_s_i * _scale).is_integer():
|
642 |
+
warnings.warn(
|
643 |
+
f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial "
|
644 |
+
f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs."
|
645 |
+
)
|
646 |
+
zoom_scale.append(_scale)
|
647 |
+
|
648 |
+
if _initialized_ss < ss: # init. the ss-th buffer at the first iteration
|
649 |
+
# construct multi-resolution outputs
|
650 |
+
output_classes = seg_prob.shape[1]
|
651 |
+
output_shape = [batch_size, output_classes] + [
|
652 |
+
int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale)
|
653 |
+
]
|
654 |
+
# allocate memory to store the full output and the count for overlapping parts
|
655 |
+
output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device=device))
|
656 |
+
count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device))
|
657 |
+
_initialized_ss += 1
|
658 |
+
|
659 |
+
# resizing the importance_map
|
660 |
+
resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False)
|
661 |
+
|
662 |
+
# store the result in the proper location of the full output. Apply weights from importance map.
|
663 |
+
for idx, original_idx in zip(slice_range, unravel_slice):
|
664 |
+
# zoom roi
|
665 |
+
original_idx_zoom = list(original_idx) # 4D for 2D image, 5D for 3D image
|
666 |
+
for axis in range(2, len(original_idx_zoom)):
|
667 |
+
zoomed_start = original_idx[axis].start * zoom_scale[axis - 2]
|
668 |
+
zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2]
|
669 |
+
if not zoomed_start.is_integer() or (not zoomed_end.is_integer()):
|
670 |
+
warnings.warn(
|
671 |
+
f"For axis-{axis-2} of output[{ss}], the output roi range is not int. "
|
672 |
+
f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). "
|
673 |
+
f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. "
|
674 |
+
f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n"
|
675 |
+
f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. "
|
676 |
+
"Tips: if overlap*roi_size*zoom_scale is an integer, it usually works."
|
677 |
+
)
|
678 |
+
original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None)
|
679 |
+
importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(compute_dtype)
|
680 |
+
# store results and weights
|
681 |
+
output_image_list[ss][original_idx_zoom] += importance_map_zoom * seg_prob[idx - slice_g]
|
682 |
+
count_map_list[ss][original_idx_zoom] += (
|
683 |
+
importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape)
|
684 |
+
)
|
685 |
+
|
686 |
+
# account for any overlapping sections
|
687 |
+
for ss in range(len(output_image_list)):
|
688 |
+
output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype)
|
689 |
+
|
690 |
+
# remove padding if image_size smaller than roi_size
|
691 |
+
for ss, output_i in enumerate(output_image_list):
|
692 |
+
if torch.isnan(output_i).any() or torch.isinf(output_i).any():
|
693 |
+
warnings.warn("Sliding window inference results contain NaN or Inf.")
|
694 |
+
|
695 |
+
zoom_scale = [
|
696 |
+
seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size)
|
697 |
+
]
|
698 |
+
|
699 |
+
final_slicing: List[slice] = []
|
700 |
+
for sp in range(num_spatial_dims):
|
701 |
+
slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2])
|
702 |
+
slice_dim = slice(
|
703 |
+
int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])),
|
704 |
+
int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])),
|
705 |
+
)
|
706 |
+
final_slicing.insert(0, slice_dim)
|
707 |
+
while len(final_slicing) < len(output_i.shape):
|
708 |
+
final_slicing.insert(0, slice(None))
|
709 |
+
output_image_list[ss] = output_i[final_slicing]
|
710 |
+
|
711 |
+
if dict_key is not None: # if output of predictor is a dict
|
712 |
+
final_output = dict(zip(dict_key, output_image_list))
|
713 |
+
else:
|
714 |
+
final_output = tuple(output_image_list) # type: ignore
|
715 |
+
final_output = final_output[0] if is_tensor_output else final_output
|
716 |
+
|
717 |
+
if isinstance(inputs, MetaTensor):
|
718 |
+
final_output = convert_to_dst_type(final_output, inputs, device=device)[0] # type: ignore
|
719 |
+
return final_output
|
720 |
+
|
721 |
+
|
722 |
+
def _get_scan_interval(
|
723 |
+
image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float
|
724 |
+
) -> Tuple[int, ...]:
|
725 |
+
"""
|
726 |
+
Compute scan interval according to the image size, roi size and overlap.
|
727 |
+
Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
|
728 |
+
use 1 instead to make sure sliding window works.
|
729 |
+
|
730 |
+
"""
|
731 |
+
if len(image_size) != num_spatial_dims:
|
732 |
+
raise ValueError("image coord different from spatial dims.")
|
733 |
+
if len(roi_size) != num_spatial_dims:
|
734 |
+
raise ValueError("roi coord different from spatial dims.")
|
735 |
+
|
736 |
+
scan_interval = []
|
737 |
+
for i in range(num_spatial_dims):
|
738 |
+
if roi_size[i] == image_size[i]:
|
739 |
+
scan_interval.append(int(roi_size[i]))
|
740 |
+
else:
|
741 |
+
interval = int(roi_size[i] * (1 - overlap))
|
742 |
+
scan_interval.append(interval if interval > 0 else 1)
|
743 |
+
return tuple(scan_interval)
|