Lewislou commited on
Commit
991881f
1 Parent(s): 7545ee8

Upload 40 files

Browse files
Files changed (40) hide show
  1. README.md +105 -9
  2. classifiers.py +261 -0
  3. config.json +25 -0
  4. model.pt +3 -0
  5. models/__init__.py +10 -0
  6. models/convnext.py +220 -0
  7. models/flexible_unet.py +312 -0
  8. models/flexible_unet_convnext.py +447 -0
  9. overlay.py +116 -0
  10. pytorch_model.bin +3 -0
  11. requirements.txt +37 -0
  12. sribd_cellseg_models.py +100 -0
  13. stardist_pkg/__init__.py +26 -0
  14. stardist_pkg/__pycache__/__init__.cpython-37.pyc +0 -0
  15. stardist_pkg/__pycache__/big.cpython-37.pyc +0 -0
  16. stardist_pkg/__pycache__/bioimageio_utils.cpython-37.pyc +0 -0
  17. stardist_pkg/__pycache__/matching.cpython-37.pyc +0 -0
  18. stardist_pkg/__pycache__/nms.cpython-37.pyc +0 -0
  19. stardist_pkg/__pycache__/sample_patches.cpython-37.pyc +0 -0
  20. stardist_pkg/__pycache__/utils.cpython-37.pyc +0 -0
  21. stardist_pkg/__pycache__/version.cpython-37.pyc +0 -0
  22. stardist_pkg/big.py +601 -0
  23. stardist_pkg/bioimageio_utils.py +472 -0
  24. stardist_pkg/geometry/__init__.py +9 -0
  25. stardist_pkg/geometry/__pycache__/__init__.cpython-37.pyc +0 -0
  26. stardist_pkg/geometry/__pycache__/geom2d.cpython-37.pyc +0 -0
  27. stardist_pkg/geometry/__pycache__/geom3d.cpython-37.pyc +0 -0
  28. stardist_pkg/geometry/geom2d.py +212 -0
  29. stardist_pkg/kernels/stardist2d.cl +51 -0
  30. stardist_pkg/kernels/stardist3d.cl +63 -0
  31. stardist_pkg/matching.py +483 -0
  32. stardist_pkg/models/__init__.py +27 -0
  33. stardist_pkg/models/base.py +1196 -0
  34. stardist_pkg/models/model2d.py +570 -0
  35. stardist_pkg/nms.py +387 -0
  36. stardist_pkg/rays3d.py +373 -0
  37. stardist_pkg/sample_patches.py +65 -0
  38. stardist_pkg/utils.py +394 -0
  39. stardist_pkg/version.py +1 -0
  40. utils_modify.py +743 -0
README.md CHANGED
@@ -1,13 +1,109 @@
1
  ---
2
- title: Lewislou Cell Seg Sribd
3
- emoji: ⚡
4
- colorFrom: blue
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 3.38.0
8
- app_file: app.py
9
- pinned: false
10
  license: apache-2.0
 
 
 
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
 
 
 
 
 
 
 
 
2
  license: apache-2.0
3
+ language:
4
+ - en
5
+ metrics:
6
+ - f1
7
+ tags:
8
+ - cell segmentation
9
+ - stardist
10
+ - hover-net
11
+ library_name: transformers
12
+ pipeline_tag: image-segmentation
13
+ datasets:
14
+ - Lewislou/cell_samples
15
  ---
16
 
17
+
18
+ # Model Card for cell-seg-sribd
19
+
20
+ <!-- Provide a quick summary of what the model is/does. -->
21
+
22
+ This repository provides the solution of team Sribd-med for NeurIPS-CellSeg Challenge. The details of our method are described in our paper [Multi-stream Cell Segmentation with Low-level Cues for Multi-modality Images]. Some parts of the codes are from the baseline codes of the NeurIPS-CellSeg-Baseline repository,
23
+
24
+ You can reproduce our method as follows step by step:
25
+
26
+
27
+ ### How to Get Started with the Model
28
+
29
+ Install requirements by python -m pip install -r requirements.txt
30
+
31
+ ## Training Details
32
+
33
+ ### Training Data
34
+
35
+ The competition training and tuning data can be downloaded from https://neurips22-cellseg.grand-challenge.org/dataset/ Besides, you can download three publiced data from the following link: Cellpose: https://www.cellpose.org/dataset Omnipose: http://www.cellpose.org/dataset_omnipose Sartorius: https://www.kaggle.com/competitions/sartorius-cell-instance-segmentation/overview
36
+
37
+ ## Environments and Requirements:
38
+ Install requirements by
39
+
40
+ ```shell
41
+ python -m pip install -r requirements.txt
42
+ ```
43
+
44
+ ### How to use
45
+
46
+ Here is how to use this model:
47
+
48
+ ```python
49
+
50
+
51
+ from skimage import io, segmentation, morphology, measure, exposure
52
+ from sribd_cellseg_models import MultiStreamCellSegModel,ModelConfig
53
+ import numpy as np
54
+ import tifffile as tif
55
+ import requests
56
+ import torch
57
+ from PIL import Image
58
+ from overlay import visualize_instances_map
59
+ import cv2
60
+ img_name = 'test_images/cell_00551.tiff'
61
+
62
+ def normalize_channel(img, lower=1, upper=99):
63
+ non_zero_vals = img[np.nonzero(img)]
64
+ percentiles = np.percentile(non_zero_vals, [lower, upper])
65
+ if percentiles[1] - percentiles[0] > 0.001:
66
+ img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8')
67
+ else:
68
+ img_norm = img
69
+ return img_norm.astype(np.uint8)
70
+ if img_name.endswith('.tif') or img_name.endswith('.tiff'):
71
+ img_data = tif.imread(img_name)
72
+ else:
73
+ img_data = io.imread(img_name)
74
+ # normalize image data
75
+ if len(img_data.shape) == 2:
76
+ img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1)
77
+ elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
78
+ img_data = img_data[:,:, :3]
79
+ else:
80
+ pass
81
+ pre_img_data = np.zeros(img_data.shape, dtype=np.uint8)
82
+ for i in range(3):
83
+ img_channel_i = img_data[:,:,i]
84
+ if len(img_channel_i[np.nonzero(img_channel_i)])>0:
85
+ pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99)
86
+ #dummy_input = np.zeros((512,512,3)).astype(np.uint8)
87
+ my_model = MultiStreamCellSegModel.from_pretrained("Lewislou/cellseg_sribd")
88
+ checkpoints = torch.load('model.pt')
89
+ my_model.__init__(ModelConfig())
90
+ my_model.load_checkpoints(checkpoints)
91
+ with torch.no_grad():
92
+ output = my_model(pre_img_data)
93
+ overlay = visualize_instances_map(pre_img_data,star_label)
94
+ cv2.imwrite('prediction.png', cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
95
+
96
+
97
+ ```
98
+
99
+ ## Citation
100
+ If any part of this code is used, please acknowledge it appropriately and cite the paper:
101
+ ```bibtex
102
+ @misc{
103
+ lou2022multistream,
104
+ title={Multi-stream Cell Segmentation with Low-level Cues for Multi-modality Images},
105
+ author={WEI LOU and Xinyi Yu and Chenyu Liu and Xiang Wan and Guanbin Li and Siqi Liu and Haofeng Li},
106
+ year={2022},
107
+ url={https://openreview.net/forum?id=G24BybwKe9}
108
+ }
109
+ ```
classifiers.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, Callable, List, Optional, Type, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+
8
+ def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
9
+ """3x3 convolution with padding"""
10
+ return nn.Conv2d(
11
+ in_planes,
12
+ out_planes,
13
+ kernel_size=3,
14
+ stride=stride,
15
+ padding=dilation,
16
+ groups=groups,
17
+ bias=False,
18
+ dilation=dilation,
19
+ )
20
+
21
+
22
+ def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
23
+ """1x1 convolution"""
24
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
25
+
26
+
27
+ class BasicBlock(nn.Module):
28
+ expansion: int = 1
29
+
30
+ def __init__(
31
+ self,
32
+ inplanes: int,
33
+ planes: int,
34
+ stride: int = 1,
35
+ downsample: Optional[nn.Module] = None,
36
+ groups: int = 1,
37
+ base_width: int = 64,
38
+ dilation: int = 1,
39
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
40
+ ) -> None:
41
+ super().__init__()
42
+ if norm_layer is None:
43
+ norm_layer = nn.BatchNorm2d
44
+ if groups != 1 or base_width != 64:
45
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
46
+ if dilation > 1:
47
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
48
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
49
+ self.conv1 = conv3x3(inplanes, planes, stride)
50
+ self.bn1 = norm_layer(planes)
51
+ self.relu = nn.ReLU(inplace=True)
52
+ self.conv2 = conv3x3(planes, planes)
53
+ self.bn2 = norm_layer(planes)
54
+ self.downsample = downsample
55
+ self.stride = stride
56
+
57
+ def forward(self, x: Tensor) -> Tensor:
58
+ identity = x
59
+
60
+ out = self.conv1(x)
61
+ out = self.bn1(out)
62
+ out = self.relu(out)
63
+
64
+ out = self.conv2(out)
65
+ out = self.bn2(out)
66
+
67
+ if self.downsample is not None:
68
+ identity = self.downsample(x)
69
+
70
+ out += identity
71
+ out = self.relu(out)
72
+
73
+ return out
74
+
75
+ class Bottleneck(nn.Module):
76
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
77
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
78
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
79
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
80
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
81
+
82
+ expansion: int = 4
83
+
84
+ def __init__(
85
+ self,
86
+ inplanes: int,
87
+ planes: int,
88
+ stride: int = 1,
89
+ downsample: Optional[nn.Module] = None,
90
+ groups: int = 1,
91
+ base_width: int = 64,
92
+ dilation: int = 1,
93
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
94
+ ) -> None:
95
+ super().__init__()
96
+ if norm_layer is None:
97
+ norm_layer = nn.BatchNorm2d
98
+ width = int(planes * (base_width / 64.0)) * groups
99
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
100
+ self.conv1 = conv1x1(inplanes, width)
101
+ self.bn1 = norm_layer(width)
102
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
103
+ self.bn2 = norm_layer(width)
104
+ self.conv3 = conv1x1(width, planes * self.expansion)
105
+ self.bn3 = norm_layer(planes * self.expansion)
106
+ self.relu = nn.ReLU(inplace=True)
107
+ self.downsample = downsample
108
+ self.stride = stride
109
+
110
+ def forward(self, x: Tensor) -> Tensor:
111
+ identity = x
112
+
113
+ out = self.conv1(x)
114
+ out = self.bn1(out)
115
+ out = self.relu(out)
116
+
117
+ out = self.conv2(out)
118
+ out = self.bn2(out)
119
+ out = self.relu(out)
120
+
121
+ out = self.conv3(out)
122
+ out = self.bn3(out)
123
+
124
+ if self.downsample is not None:
125
+ identity = self.downsample(x)
126
+
127
+ out += identity
128
+ out = self.relu(out)
129
+
130
+ return out
131
+
132
+ class ResNet(nn.Module):
133
+ def __init__(
134
+ self,
135
+ block: Type[Union[BasicBlock, Bottleneck]],
136
+ layers: List[int],
137
+ num_classes: int = 1000,
138
+ zero_init_residual: bool = False,
139
+ groups: int = 1,
140
+ width_per_group: int = 64,
141
+ replace_stride_with_dilation: Optional[List[bool]] = None,
142
+ norm_layer: Optional[Callable[..., nn.Module]] = None,
143
+ ) -> None:
144
+ super().__init__()
145
+ # _log_api_usage_once(self)
146
+ if norm_layer is None:
147
+ norm_layer = nn.BatchNorm2d
148
+ self._norm_layer = norm_layer
149
+
150
+ self.inplanes = 64
151
+ self.dilation = 1
152
+ if replace_stride_with_dilation is None:
153
+ # each element in the tuple indicates if we should replace
154
+ # the 2x2 stride with a dilated convolution instead
155
+ replace_stride_with_dilation = [False, False, False]
156
+ if len(replace_stride_with_dilation) != 3:
157
+ raise ValueError(
158
+ "replace_stride_with_dilation should be None "
159
+ f"or a 3-element tuple, got {replace_stride_with_dilation}"
160
+ )
161
+ self.groups = groups
162
+ self.base_width = width_per_group
163
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
164
+ self.bn1 = norm_layer(self.inplanes)
165
+ self.relu = nn.ReLU(inplace=True)
166
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
167
+ self.layer1 = self._make_layer(block, 64, layers[0])
168
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
169
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
170
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
171
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
172
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
173
+
174
+ for m in self.modules():
175
+ if isinstance(m, nn.Conv2d):
176
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
177
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
178
+ nn.init.constant_(m.weight, 1)
179
+ nn.init.constant_(m.bias, 0)
180
+
181
+ # Zero-initialize the last BN in each residual branch,
182
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
183
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
184
+ if zero_init_residual:
185
+ for m in self.modules():
186
+ if isinstance(m, Bottleneck) and m.bn3.weight is not None:
187
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
188
+ elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
189
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
190
+
191
+ def _make_layer(
192
+ self,
193
+ block: Type[Union[BasicBlock, Bottleneck]],
194
+ planes: int,
195
+ blocks: int,
196
+ stride: int = 1,
197
+ dilate: bool = False,
198
+ ) -> nn.Sequential:
199
+ norm_layer = self._norm_layer
200
+ downsample = None
201
+ previous_dilation = self.dilation
202
+ if dilate:
203
+ self.dilation *= stride
204
+ stride = 1
205
+ if stride != 1 or self.inplanes != planes * block.expansion:
206
+ downsample = nn.Sequential(
207
+ conv1x1(self.inplanes, planes * block.expansion, stride),
208
+ norm_layer(planes * block.expansion),
209
+ )
210
+
211
+ layers = []
212
+ layers.append(
213
+ block(
214
+ self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
215
+ )
216
+ )
217
+ self.inplanes = planes * block.expansion
218
+ for _ in range(1, blocks):
219
+ layers.append(
220
+ block(
221
+ self.inplanes,
222
+ planes,
223
+ groups=self.groups,
224
+ base_width=self.base_width,
225
+ dilation=self.dilation,
226
+ norm_layer=norm_layer,
227
+ )
228
+ )
229
+
230
+ return nn.Sequential(*layers)
231
+
232
+ def _forward_impl(self, x: Tensor) -> Tensor:
233
+ # See note [TorchScript super()]
234
+ x = self.conv1(x)
235
+ x = self.bn1(x)
236
+ x = self.relu(x)
237
+ x = self.maxpool(x)
238
+
239
+ x = self.layer1(x)
240
+ x = self.layer2(x)
241
+ x = self.layer3(x)
242
+ x = self.layer4(x)
243
+
244
+ x = self.avgpool(x)
245
+ x = torch.flatten(x, 1)
246
+ x = self.fc(x)
247
+
248
+ return x
249
+
250
+ def forward(self, x: Tensor) -> Tensor:
251
+ return self._forward_impl(x)
252
+
253
+ def resnet18(weights=None):
254
+ # weights: path
255
+ model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=4)
256
+ if weights is not None:
257
+ model.load_state_dict(torch.load(weights))
258
+ return model
259
+
260
+ def resnet10():
261
+ return ResNet(BasicBlock, [1, 1, 1, 1], num_classes=4)
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MultiStreamCellSegModel"
4
+ ],
5
+ "block_size": 2048,
6
+ "context": 128,
7
+ "device": "cpu",
8
+ "input_channels": 3,
9
+ "ksize": 15,
10
+ "min_overlap": 128,
11
+ "model_type": "cell_sribd",
12
+ "n_rays": 32,
13
+ "np_thres": 0.6,
14
+ "num_classes": 4,
15
+ "obj_size_thres": 100,
16
+ "overall_thres": 0.4,
17
+ "overlap": 0.5,
18
+ "roi_size": [
19
+ 512,
20
+ 512
21
+ ],
22
+ "sw_batch_size": 4,
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.27.1"
25
+ }
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:460f2c3a9168220ef03404df983b923c7f59d6db873536cd44d9e4b7e4354f6c
3
+ size 135
models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Sun Mar 20 14:23:55 2022
5
+
6
+ @author: jma
7
+ """
8
+
9
+ #from .unetr2d import UNETR2D
10
+ #from .swin_unetr import SwinUNETR
models/convnext.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ from functools import partial
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from timm.models.layers import trunc_normal_, DropPath
13
+ from timm.models.registry import register_model
14
+ from monai.networks.layers.factories import Act, Conv, Pad, Pool
15
+ from monai.networks.layers.utils import get_norm_layer
16
+ from monai.utils.module import look_up_option
17
+ from typing import List, NamedTuple, Optional, Tuple, Type, Union
18
+ class Block(nn.Module):
19
+ r""" ConvNeXt Block. There are two equivalent implementations:
20
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
21
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
22
+ We use (2) as we find it slightly faster in PyTorch
23
+
24
+ Args:
25
+ dim (int): Number of input channels.
26
+ drop_path (float): Stochastic depth rate. Default: 0.0
27
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
28
+ """
29
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
30
+ super().__init__()
31
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
32
+ self.norm = LayerNorm(dim, eps=1e-6)
33
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
34
+ self.act = nn.GELU()
35
+ self.pwconv2 = nn.Linear(4 * dim, dim)
36
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
37
+ requires_grad=True) if layer_scale_init_value > 0 else None
38
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
39
+
40
+ def forward(self, x):
41
+ input = x
42
+ x = self.dwconv(x)
43
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
44
+ x = self.norm(x)
45
+ x = self.pwconv1(x)
46
+ x = self.act(x)
47
+ x = self.pwconv2(x)
48
+ if self.gamma is not None:
49
+ x = self.gamma * x
50
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
51
+
52
+ x = input + self.drop_path(x)
53
+ return x
54
+
55
+ class ConvNeXt(nn.Module):
56
+ r""" ConvNeXt
57
+ A PyTorch impl of : `A ConvNet for the 2020s` -
58
+ https://arxiv.org/pdf/2201.03545.pdf
59
+
60
+ Args:
61
+ in_chans (int): Number of input image channels. Default: 3
62
+ num_classes (int): Number of classes for classification head. Default: 1000
63
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
64
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
65
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
66
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
67
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
68
+ """
69
+ def __init__(self, in_chans=3, num_classes=21841,
70
+ depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
71
+ layer_scale_init_value=1e-6, head_init_scale=1., out_indices=[0, 1, 2, 3],
72
+ ):
73
+ super().__init__()
74
+ # conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv["conv", 2]
75
+ # self._conv_stem = conv_type(self.in_channels, self.in_channels, kernel_size=3, stride=stride, bias=False)
76
+ # self._conv_stem_padding = _make_same_padder(self._conv_stem, current_image_size)
77
+
78
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
79
+ stem = nn.Sequential(
80
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
81
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
82
+ )
83
+ self.downsample_layers.append(stem)
84
+ for i in range(3):
85
+ downsample_layer = nn.Sequential(
86
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
87
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
88
+ )
89
+ self.downsample_layers.append(downsample_layer)
90
+
91
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
92
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
93
+ cur = 0
94
+ for i in range(4):
95
+ stage = nn.Sequential(
96
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
97
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
98
+ )
99
+ self.stages.append(stage)
100
+ cur += depths[i]
101
+
102
+
103
+ self.out_indices = out_indices
104
+
105
+ norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
106
+ for i_layer in range(4):
107
+ layer = norm_layer(dims[i_layer])
108
+ layer_name = f'norm{i_layer}'
109
+ self.add_module(layer_name, layer)
110
+ self.apply(self._init_weights)
111
+
112
+
113
+ def _init_weights(self, m):
114
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
115
+ trunc_normal_(m.weight, std=.02)
116
+ nn.init.constant_(m.bias, 0)
117
+
118
+ def forward_features(self, x):
119
+ outs = []
120
+
121
+ for i in range(4):
122
+ x = self.downsample_layers[i](x)
123
+ x = self.stages[i](x)
124
+ if i in self.out_indices:
125
+ norm_layer = getattr(self, f'norm{i}')
126
+ x_out = norm_layer(x)
127
+
128
+ outs.append(x_out)
129
+
130
+ return tuple(outs)
131
+
132
+ def forward(self, x):
133
+ x = self.forward_features(x)
134
+
135
+ return x
136
+
137
+ class LayerNorm(nn.Module):
138
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
139
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
140
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
141
+ with shape (batch_size, channels, height, width).
142
+ """
143
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
144
+ super().__init__()
145
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
146
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
147
+ self.eps = eps
148
+ self.data_format = data_format
149
+ if self.data_format not in ["channels_last", "channels_first"]:
150
+ raise NotImplementedError
151
+ self.normalized_shape = (normalized_shape, )
152
+
153
+ def forward(self, x):
154
+ if self.data_format == "channels_last":
155
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
156
+ elif self.data_format == "channels_first":
157
+ u = x.mean(1, keepdim=True)
158
+ s = (x - u).pow(2).mean(1, keepdim=True)
159
+ x = (x - u) / torch.sqrt(s + self.eps)
160
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
161
+ return x
162
+
163
+
164
+ model_urls = {
165
+ "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
166
+ "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
167
+ "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
168
+ "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
169
+ "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
170
+ "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
171
+ "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
172
+ "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
173
+ "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
174
+ }
175
+
176
+ @register_model
177
+ def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
178
+ model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
179
+ if pretrained:
180
+ url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
181
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
182
+ model.load_state_dict(checkpoint["model"])
183
+ return model
184
+
185
+ @register_model
186
+ def convnext_small(pretrained=False,in_22k=False, **kwargs):
187
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
188
+ if pretrained:
189
+ url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
190
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
191
+ model.load_state_dict(checkpoint["model"], strict=False)
192
+ return model
193
+
194
+ @register_model
195
+ def convnext_base(pretrained=False, in_22k=False, **kwargs):
196
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
197
+ if pretrained:
198
+ url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
199
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
200
+ model.load_state_dict(checkpoint["model"], strict=False)
201
+ return model
202
+
203
+ @register_model
204
+ def convnext_large(pretrained=False, in_22k=False, **kwargs):
205
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
206
+ if pretrained:
207
+ url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
208
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
209
+ model.load_state_dict(checkpoint["model"])
210
+ return model
211
+
212
+ @register_model
213
+ def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
214
+ model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
215
+ if pretrained:
216
+ assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
217
+ url = model_urls['convnext_xlarge_22k']
218
+ checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
219
+ model.load_state_dict(checkpoint["model"])
220
+ return model
models/flexible_unet.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from typing import List, Optional, Sequence, Tuple, Union
13
+
14
+ import torch
15
+ from torch import nn
16
+
17
+ from monai.networks.blocks import UpSample
18
+ from monai.networks.layers.factories import Conv
19
+ from monai.networks.layers.utils import get_act_layer
20
+ from monai.networks.nets import EfficientNetBNFeatures
21
+ from monai.networks.nets.basic_unet import UpCat
22
+ from monai.utils import InterpolateMode
23
+
24
+ __all__ = ["FlexibleUNet"]
25
+
26
+ encoder_feature_channel = {
27
+ "efficientnet-b0": (16, 24, 40, 112, 320),
28
+ "efficientnet-b1": (16, 24, 40, 112, 320),
29
+ "efficientnet-b2": (16, 24, 48, 120, 352),
30
+ "efficientnet-b3": (24, 32, 48, 136, 384),
31
+ "efficientnet-b4": (24, 32, 56, 160, 448),
32
+ "efficientnet-b5": (24, 40, 64, 176, 512),
33
+ "efficientnet-b6": (32, 40, 72, 200, 576),
34
+ "efficientnet-b7": (32, 48, 80, 224, 640),
35
+ "efficientnet-b8": (32, 56, 88, 248, 704),
36
+ "efficientnet-l2": (72, 104, 176, 480, 1376),
37
+ }
38
+
39
+
40
+ def _get_encoder_channels_by_backbone(backbone: str, in_channels: int = 3) -> tuple:
41
+ """
42
+ Get the encoder output channels by given backbone name.
43
+
44
+ Args:
45
+ backbone: name of backbone to generate features, can be from [efficientnet-b0, ..., efficientnet-b7].
46
+ in_channels: channel of input tensor, default to 3.
47
+
48
+ Returns:
49
+ A tuple of output feature map channels' length .
50
+ """
51
+ encoder_channel_tuple = encoder_feature_channel[backbone]
52
+ encoder_channel_list = [in_channels] + list(encoder_channel_tuple)
53
+ encoder_channel = tuple(encoder_channel_list)
54
+ return encoder_channel
55
+
56
+
57
+ class UNetDecoder(nn.Module):
58
+ """
59
+ UNet Decoder.
60
+ This class refers to `segmentation_models.pytorch
61
+ <https://github.com/qubvel/segmentation_models.pytorch>`_.
62
+
63
+ Args:
64
+ spatial_dims: number of spatial dimensions.
65
+ encoder_channels: number of output channels for all feature maps in encoder.
66
+ `len(encoder_channels)` should be no less than 2.
67
+ decoder_channels: number of output channels for all feature maps in decoder.
68
+ `len(decoder_channels)` should equal to `len(encoder_channels) - 1`.
69
+ act: activation type and arguments.
70
+ norm: feature normalization type and arguments.
71
+ dropout: dropout ratio.
72
+ bias: whether to have a bias term in convolution blocks in this decoder.
73
+ upsample: upsampling mode, available options are
74
+ ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
75
+ pre_conv: a conv block applied before upsampling.
76
+ Only used in the "nontrainable" or "pixelshuffle" mode.
77
+ interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
78
+ Only used in the "nontrainable" mode.
79
+ align_corners: set the align_corners parameter for upsample. Defaults to True.
80
+ Only used in the "nontrainable" mode.
81
+ is_pad: whether to pad upsampling features to fit the encoder spatial dims.
82
+
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ spatial_dims: int,
88
+ encoder_channels: Sequence[int],
89
+ decoder_channels: Sequence[int],
90
+ act: Union[str, tuple],
91
+ norm: Union[str, tuple],
92
+ dropout: Union[float, tuple],
93
+ bias: bool,
94
+ upsample: str,
95
+ pre_conv: Optional[str],
96
+ interp_mode: str,
97
+ align_corners: Optional[bool],
98
+ is_pad: bool,
99
+ ):
100
+
101
+ super().__init__()
102
+ if len(encoder_channels) < 2:
103
+ raise ValueError("the length of `encoder_channels` should be no less than 2.")
104
+ if len(decoder_channels) != len(encoder_channels) - 1:
105
+ raise ValueError("`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.")
106
+
107
+ in_channels = [encoder_channels[-1]] + list(decoder_channels[:-1])
108
+ skip_channels = list(encoder_channels[1:-1][::-1]) + [0]
109
+ halves = [True] * (len(skip_channels) - 1)
110
+ halves.append(False)
111
+ blocks = []
112
+ for in_chn, skip_chn, out_chn, halve in zip(in_channels, skip_channels, decoder_channels, halves):
113
+ blocks.append(
114
+ UpCat(
115
+ spatial_dims=spatial_dims,
116
+ in_chns=in_chn,
117
+ cat_chns=skip_chn,
118
+ out_chns=out_chn,
119
+ act=act,
120
+ norm=norm,
121
+ dropout=dropout,
122
+ bias=bias,
123
+ upsample=upsample,
124
+ pre_conv=pre_conv,
125
+ interp_mode=interp_mode,
126
+ align_corners=align_corners,
127
+ halves=halve,
128
+ is_pad=is_pad,
129
+ )
130
+ )
131
+ self.blocks = nn.ModuleList(blocks)
132
+
133
+ def forward(self, features: List[torch.Tensor], skip_connect: int = 4):
134
+ skips = features[:-1][::-1]
135
+ features = features[1:][::-1]
136
+
137
+ x = features[0]
138
+ for i, block in enumerate(self.blocks):
139
+ if i < skip_connect:
140
+ skip = skips[i]
141
+ else:
142
+ skip = None
143
+ x = block(x, skip)
144
+
145
+ return x
146
+
147
+
148
+ class SegmentationHead(nn.Sequential):
149
+ """
150
+ Segmentation head.
151
+ This class refers to `segmentation_models.pytorch
152
+ <https://github.com/qubvel/segmentation_models.pytorch>`_.
153
+
154
+ Args:
155
+ spatial_dims: number of spatial dimensions.
156
+ in_channels: number of input channels for the block.
157
+ out_channels: number of output channels for the block.
158
+ kernel_size: kernel size for the conv layer.
159
+ act: activation type and arguments.
160
+ scale_factor: multiplier for spatial size. Has to match input size if it is a tuple.
161
+
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ spatial_dims: int,
167
+ in_channels: int,
168
+ out_channels: int,
169
+ kernel_size: int = 3,
170
+ act: Optional[Union[Tuple, str]] = None,
171
+ scale_factor: float = 1.0,
172
+ ):
173
+
174
+ conv_layer = Conv[Conv.CONV, spatial_dims](
175
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=kernel_size // 2
176
+ )
177
+ up_layer: nn.Module = nn.Identity()
178
+ if scale_factor > 1.0:
179
+ up_layer = UpSample(
180
+ spatial_dims=spatial_dims,
181
+ scale_factor=scale_factor,
182
+ mode="nontrainable",
183
+ pre_conv=None,
184
+ interp_mode=InterpolateMode.LINEAR,
185
+ )
186
+ if act is not None:
187
+ act_layer = get_act_layer(act)
188
+ else:
189
+ act_layer = nn.Identity()
190
+ super().__init__(conv_layer, up_layer, act_layer)
191
+
192
+
193
+ class FlexibleUNet(nn.Module):
194
+ """
195
+ A flexible implementation of UNet-like encoder-decoder architecture.
196
+ """
197
+
198
+ def __init__(
199
+ self,
200
+ in_channels: int,
201
+ out_channels: int,
202
+ backbone: str,
203
+ pretrained: bool = False,
204
+ decoder_channels: Tuple = (256, 128, 64, 32, 16),
205
+ spatial_dims: int = 2,
206
+ norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}),
207
+ act: Union[str, tuple] = ("relu", {"inplace": True}),
208
+ dropout: Union[float, tuple] = 0.0,
209
+ decoder_bias: bool = False,
210
+ upsample: str = "nontrainable",
211
+ interp_mode: str = "nearest",
212
+ is_pad: bool = True,
213
+ ) -> None:
214
+ """
215
+ A flexible implement of UNet, in which the backbone/encoder can be replaced with
216
+ any efficient network. Currently the input must have a 2 or 3 spatial dimension
217
+ and the spatial size of each dimension must be a multiple of 32 if is pad parameter
218
+ is False
219
+
220
+ Args:
221
+ in_channels: number of input channels.
222
+ out_channels: number of output channels.
223
+ backbone: name of backbones to initialize, only support efficientnet right now,
224
+ can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
225
+ pretrained: whether to initialize pretrained ImageNet weights, only available
226
+ for spatial_dims=2 and batch norm is used, default to False.
227
+ decoder_channels: number of output channels for all feature maps in decoder.
228
+ `len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
229
+ to (256, 128, 64, 32, 16).
230
+ spatial_dims: number of spatial dimensions, default to 2.
231
+ norm: normalization type and arguments, default to ("batch", {"eps": 1e-3,
232
+ "momentum": 0.1}).
233
+ act: activation type and arguments, default to ("relu", {"inplace": True}).
234
+ dropout: dropout ratio, default to 0.0.
235
+ decoder_bias: whether to have a bias term in decoder's convolution blocks.
236
+ upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``,
237
+ ``"nontrainable"``.
238
+ interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
239
+ Only used in the "nontrainable" mode.
240
+ is_pad: whether to pad upsampling features to fit features from encoder. Default to True.
241
+ If this parameter is set to "True", the spatial dim of network input can be arbitary
242
+ size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32.
243
+ """
244
+ super().__init__()
245
+
246
+ if backbone not in encoder_feature_channel:
247
+ raise ValueError(f"invalid model_name {backbone} found, must be one of {encoder_feature_channel.keys()}.")
248
+
249
+ if spatial_dims not in (2, 3):
250
+ raise ValueError("spatial_dims can only be 2 or 3.")
251
+
252
+ adv_prop = "ap" in backbone
253
+
254
+ self.backbone = backbone
255
+ self.spatial_dims = spatial_dims
256
+ model_name = backbone
257
+ encoder_channels = _get_encoder_channels_by_backbone(backbone, in_channels)
258
+ self.encoder = EfficientNetBNFeatures(
259
+ model_name=model_name,
260
+ pretrained=pretrained,
261
+ in_channels=in_channels,
262
+ spatial_dims=spatial_dims,
263
+ norm=norm,
264
+ adv_prop=adv_prop,
265
+ )
266
+ self.decoder = UNetDecoder(
267
+ spatial_dims=spatial_dims,
268
+ encoder_channels=encoder_channels,
269
+ decoder_channels=decoder_channels,
270
+ act=act,
271
+ norm=norm,
272
+ dropout=dropout,
273
+ bias=decoder_bias,
274
+ upsample=upsample,
275
+ interp_mode=interp_mode,
276
+ pre_conv=None,
277
+ align_corners=None,
278
+ is_pad=is_pad,
279
+ )
280
+ self.dist_head = SegmentationHead(
281
+ spatial_dims=spatial_dims,
282
+ in_channels=decoder_channels[-1],
283
+ out_channels=32,
284
+ kernel_size=1,
285
+ act='relu',
286
+ )
287
+ self.prob_head = SegmentationHead(
288
+ spatial_dims=spatial_dims,
289
+ in_channels=decoder_channels[-1],
290
+ out_channels=1,
291
+ kernel_size=1,
292
+ act='sigmoid',
293
+ )
294
+
295
+ def forward(self, inputs: torch.Tensor):
296
+ """
297
+ Do a typical encoder-decoder-header inference.
298
+
299
+ Args:
300
+ inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``,
301
+ N is defined by `dimensions`.
302
+
303
+ Returns:
304
+ A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.
305
+
306
+ """
307
+ x = inputs
308
+ enc_out = self.encoder(x)
309
+ decoder_out = self.decoder(enc_out)
310
+ dist = self.dist_head(decoder_out)
311
+ prob = self.prob_head(decoder_out)
312
+ return dist,prob
models/flexible_unet_convnext.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from typing import List, Optional, Sequence, Tuple, Union
13
+
14
+ import torch
15
+ from torch import nn
16
+ from . import convnext
17
+ from monai.networks.blocks import UpSample
18
+ from monai.networks.layers.factories import Conv
19
+ from monai.networks.layers.utils import get_act_layer
20
+ from monai.networks.nets import EfficientNetBNFeatures
21
+ from monai.networks.nets.basic_unet import UpCat
22
+ from monai.utils import InterpolateMode
23
+
24
+ __all__ = ["FlexibleUNet"]
25
+
26
+ encoder_feature_channel = {
27
+ "efficientnet-b0": (16, 24, 40, 112, 320),
28
+ "efficientnet-b1": (16, 24, 40, 112, 320),
29
+ "efficientnet-b2": (16, 24, 48, 120, 352),
30
+ "efficientnet-b3": (24, 32, 48, 136, 384),
31
+ "efficientnet-b4": (24, 32, 56, 160, 448),
32
+ "efficientnet-b5": (24, 40, 64, 176, 512),
33
+ "efficientnet-b6": (32, 40, 72, 200, 576),
34
+ "efficientnet-b7": (32, 48, 80, 224, 640),
35
+ "efficientnet-b8": (32, 56, 88, 248, 704),
36
+ "efficientnet-l2": (72, 104, 176, 480, 1376),
37
+ "convnext_small": (96, 192, 384, 768),
38
+ "convnext_base": (128, 256, 512, 1024),
39
+ "van_b2": (64, 128, 320, 512),
40
+ "van_b1": (64, 128, 320, 512),
41
+ }
42
+
43
+
44
+ def _get_encoder_channels_by_backbone(backbone: str, in_channels: int = 3) -> tuple:
45
+ """
46
+ Get the encoder output channels by given backbone name.
47
+
48
+ Args:
49
+ backbone: name of backbone to generate features, can be from [efficientnet-b0, ..., efficientnet-b7].
50
+ in_channels: channel of input tensor, default to 3.
51
+
52
+ Returns:
53
+ A tuple of output feature map channels' length .
54
+ """
55
+ encoder_channel_tuple = encoder_feature_channel[backbone]
56
+ encoder_channel_list = [in_channels] + list(encoder_channel_tuple)
57
+ encoder_channel = tuple(encoder_channel_list)
58
+ return encoder_channel
59
+
60
+
61
+ class UNetDecoder(nn.Module):
62
+ """
63
+ UNet Decoder.
64
+ This class refers to `segmentation_models.pytorch
65
+ <https://github.com/qubvel/segmentation_models.pytorch>`_.
66
+
67
+ Args:
68
+ spatial_dims: number of spatial dimensions.
69
+ encoder_channels: number of output channels for all feature maps in encoder.
70
+ `len(encoder_channels)` should be no less than 2.
71
+ decoder_channels: number of output channels for all feature maps in decoder.
72
+ `len(decoder_channels)` should equal to `len(encoder_channels) - 1`.
73
+ act: activation type and arguments.
74
+ norm: feature normalization type and arguments.
75
+ dropout: dropout ratio.
76
+ bias: whether to have a bias term in convolution blocks in this decoder.
77
+ upsample: upsampling mode, available options are
78
+ ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
79
+ pre_conv: a conv block applied before upsampling.
80
+ Only used in the "nontrainable" or "pixelshuffle" mode.
81
+ interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
82
+ Only used in the "nontrainable" mode.
83
+ align_corners: set the align_corners parameter for upsample. Defaults to True.
84
+ Only used in the "nontrainable" mode.
85
+ is_pad: whether to pad upsampling features to fit the encoder spatial dims.
86
+
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ spatial_dims: int,
92
+ encoder_channels: Sequence[int],
93
+ decoder_channels: Sequence[int],
94
+ act: Union[str, tuple],
95
+ norm: Union[str, tuple],
96
+ dropout: Union[float, tuple],
97
+ bias: bool,
98
+ upsample: str,
99
+ pre_conv: Optional[str],
100
+ interp_mode: str,
101
+ align_corners: Optional[bool],
102
+ is_pad: bool,
103
+ ):
104
+
105
+ super().__init__()
106
+ if len(encoder_channels) < 2:
107
+ raise ValueError("the length of `encoder_channels` should be no less than 2.")
108
+ if len(decoder_channels) != len(encoder_channels) - 1:
109
+ raise ValueError("`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.")
110
+
111
+ in_channels = [encoder_channels[-1]] + list(decoder_channels[:-1])
112
+ skip_channels = list(encoder_channels[1:-1][::-1]) + [0]
113
+ halves = [True] * (len(skip_channels) - 1)
114
+ halves.append(False)
115
+ blocks = []
116
+ for in_chn, skip_chn, out_chn, halve in zip(in_channels, skip_channels, decoder_channels, halves):
117
+ blocks.append(
118
+ UpCat(
119
+ spatial_dims=spatial_dims,
120
+ in_chns=in_chn,
121
+ cat_chns=skip_chn,
122
+ out_chns=out_chn,
123
+ act=act,
124
+ norm=norm,
125
+ dropout=dropout,
126
+ bias=bias,
127
+ upsample=upsample,
128
+ pre_conv=pre_conv,
129
+ interp_mode=interp_mode,
130
+ align_corners=align_corners,
131
+ halves=halve,
132
+ is_pad=is_pad,
133
+ )
134
+ )
135
+ self.blocks = nn.ModuleList(blocks)
136
+
137
+ def forward(self, features: List[torch.Tensor], skip_connect: int = 3):
138
+ skips = features[:-1][::-1]
139
+ features = features[1:][::-1]
140
+
141
+ x = features[0]
142
+ for i, block in enumerate(self.blocks):
143
+ if i < skip_connect:
144
+ skip = skips[i]
145
+ else:
146
+ skip = None
147
+ x = block(x, skip)
148
+
149
+ return x
150
+
151
+
152
+ class SegmentationHead(nn.Sequential):
153
+ """
154
+ Segmentation head.
155
+ This class refers to `segmentation_models.pytorch
156
+ <https://github.com/qubvel/segmentation_models.pytorch>`_.
157
+
158
+ Args:
159
+ spatial_dims: number of spatial dimensions.
160
+ in_channels: number of input channels for the block.
161
+ out_channels: number of output channels for the block.
162
+ kernel_size: kernel size for the conv layer.
163
+ act: activation type and arguments.
164
+ scale_factor: multiplier for spatial size. Has to match input size if it is a tuple.
165
+
166
+ """
167
+
168
+ def __init__(
169
+ self,
170
+ spatial_dims: int,
171
+ in_channels: int,
172
+ out_channels: int,
173
+ kernel_size: int = 3,
174
+ act: Optional[Union[Tuple, str]] = None,
175
+ scale_factor: float = 1.0,
176
+ ):
177
+
178
+ conv_layer = Conv[Conv.CONV, spatial_dims](
179
+ in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=kernel_size // 2
180
+ )
181
+ up_layer: nn.Module = nn.Identity()
182
+ # if scale_factor > 1.0:
183
+ # up_layer = UpSample(
184
+ # in_channels=out_channels,
185
+ # spatial_dims=spatial_dims,
186
+ # scale_factor=scale_factor,
187
+ # mode="deconv",
188
+ # pre_conv=None,
189
+ # interp_mode=InterpolateMode.LINEAR,
190
+ # )
191
+ if scale_factor > 1.0:
192
+ up_layer = UpSample(
193
+ spatial_dims=spatial_dims,
194
+ scale_factor=scale_factor,
195
+ mode="nontrainable",
196
+ pre_conv=None,
197
+ interp_mode=InterpolateMode.LINEAR,
198
+ )
199
+ if act is not None:
200
+ act_layer = get_act_layer(act)
201
+ else:
202
+ act_layer = nn.Identity()
203
+ super().__init__(conv_layer, up_layer, act_layer)
204
+
205
+
206
+ class FlexibleUNet_star(nn.Module):
207
+ """
208
+ A flexible implementation of UNet-like encoder-decoder architecture.
209
+ """
210
+
211
+ def __init__(
212
+ self,
213
+ in_channels: int,
214
+ out_channels: int,
215
+ backbone: str,
216
+ pretrained: bool = False,
217
+ decoder_channels: Tuple = (256, 128, 64, 32),
218
+ #decoder_channels: Tuple = (1024, 512, 256, 128),
219
+ spatial_dims: int = 2,
220
+ norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}),
221
+ act: Union[str, tuple] = ("relu", {"inplace": True}),
222
+ dropout: Union[float, tuple] = 0.0,
223
+ decoder_bias: bool = False,
224
+ upsample: str = "nontrainable",
225
+ interp_mode: str = "nearest",
226
+ is_pad: bool = True,
227
+ n_rays: int = 32,
228
+ prob_out_channels: int = 1,
229
+ ) -> None:
230
+ """
231
+ A flexible implement of UNet, in which the backbone/encoder can be replaced with
232
+ any efficient network. Currently the input must have a 2 or 3 spatial dimension
233
+ and the spatial size of each dimension must be a multiple of 32 if is pad parameter
234
+ is False
235
+
236
+ Args:
237
+ in_channels: number of input channels.
238
+ out_channels: number of output channels.
239
+ backbone: name of backbones to initialize, only support efficientnet right now,
240
+ can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
241
+ pretrained: whether to initialize pretrained ImageNet weights, only available
242
+ for spatial_dims=2 and batch norm is used, default to False.
243
+ decoder_channels: number of output channels for all feature maps in decoder.
244
+ `len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
245
+ to (256, 128, 64, 32, 16).
246
+ spatial_dims: number of spatial dimensions, default to 2.
247
+ norm: normalization type and arguments, default to ("batch", {"eps": 1e-3,
248
+ "momentum": 0.1}).
249
+ act: activation type and arguments, default to ("relu", {"inplace": True}).
250
+ dropout: dropout ratio, default to 0.0.
251
+ decoder_bias: whether to have a bias term in decoder's convolution blocks.
252
+ upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``,
253
+ ``"nontrainable"``.
254
+ interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
255
+ Only used in the "nontrainable" mode.
256
+ is_pad: whether to pad upsampling features to fit features from encoder. Default to True.
257
+ If this parameter is set to "True", the spatial dim of network input can be arbitary
258
+ size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32.
259
+ """
260
+ super().__init__()
261
+
262
+ if backbone not in encoder_feature_channel:
263
+ raise ValueError(f"invalid model_name {backbone} found, must be one of {encoder_feature_channel.keys()}.")
264
+
265
+ if spatial_dims not in (2, 3):
266
+ raise ValueError("spatial_dims can only be 2 or 3.")
267
+
268
+ adv_prop = "ap" in backbone
269
+
270
+ self.backbone = backbone
271
+ self.spatial_dims = spatial_dims
272
+ model_name = backbone
273
+ encoder_channels = _get_encoder_channels_by_backbone(backbone, in_channels)
274
+
275
+ self.encoder = convnext.convnext_small(pretrained=False,in_22k=True)
276
+
277
+ self.decoder = UNetDecoder(
278
+ spatial_dims=spatial_dims,
279
+ encoder_channels=encoder_channels,
280
+ decoder_channels=decoder_channels,
281
+ act=act,
282
+ norm=norm,
283
+ dropout=dropout,
284
+ bias=decoder_bias,
285
+ upsample=upsample,
286
+ interp_mode=interp_mode,
287
+ pre_conv=None,
288
+ align_corners=None,
289
+ is_pad=is_pad,
290
+ )
291
+ self.dist_head = SegmentationHead(
292
+ spatial_dims=spatial_dims,
293
+ in_channels=decoder_channels[-1],
294
+ out_channels=n_rays,
295
+ kernel_size=1,
296
+ act='relu',
297
+ scale_factor = 2,
298
+ )
299
+ self.prob_head = SegmentationHead(
300
+ spatial_dims=spatial_dims,
301
+ in_channels=decoder_channels[-1],
302
+ out_channels=prob_out_channels,
303
+ kernel_size=1,
304
+ act='sigmoid',
305
+ scale_factor = 2,
306
+ )
307
+
308
+ def forward(self, inputs: torch.Tensor):
309
+ """
310
+ Do a typical encoder-decoder-header inference.
311
+
312
+ Args:
313
+ inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``,
314
+ N is defined by `dimensions`.
315
+
316
+ Returns:
317
+ A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.
318
+
319
+ """
320
+ x = inputs
321
+ enc_out = self.encoder(x)
322
+ decoder_out = self.decoder(enc_out)
323
+
324
+ dist = self.dist_head(decoder_out)
325
+ prob = self.prob_head(decoder_out)
326
+
327
+ return dist,prob
328
+
329
+
330
+
331
+ class FlexibleUNet_hv(nn.Module):
332
+ """
333
+ A flexible implementation of UNet-like encoder-decoder architecture.
334
+ """
335
+
336
+ def __init__(
337
+ self,
338
+ in_channels: int,
339
+ out_channels: int,
340
+ backbone: str,
341
+ pretrained: bool = False,
342
+ decoder_channels: Tuple = (1024, 512, 256, 128),
343
+ spatial_dims: int = 2,
344
+ norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}),
345
+ act: Union[str, tuple] = ("relu", {"inplace": True}),
346
+ dropout: Union[float, tuple] = 0.0,
347
+ decoder_bias: bool = False,
348
+ upsample: str = "nontrainable",
349
+ interp_mode: str = "nearest",
350
+ is_pad: bool = True,
351
+ n_rays: int = 32,
352
+ prob_out_channels: int = 1,
353
+ ) -> None:
354
+ """
355
+ A flexible implement of UNet, in which the backbone/encoder can be replaced with
356
+ any efficient network. Currently the input must have a 2 or 3 spatial dimension
357
+ and the spatial size of each dimension must be a multiple of 32 if is pad parameter
358
+ is False
359
+
360
+ Args:
361
+ in_channels: number of input channels.
362
+ out_channels: number of output channels.
363
+ backbone: name of backbones to initialize, only support efficientnet right now,
364
+ can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
365
+ pretrained: whether to initialize pretrained ImageNet weights, only available
366
+ for spatial_dims=2 and batch norm is used, default to False.
367
+ decoder_channels: number of output channels for all feature maps in decoder.
368
+ `len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
369
+ to (256, 128, 64, 32, 16).
370
+ spatial_dims: number of spatial dimensions, default to 2.
371
+ norm: normalization type and arguments, default to ("batch", {"eps": 1e-3,
372
+ "momentum": 0.1}).
373
+ act: activation type and arguments, default to ("relu", {"inplace": True}).
374
+ dropout: dropout ratio, default to 0.0.
375
+ decoder_bias: whether to have a bias term in decoder's convolution blocks.
376
+ upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``,
377
+ ``"nontrainable"``.
378
+ interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
379
+ Only used in the "nontrainable" mode.
380
+ is_pad: whether to pad upsampling features to fit features from encoder. Default to True.
381
+ If this parameter is set to "True", the spatial dim of network input can be arbitary
382
+ size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32.
383
+ """
384
+ super().__init__()
385
+
386
+ if backbone not in encoder_feature_channel:
387
+ raise ValueError(f"invalid model_name {backbone} found, must be one of {encoder_feature_channel.keys()}.")
388
+
389
+ if spatial_dims not in (2, 3):
390
+ raise ValueError("spatial_dims can only be 2 or 3.")
391
+
392
+ adv_prop = "ap" in backbone
393
+
394
+ self.backbone = backbone
395
+ self.spatial_dims = spatial_dims
396
+ model_name = backbone
397
+ encoder_channels = _get_encoder_channels_by_backbone(backbone, in_channels)
398
+ self.encoder = convnext.convnext_small(pretrained=False,in_22k=True)
399
+ self.decoder = UNetDecoder(
400
+ spatial_dims=spatial_dims,
401
+ encoder_channels=encoder_channels,
402
+ decoder_channels=decoder_channels,
403
+ act=act,
404
+ norm=norm,
405
+ dropout=dropout,
406
+ bias=decoder_bias,
407
+ upsample=upsample,
408
+ interp_mode=interp_mode,
409
+ pre_conv=None,
410
+ align_corners=None,
411
+ is_pad=is_pad,
412
+ )
413
+ self.dist_head = SegmentationHead(
414
+ spatial_dims=spatial_dims,
415
+ in_channels=decoder_channels[-1],
416
+ out_channels=n_rays,
417
+ kernel_size=1,
418
+ act=None,
419
+ scale_factor = 2,
420
+ )
421
+ self.prob_head = SegmentationHead(
422
+ spatial_dims=spatial_dims,
423
+ in_channels=decoder_channels[-1],
424
+ out_channels=prob_out_channels,
425
+ kernel_size=1,
426
+ act='sigmoid',
427
+ scale_factor = 2,
428
+ )
429
+
430
+ def forward(self, inputs: torch.Tensor):
431
+ """
432
+ Do a typical encoder-decoder-header inference.
433
+
434
+ Args:
435
+ inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``,
436
+ N is defined by `dimensions`.
437
+
438
+ Returns:
439
+ A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.
440
+
441
+ """
442
+ x = inputs
443
+ enc_out = self.encoder(x)
444
+ decoder_out = self.decoder(enc_out)
445
+ dist = self.dist_head(decoder_out)
446
+ prob = self.prob_head(decoder_out)
447
+ return dist,prob
overlay.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ ###overlay
5
+ import cv2
6
+ import math
7
+ import random
8
+ import colorsys
9
+ import numpy as np
10
+ import itertools
11
+ import matplotlib.pyplot as plt
12
+ from matplotlib import cm
13
+ import os
14
+ import scipy.io as io
15
+ def get_bounding_box(img):
16
+ """Get bounding box coordinate information."""
17
+ rows = np.any(img, axis=1)
18
+ cols = np.any(img, axis=0)
19
+ rmin, rmax = np.where(rows)[0][[0, -1]]
20
+ cmin, cmax = np.where(cols)[0][[0, -1]]
21
+ # due to python indexing, need to add 1 to max
22
+ # else accessing will be 1px in the box, not out
23
+ rmax += 1
24
+ cmax += 1
25
+ return [rmin, rmax, cmin, cmax]
26
+ ####
27
+ def colorize(ch, vmin, vmax):
28
+ """Will clamp value value outside the provided range to vmax and vmin."""
29
+ cmap = plt.get_cmap("jet")
30
+ ch = np.squeeze(ch.astype("float32"))
31
+ vmin = vmin if vmin is not None else ch.min()
32
+ vmax = vmax if vmax is not None else ch.max()
33
+ ch[ch > vmax] = vmax # clamp value
34
+ ch[ch < vmin] = vmin
35
+ ch = (ch - vmin) / (vmax - vmin + 1.0e-16)
36
+ # take RGB from RGBA heat map
37
+ ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8")
38
+ return ch_cmap
39
+
40
+
41
+ ####
42
+ def random_colors(N, bright=True):
43
+ """Generate random colors.
44
+
45
+ To get visually distinct colors, generate them in HSV space then
46
+ convert to RGB.
47
+ """
48
+ brightness = 1.0 if bright else 0.7
49
+ hsv = [(i / N, 1, brightness) for i in range(N)]
50
+ colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
51
+ random.shuffle(colors)
52
+ return colors
53
+
54
+
55
+ ####
56
+ def visualize_instances_map(
57
+ input_image, inst_map, type_map=None, type_colour=None, line_thickness=2
58
+ ):
59
+ """Overlays segmentation results on image as contours.
60
+
61
+ Args:
62
+ input_image: input image
63
+ inst_map: instance mask with unique value for every object
64
+ type_map: type mask with unique value for every class
65
+ type_colour: a dict of {type : colour} , `type` is from 0-N
66
+ and `colour` is a tuple of (R, G, B)
67
+ line_thickness: line thickness of contours
68
+
69
+ Returns:
70
+ overlay: output image with segmentation overlay as contours
71
+ """
72
+ overlay = np.copy((input_image).astype(np.uint8))
73
+
74
+ inst_list = list(np.unique(inst_map)) # get list of instances
75
+ inst_list.remove(0) # remove background
76
+
77
+ inst_rng_colors = random_colors(len(inst_list))
78
+ inst_rng_colors = np.array(inst_rng_colors) * 255
79
+ inst_rng_colors = inst_rng_colors.astype(np.uint8)
80
+
81
+ for inst_idx, inst_id in enumerate(inst_list):
82
+ inst_map_mask = np.array(inst_map == inst_id, np.uint8) # get single object
83
+ y1, y2, x1, x2 = get_bounding_box(inst_map_mask)
84
+ y1 = y1 - 2 if y1 - 2 >= 0 else y1
85
+ x1 = x1 - 2 if x1 - 2 >= 0 else x1
86
+ x2 = x2 + 2 if x2 + 2 <= inst_map.shape[1] - 1 else x2
87
+ y2 = y2 + 2 if y2 + 2 <= inst_map.shape[0] - 1 else y2
88
+ inst_map_crop = inst_map_mask[y1:y2, x1:x2]
89
+ contours_crop = cv2.findContours(
90
+ inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
91
+ )
92
+ # only has 1 instance per map, no need to check #contour detected by opencv
93
+ #print(contours_crop)
94
+ contours_crop = np.squeeze(
95
+ contours_crop[0][0].astype("int32")
96
+ ) # * opencv protocol format may break
97
+
98
+ if len(contours_crop.shape) == 1:
99
+ contours_crop = contours_crop.reshape(1,-1)
100
+ #print(contours_crop.shape)
101
+ contours_crop += np.asarray([[x1, y1]]) # index correction
102
+ if type_map is not None:
103
+ type_map_crop = type_map[y1:y2, x1:x2]
104
+ type_id = np.unique(type_map_crop).max() # non-zero
105
+ inst_colour = type_colour[type_id]
106
+ else:
107
+ inst_colour = (inst_rng_colors[inst_idx]).tolist()
108
+ cv2.drawContours(overlay, [contours_crop], -1, inst_colour, line_thickness)
109
+ return overlay
110
+
111
+
112
+ # In[ ]:
113
+
114
+
115
+
116
+
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d978b42e9e63e949f0dcd3685be14146b6c5b5bfb48f703bfbc308b4ac190b64
3
+ size 135
requirements.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gputools==0.2.13
2
+ h5py==3.7.0
3
+ huggingface-hub==0.10.1
4
+ imagecodecs
5
+ imageio==2.22.2
6
+ importlib-metadata==5.0.0
7
+ kiwisolver==1.4.4
8
+ llvmlite==0.39.1
9
+ Mako==1.2.3
10
+ Markdown==3.4.1
11
+ MarkupSafe==2.1.1
12
+ matplotlib==3.6.1
13
+ mkl-fft==1.3.1
14
+ mkl-service==2.4.0
15
+ monai==1.0.0
16
+ networkx==2.8.7
17
+ numba==0.56.3
18
+ numexpr
19
+ numpy
20
+ oauthlib==3.2.2
21
+ opencv-python==4.6.0.66
22
+ packaging
23
+ pandas==1.4.4
24
+ Pillow==9.2.0
25
+ scikit-image==0.19.3
26
+ scipy==1.9.2
27
+ stardist==0.8.3
28
+ tensorboard==2.10.1
29
+ tensorboard-data-server==0.6.1
30
+ tensorboard-plugin-wit==1.8.1
31
+ tifffile==2022.10.10
32
+ timm==0.6.11
33
+ torch==1.12.1
34
+ torchaudio==0.12.1
35
+ torchvision==0.13.1
36
+ tqdm==4.64.1
37
+
sribd_cellseg_models.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ join = os.path.join
4
+ import argparse
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from collections import OrderedDict
9
+ from torchvision import datasets, models, transforms
10
+ from classifiers import resnet10, resnet18
11
+
12
+ from utils_modify import sliding_window_inference,sliding_window_inference_large,__proc_np_hv
13
+ from PIL import Image
14
+ import torch.nn.functional as F
15
+ from skimage import io, segmentation, morphology, measure, exposure
16
+ import tifffile as tif
17
+ from models.flexible_unet_convnext import FlexibleUNet_star,FlexibleUNet_hv
18
+ from transformers import PretrainedConfig
19
+ from typing import List
20
+ from transformers import PreTrainedModel
21
+ from huggingface_hub import PyTorchModelHubMixin
22
+ from torch import nn
23
+ class ModelConfig(PretrainedConfig):
24
+ model_type = "cell_sribd"
25
+ def __init__(
26
+ self,
27
+ version = 1,
28
+ input_channels: int = 3,
29
+ roi_size: int = 512,
30
+ overlap: float = 0.5,
31
+ device: str = 'cpu',
32
+ **kwargs,
33
+ ):
34
+
35
+ self.device = device
36
+ self.roi_size = (roi_size, roi_size)
37
+ self.input_channels = input_channels
38
+ self.overlap = overlap
39
+ self.np_thres, self.ksize, self.overall_thres, self.obj_size_thres = 0.6, 15, 0.4, 100
40
+ self.n_rays = 32
41
+ self.sw_batch_size = 4
42
+ self.num_classes= 4
43
+ self.block_size = 2048
44
+ self.min_overlap = 128
45
+ self.context = 128
46
+ super().__init__(**kwargs)
47
+
48
+
49
+ class MultiStreamCellSegModel(PreTrainedModel):
50
+ config_class = ModelConfig
51
+ #print(config.input_channels)
52
+ def __init__(self, config):
53
+ super().__init__(config)
54
+ #print(config.input_channels)
55
+ self.config = config
56
+ self.cls_model = resnet18()
57
+ self.model0 = FlexibleUNet_star(in_channels=config.input_channels,out_channels=config.n_rays+1,backbone='convnext_small',pretrained=False,n_rays=config.n_rays,prob_out_channels=1,)
58
+ self.model1 = FlexibleUNet_star(in_channels=config.input_channels,out_channels=config.n_rays+1,backbone='convnext_small',pretrained=False,n_rays=config.n_rays,prob_out_channels=1,)
59
+ self.model2 = FlexibleUNet_star(in_channels=config.input_channels,out_channels=config.n_rays+1,backbone='convnext_small',pretrained=False,n_rays=config.n_rays,prob_out_channels=1,)
60
+ self.model3 = FlexibleUNet_hv(in_channels=config.input_channels,out_channels=2+2,backbone='convnext_small',pretrained=False,n_rays=2,prob_out_channels=2,)
61
+ self.preprocess=transforms.Compose([
62
+ transforms.Resize(size=256),
63
+ transforms.CenterCrop(size=224),
64
+ transforms.ToTensor(),
65
+ transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
66
+ def load_checkpoints(self,checkpoints):
67
+ self.cls_model.load_state_dict(checkpoints['cls_model'])
68
+ self.model0.load_state_dict(checkpoints['class1_model']['model_state_dict'])
69
+ self.model1.load_state_dict(checkpoints['class2_model']['model_state_dict'])
70
+ self.model2.load_state_dict(checkpoints['class3_model']['model_state_dict'])
71
+ self.model3.load_state_dict(checkpoints['class4_model'])
72
+
73
+ def forward(self, pre_img_data):
74
+ inputs=self.preprocess(Image.fromarray(pre_img_data)).unsqueeze(0)
75
+ outputs = self.cls_model(inputs)
76
+ _, preds = torch.max(outputs, 1)
77
+ label=preds[0].cpu().numpy()
78
+ test_npy01 = pre_img_data
79
+ if label in [0,1,2]:
80
+ if label == 0:
81
+ output_label = sliding_window_inference_large(test_npy01,self.config.block_size,self.config.min_overlap,self.config.context, self.config.roi_size,self.config.sw_batch_size,predictor=self.model0,device=self.config.device)
82
+ elif label == 1:
83
+ output_label = sliding_window_inference_large(test_npy01,self.config.block_size,self.config.min_overlap,self.config.context, self.config.roi_size,self.config.sw_batch_size,predictor=self.model1,device=self.config.device)
84
+ elif label == 2:
85
+ output_label = sliding_window_inference_large(test_npy01,self.config.block_size,self.config.min_overlap,self.config.context, self.config.roi_size,self.config.sw_batch_size,predictor=self.model2,device=self.config.device)
86
+ else:
87
+ test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0, 3, 1, 2).type(torch.FloatTensor)
88
+
89
+ output_hv, output_np = sliding_window_inference(test_tensor, self.config.roi, self.config.sw_batch_size, self.model3, overlap=self.config.overlap,device=self.config.device)
90
+ pred_dict = {'np': output_np, 'hv': output_hv}
91
+ pred_dict = OrderedDict(
92
+ [[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] # NHWC
93
+ )
94
+ pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1:]
95
+ pred_output = torch.cat(list(pred_dict.values()), -1).cpu().numpy() # NHW3
96
+ pred_map = np.squeeze(pred_output) # HW3
97
+ pred_inst = __proc_np_hv(pred_map, self.config.np_thres, self.config.ksize, self.config.overall_thres, self.config.obj_size_thres)
98
+ raw_pred_shape = pred_inst.shape[:2]
99
+ output_label = pred_inst
100
+ return output_label
stardist_pkg/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, print_function
2
+
3
+ import warnings
4
+ def format_warning(message, category, filename, lineno, line=''):
5
+ import pathlib
6
+ return f"{pathlib.Path(filename).name} ({lineno}): {message}\n"
7
+ warnings.formatwarning = format_warning
8
+ del warnings
9
+
10
+ from .version import __version__
11
+
12
+ # TODO: which functions to expose here? all?
13
+ from .nms import non_maximum_suppression
14
+ from .utils import edt_prob, fill_label_holes, sample_points, calculate_extents, export_imagej_rois, gputools_available
15
+ from .geometry import star_dist, polygons_to_label, relabel_image_stardist, ray_angles, dist_to_coord
16
+ from .sample_patches import sample_patches
17
+ from .bioimageio_utils import export_bioimageio, import_bioimageio
18
+
19
+ def _py_deprecation(ver_python=(3,6), ver_stardist='0.9.0'):
20
+ import sys
21
+ from distutils.version import LooseVersion
22
+ if sys.version_info[:2] == ver_python and LooseVersion(__version__) < LooseVersion(ver_stardist):
23
+ print(f"You are using Python {ver_python[0]}.{ver_python[1]}, which will no longer be supported in StarDist {ver_stardist}.\n"
24
+ f"→ Please upgrade to Python {ver_python[0]}.{ver_python[1]+1} or later.", file=sys.stderr, flush=True)
25
+ _py_deprecation()
26
+ del _py_deprecation
stardist_pkg/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (1.62 kB). View file
 
stardist_pkg/__pycache__/big.cpython-37.pyc ADDED
Binary file (20.7 kB). View file
 
stardist_pkg/__pycache__/bioimageio_utils.cpython-37.pyc ADDED
Binary file (15.2 kB). View file
 
stardist_pkg/__pycache__/matching.cpython-37.pyc ADDED
Binary file (16.9 kB). View file
 
stardist_pkg/__pycache__/nms.cpython-37.pyc ADDED
Binary file (9.59 kB). View file
 
stardist_pkg/__pycache__/sample_patches.cpython-37.pyc ADDED
Binary file (4.22 kB). View file
 
stardist_pkg/__pycache__/utils.cpython-37.pyc ADDED
Binary file (15.4 kB). View file
 
stardist_pkg/__pycache__/version.cpython-37.pyc ADDED
Binary file (199 Bytes). View file
 
stardist_pkg/big.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import warnings
3
+ import math
4
+ from tqdm import tqdm
5
+ from skimage.measure import regionprops
6
+ from skimage.draw import polygon
7
+ from csbdeep.utils import _raise, axes_check_and_normalize, axes_dict
8
+ from itertools import product
9
+
10
+
11
+
12
+
13
+ OBJECT_KEYS = set(('prob', 'points', 'coord', 'dist', 'class_prob', 'class_id'))
14
+ COORD_KEYS = set(('points', 'coord'))
15
+
16
+
17
+
18
+ class Block:
19
+ """One-dimensional block as part of a chain.
20
+
21
+ There are no explicit start and end positions. Instead, each block is
22
+ aware of its predecessor and successor and derives such things (recursively)
23
+ based on its neighbors.
24
+
25
+ Blocks overlap with one another (at least min_overlap + 2*context) and
26
+ have a read region (the entire block) and a write region (ignoring context).
27
+ Given a query interval, Block.is_responsible will return true for only one
28
+ block of a chain (or raise an exception if the interval is larger than
29
+ min_overlap or even the entire block without context).
30
+
31
+ """
32
+ def __init__(self, size, min_overlap, context, pred):
33
+ self.size = int(size)
34
+ self.min_overlap = int(min_overlap)
35
+ self.context = int(context)
36
+ self.pred = pred
37
+ self.succ = None
38
+ assert 0 <= self.min_overlap + 2*self.context < self.size
39
+ self.stride = self.size - (self.min_overlap + 2*self.context)
40
+ self._start = 0
41
+ self._frozen = False
42
+
43
+ @property
44
+ def start(self):
45
+ return self._start if (self.frozen or self.at_begin) else self.pred.succ_start
46
+
47
+ @property
48
+ def end(self):
49
+ return self.start + self.size
50
+
51
+ @property
52
+ def succ_start(self):
53
+ return self.start + self.stride
54
+
55
+ def add_succ(self):
56
+ assert self.succ is None and not self.frozen
57
+ self.succ = Block(self.size, self.min_overlap, self.context, self)
58
+ return self.succ
59
+
60
+ def decrease_stride(self, amount):
61
+ amount = int(amount)
62
+ assert 0 <= amount < self.stride and not self.frozen
63
+ self.stride -= amount
64
+
65
+ def freeze(self):
66
+ """Call on first block to freeze entire chain (after construction is done)"""
67
+ assert not self.frozen and (self.at_begin or self.pred.frozen)
68
+ self._start = self.start
69
+ self._frozen = True
70
+ if not self.at_end:
71
+ self.succ.freeze()
72
+
73
+ @property
74
+ def slice_read(self):
75
+ return slice(self.start, self.end)
76
+
77
+ @property
78
+ def slice_crop_context(self):
79
+ """Crop context relative to read region"""
80
+ return slice(self.context_start, self.size - self.context_end)
81
+
82
+ @property
83
+ def slice_write(self):
84
+ return slice(self.start + self.context_start, self.end - self.context_end)
85
+
86
+ def is_responsible(self, bbox):
87
+ """Responsibility for query interval bbox, which is assumed to be smaller than min_overlap.
88
+
89
+ If the assumption is met, only one block of a chain will return true.
90
+ If violated, one or more blocks of a chain may raise a NotFullyVisible exception.
91
+ The exception will have an argument that is
92
+ False if bbox is larger than min_overlap, and
93
+ True if bbox is even larger than the entire block without context.
94
+
95
+ bbox: (int,int)
96
+ 1D bounding box interval with coordinates relative to size without context
97
+
98
+ """
99
+ bmin, bmax = bbox
100
+
101
+ r_start = 0 if self.at_begin else (self.pred.overlap - self.pred.context_end - self.context_start)
102
+ r_end = self.size - self.context_start - self.context_end
103
+ assert 0 <= bmin < bmax <= r_end
104
+
105
+ # assert not (bmin == 0 and bmax >= r_start and not self.at_begin), [(r_start,r_end), bbox, self]
106
+
107
+ if bmin == 0 and bmax >= r_start:
108
+ if bmax == r_end:
109
+ # object spans the entire block, i.e. is probably larger than size (minus the context)
110
+ raise NotFullyVisible(True)
111
+ if not self.at_begin:
112
+ # object spans the entire overlap region, i.e. is only partially visible here and also by the predecessor block
113
+ raise NotFullyVisible(False)
114
+
115
+ # object ends before responsible region start
116
+ if bmax < r_start: return False
117
+ # object touches the end of the responsible region (only take if at end)
118
+ if bmax == r_end and not self.at_end: return False
119
+ return True
120
+
121
+ # ------------------------
122
+
123
+ @property
124
+ def frozen(self):
125
+ return self._frozen
126
+
127
+ @property
128
+ def at_begin(self):
129
+ return self.pred is None
130
+
131
+ @property
132
+ def at_end(self):
133
+ return self.succ is None
134
+
135
+ @property
136
+ def overlap(self):
137
+ return self.size - self.stride
138
+
139
+ @property
140
+ def context_start(self):
141
+ return 0 if self.at_begin else self.context
142
+
143
+ @property
144
+ def context_end(self):
145
+ return 0 if self.at_end else self.context
146
+
147
+ def __repr__(self):
148
+ shared = f'{self.start:03}:{self.end:03}'
149
+ shared += f', size={self.context_start}-{self.size-self.context_start-self.context_end}-{self.context_end}'
150
+ if self.at_end:
151
+ return f'{self.__class__.__name__}({shared})'
152
+ else:
153
+ return f'{self.__class__.__name__}({shared}, overlap={self.overlap}/{self.overlap-self.context_start-self.context_end})'
154
+
155
+ @property
156
+ def chain(self):
157
+ blocks = [self]
158
+ while not blocks[-1].at_end:
159
+ blocks.append(blocks[-1].succ)
160
+ return blocks
161
+
162
+ def __iter__(self):
163
+ return iter(self.chain)
164
+
165
+ # ------------------------
166
+
167
+ @staticmethod
168
+ def cover(size, block_size, min_overlap, context, grid=1, verbose=True):
169
+ """Return chain of grid-aligned blocks to cover the interval [0,size].
170
+
171
+ Parameters block_size, min_overlap, and context will be used
172
+ for all blocks of the chain. Only the size of the last block
173
+ may differ.
174
+
175
+ Except for the last block, start and end positions of all blocks will
176
+ be multiples of grid. To that end, the provided block parameters may
177
+ be increased to achieve that.
178
+
179
+ Note that parameters must be chosen such that the write regions of only
180
+ neighboring blocks are overlapping.
181
+
182
+ """
183
+ assert 0 <= min_overlap+2*context < block_size <= size
184
+ assert 0 < grid <= block_size
185
+ block_size = _grid_divisible(grid, block_size, name='block_size', verbose=verbose)
186
+ min_overlap = _grid_divisible(grid, min_overlap, name='min_overlap', verbose=verbose)
187
+ context = _grid_divisible(grid, context, name='context', verbose=verbose)
188
+
189
+ # allow size not to be divisible by grid
190
+ size_orig = size
191
+ size = _grid_divisible(grid, size, name='size', verbose=False)
192
+
193
+ # divide all sizes by grid
194
+ assert all(v % grid == 0 for v in (size, block_size, min_overlap, context))
195
+ size //= grid
196
+ block_size //= grid
197
+ min_overlap //= grid
198
+ context //= grid
199
+
200
+ # compute cover in grid-multiples
201
+ t = first = Block(block_size, min_overlap, context, None)
202
+ while t.end < size:
203
+ t = t.add_succ()
204
+ last = t
205
+
206
+ # [print(t) for t in first]
207
+
208
+ # move blocks around to make it fit
209
+ excess = last.end - size
210
+ t = first
211
+ while excess > 0:
212
+ t.decrease_stride(1)
213
+ excess -= 1
214
+ t = t.succ
215
+ if (t == last): t = first
216
+
217
+ # make a copy of the cover and multiply sizes by grid
218
+ if grid > 1:
219
+ size *= grid
220
+ block_size *= grid
221
+ min_overlap *= grid
222
+ context *= grid
223
+ #
224
+ _t = _first = first
225
+ t = first = Block(block_size, min_overlap, context, None)
226
+ t.stride = _t.stride*grid
227
+ while not _t.at_end:
228
+ _t = _t.succ
229
+ t = t.add_succ()
230
+ t.stride = _t.stride*grid
231
+ last = t
232
+
233
+ # change size of last block
234
+ # will be padded internally to the same size
235
+ # as the others by model.predict_instances
236
+ size_delta = size - size_orig
237
+ last.size -= size_delta
238
+ assert 0 <= size_delta < grid
239
+
240
+ # for efficiency (to not determine starts recursively from now on)
241
+ first.freeze()
242
+
243
+ blocks = first.chain
244
+
245
+ # sanity checks
246
+ assert first.start == 0 and last.end == size_orig
247
+ assert all(t.overlap-2*context >= min_overlap for t in blocks if t != last)
248
+ assert all(t.start % grid == 0 and t.end % grid == 0 for t in blocks if t != last)
249
+ # print(); [print(t) for t in first]
250
+
251
+ # only neighboring blocks should be overlapping
252
+ if len(blocks) >= 3:
253
+ for t in blocks[:-2]:
254
+ assert t.slice_write.stop <= t.succ.succ.slice_write.start
255
+
256
+ return blocks
257
+
258
+
259
+
260
+ class BlockND:
261
+ """N-dimensional block.
262
+
263
+ Each BlockND simply consists of a 1-dimensional Block per axis and also
264
+ has an id (which should be unique). The n-dimensional region represented
265
+ by each BlockND is the intersection of all 1D Blocks per axis.
266
+
267
+ Also see `Block`.
268
+
269
+ """
270
+ def __init__(self, id, blocks, axes):
271
+ self.id = id
272
+ self.blocks = tuple(blocks)
273
+ self.axes = axes_check_and_normalize(axes, length=len(self.blocks))
274
+ self.axis_to_block = dict(zip(self.axes,self.blocks))
275
+
276
+ def blocks_for_axes(self, axes=None):
277
+ axes = self.axes if axes is None else axes_check_and_normalize(axes)
278
+ return tuple(self.axis_to_block[a] for a in axes)
279
+
280
+ def slice_read(self, axes=None):
281
+ return tuple(t.slice_read for t in self.blocks_for_axes(axes))
282
+
283
+ def slice_crop_context(self, axes=None):
284
+ return tuple(t.slice_crop_context for t in self.blocks_for_axes(axes))
285
+
286
+ def slice_write(self, axes=None):
287
+ return tuple(t.slice_write for t in self.blocks_for_axes(axes))
288
+
289
+ def read(self, x, axes=None):
290
+ """Read block "read region" from x (numpy.ndarray or similar)"""
291
+ return x[self.slice_read(axes)]
292
+
293
+ def crop_context(self, labels, axes=None):
294
+ return labels[self.slice_crop_context(axes)]
295
+
296
+ def write(self, x, labels, axes=None):
297
+ """Write (only entries > 0 of) labels to block "write region" of x (numpy.ndarray or similar)"""
298
+ s = self.slice_write(axes)
299
+ mask = labels > 0
300
+ # x[s][mask] = labels[mask] # doesn't work with zarr
301
+ region = x[s]
302
+ region[mask] = labels[mask]
303
+ x[s] = region
304
+
305
+ def is_responsible(self, slices, axes=None):
306
+ return all(t.is_responsible((s.start,s.stop)) for t,s in zip(self.blocks_for_axes(axes),slices))
307
+
308
+ def __repr__(self):
309
+ slices = ','.join(f'{a}={t.start:03}:{t.end:03}' for t,a in zip(self.blocks,self.axes))
310
+ return f'{self.__class__.__name__}({self.id}|{slices})'
311
+
312
+ def __iter__(self):
313
+ return iter(self.blocks)
314
+
315
+ # ------------------------
316
+
317
+ def filter_objects(self, labels, polys, axes=None):
318
+ """Filter out objects that block is not responsible for.
319
+
320
+ Given label image 'labels' and dictionary 'polys' of polygon/polyhedron objects,
321
+ only retain those objects that this block is responsible for.
322
+
323
+ This function will return a pair (labels, polys) of the modified label image and dictionary.
324
+ It will raise a RuntimeError if an object is found in the overlap area
325
+ of neighboring blocks that violates the assumption to be smaller than 'min_overlap'.
326
+
327
+ If parameter 'polys' is None, only the filtered label image will be returned.
328
+
329
+ Notes
330
+ -----
331
+ - Important: It is assumed that the object label ids in 'labels' and
332
+ the entries in 'polys' are sorted in the same way.
333
+ - Does not modify 'labels' and 'polys', but returns modified copies.
334
+
335
+ Example
336
+ -------
337
+ >>> labels, polys = model.predict_instances(block.read(img))
338
+ >>> labels = block.crop_context(labels)
339
+ >>> labels, polys = block.filter_objects(labels, polys)
340
+
341
+ """
342
+ # TODO: option to update labels in-place
343
+ assert np.issubdtype(labels.dtype, np.integer)
344
+ ndim = len(self.blocks_for_axes(axes))
345
+ assert ndim in (2,3)
346
+ assert labels.ndim == ndim and labels.shape == tuple(s.stop-s.start for s in self.slice_crop_context(axes))
347
+
348
+ labels_filtered = np.zeros_like(labels)
349
+ # problem_ids = []
350
+ for r in regionprops(labels):
351
+ slices = tuple(slice(r.bbox[i],r.bbox[i+labels.ndim]) for i in range(labels.ndim))
352
+ try:
353
+ if self.is_responsible(slices, axes):
354
+ labels_filtered[slices][r.image] = r.label
355
+ except NotFullyVisible as e:
356
+ # shape_block_write = tuple(s.stop-s.start for s in self.slice_write(axes))
357
+ shape_object = tuple(s.stop-s.start for s in slices)
358
+ shape_min_overlap = tuple(t.min_overlap for t in self.blocks_for_axes(axes))
359
+ raise RuntimeError(f"Found object of shape {shape_object}, which violates the assumption of being smaller than 'min_overlap' {shape_min_overlap}. Increase 'min_overlap' to avoid this problem.")
360
+
361
+ # if e.args[0]: # object larger than block write region
362
+ # assert any(o >= b for o,b in zip(shape_object,shape_block_write))
363
+ # # problem, since this object will probably be saved by another block too
364
+ # raise RuntimeError(f"Found object of shape {shape_object}, larger than an entire block's write region of shape {shape_block_write}. Increase 'block_size' to avoid this problem.")
365
+ # # print("found object larger than 'block_size'")
366
+ # else:
367
+ # assert any(o >= b for o,b in zip(shape_object,shape_min_overlap))
368
+ # # print("found object larger than 'min_overlap'")
369
+
370
+ # # keep object, because will be dealt with later, i.e.
371
+ # # render the poly again into the label image, but this is not
372
+ # # ideal since the assumption is that the object outside that
373
+ # # region is not reliable because it's in the context
374
+ # labels_filtered[slices][r.image] = r.label
375
+ # problem_ids.append(r.label)
376
+
377
+ if polys is None:
378
+ # assert len(problem_ids) == 0
379
+ return labels_filtered
380
+ else:
381
+ # it is assumed that ids in 'labels' map to entries in 'polys'
382
+ assert isinstance(polys,dict) and any(k in polys for k in COORD_KEYS)
383
+ filtered_labels = np.unique(labels_filtered)
384
+ filtered_ind = [i-1 for i in filtered_labels if i > 0]
385
+ polys_out = {k: (v[filtered_ind] if k in OBJECT_KEYS else v) for k,v in polys.items()}
386
+ for k in COORD_KEYS:
387
+ if k in polys_out.keys():
388
+ polys_out[k] = self.translate_coordinates(polys_out[k], axes=axes)
389
+
390
+ return labels_filtered, polys_out#, tuple(problem_ids)
391
+
392
+ def translate_coordinates(self, coordinates, axes=None):
393
+ """Translate local block coordinates (of read region) to global ones based on block position"""
394
+ ndim = len(self.blocks_for_axes(axes))
395
+ assert isinstance(coordinates, np.ndarray) and coordinates.ndim >= 2 and coordinates.shape[1] == ndim
396
+ start = [s.start for s in self.slice_read(axes)]
397
+ shape = tuple(1 if d!=1 else ndim for d in range(coordinates.ndim))
398
+ start = np.array(start).reshape(shape)
399
+ return coordinates + start
400
+
401
+ # ------------------------
402
+
403
+ @staticmethod
404
+ def cover(shape, axes, block_size, min_overlap, context, grid=1):
405
+ """Return grid-aligned n-dimensional blocks to cover region
406
+ of the given shape with axes semantics.
407
+
408
+ Parameters block_size, min_overlap, and context can be different per
409
+ dimension/axis (if provided as list) or the same (if provided as
410
+ scalar value).
411
+
412
+ Also see `Block.cover`.
413
+
414
+ """
415
+ shape = tuple(shape)
416
+ n = len(shape)
417
+ axes = axes_check_and_normalize(axes, length=n)
418
+ if np.isscalar(block_size): block_size = n*[block_size]
419
+ if np.isscalar(min_overlap): min_overlap = n*[min_overlap]
420
+ if np.isscalar(context): context = n*[context]
421
+ if np.isscalar(grid): grid = n*[grid]
422
+ assert n == len(block_size) == len(min_overlap) == len(context) == len(grid)
423
+
424
+ # compute cover for each dimension
425
+ cover_1d = [Block.cover(*args) for args in zip(shape, block_size, min_overlap, context, grid)]
426
+ # return cover as Cartesian product of 1-dimensional blocks
427
+ return tuple(BlockND(i,blocks,axes) for i,blocks in enumerate(product(*cover_1d)))
428
+
429
+
430
+
431
+ class Polygon:
432
+
433
+ def __init__(self, coord, bbox=None, shape_max=None):
434
+ self.bbox = self.coords_bbox(coord, shape_max=shape_max) if bbox is None else bbox
435
+ self.coord = coord - np.array([r[0] for r in self.bbox]).reshape(2,1)
436
+ self.slice = tuple(slice(*r) for r in self.bbox)
437
+ self.shape = tuple(r[1]-r[0] for r in self.bbox)
438
+ rr,cc = polygon(*self.coord, self.shape)
439
+ self.mask = np.zeros(self.shape, bool)
440
+ self.mask[rr,cc] = True
441
+
442
+ @staticmethod
443
+ def coords_bbox(*coords, shape_max=None):
444
+ assert all(isinstance(c, np.ndarray) and c.ndim==2 and c.shape[0]==2 for c in coords)
445
+ if shape_max is None:
446
+ shape_max = (np.inf, np.inf)
447
+ coord = np.concatenate(coords, axis=1)
448
+ mins = np.maximum(0, np.floor(np.min(coord,axis=1))).astype(int)
449
+ maxs = np.minimum(shape_max, np.ceil (np.max(coord,axis=1))).astype(int)
450
+ return tuple(zip(tuple(mins),tuple(maxs)))
451
+
452
+
453
+
454
+ class Polyhedron:
455
+
456
+ def __init__(self, dist, origin, rays, bbox=None, shape_max=None):
457
+ self.bbox = self.coords_bbox((dist, origin), rays=rays, shape_max=shape_max) if bbox is None else bbox
458
+ self.slice = tuple(slice(*r) for r in self.bbox)
459
+ self.shape = tuple(r[1]-r[0] for r in self.bbox)
460
+ _origin = origin.reshape(1,3) - np.array([r[0] for r in self.bbox]).reshape(1,3)
461
+ self.mask = polyhedron_to_label(dist[np.newaxis], _origin, rays, shape=self.shape, verbose=False).astype(bool)
462
+
463
+ @staticmethod
464
+ def coords_bbox(*dist_origin, rays, shape_max=None):
465
+ dists, points = zip(*dist_origin)
466
+ assert all(isinstance(d, np.ndarray) and d.ndim==1 and len(d)==len(rays) for d in dists)
467
+ assert all(isinstance(p, np.ndarray) and p.ndim==1 and len(p)==3 for p in points)
468
+ dists, points, verts = np.stack(dists)[...,np.newaxis], np.stack(points)[:,np.newaxis], rays.vertices[np.newaxis]
469
+ coord = dists * verts + points
470
+ coord = np.concatenate(coord, axis=0)
471
+ if shape_max is None:
472
+ shape_max = (np.inf, np.inf, np.inf)
473
+ mins = np.maximum(0, np.floor(np.min(coord,axis=0))).astype(int)
474
+ maxs = np.minimum(shape_max, np.ceil (np.max(coord,axis=0))).astype(int)
475
+ return tuple(zip(tuple(mins),tuple(maxs)))
476
+
477
+
478
+
479
+ # def repaint_labels(output, labels, polys, show_progress=True):
480
+ # """Repaint object instances in correct order based on probability scores.
481
+
482
+ # Does modify 'output' and 'polys' in-place, but will only write sparsely to 'output' where needed.
483
+
484
+ # output: numpy.ndarray or similar
485
+ # Label image (integer-valued)
486
+ # labels: iterable of int
487
+ # List of integer label ids that occur in output
488
+ # polys: dict
489
+ # Dictionary of polygon/polyhedra properties.
490
+ # Assumption is that the label id (-1) corresponds to the index in the polys dict
491
+
492
+ # """
493
+ # assert output.ndim in (2,3)
494
+
495
+ # if show_progress:
496
+ # labels = tqdm(labels, leave=True)
497
+
498
+ # labels_eliminated = set()
499
+
500
+ # # TODO: inelegant to have so much duplicated code here
501
+ # if output.ndim == 2:
502
+ # coord = lambda i: polys['coord'][i-1]
503
+ # prob = lambda i: polys['prob'][i-1]
504
+
505
+ # for i in labels:
506
+ # if i in labels_eliminated: continue
507
+ # poly_i = Polygon(coord(i), shape_max=output.shape)
508
+
509
+ # # find all labels that overlap with i (including i)
510
+ # overlapping = set(np.unique(output[poly_i.slice][poly_i.mask])) - {0}
511
+ # assert i in overlapping
512
+ # # compute bbox union to find area to crop/replace in large output label image
513
+ # bbox_union = Polygon.coords_bbox(*[coord(j) for j in overlapping], shape_max=output.shape)
514
+
515
+ # # crop out label i, including the region that include all overlapping labels
516
+ # poly_i = Polygon(coord(i), bbox=bbox_union)
517
+ # mask = poly_i.mask.copy()
518
+
519
+ # # remove pixels from mask that belong to labels with higher probability
520
+ # for j in [j for j in overlapping if prob(j) > prob(i)]:
521
+ # mask[ Polygon(coord(j), bbox=bbox_union).mask ] = False
522
+
523
+ # crop = output[poly_i.slice]
524
+ # crop[crop==i] = 0 # delete all remnants of i in crop
525
+ # crop[mask] = i # paint i where mask still active
526
+
527
+ # labels_remaining = set(np.unique(output[poly_i.slice][poly_i.mask])) - {0}
528
+ # labels_eliminated.update(overlapping - labels_remaining)
529
+ # else:
530
+
531
+ # dist = lambda i: polys['dist'][i-1]
532
+ # origin = lambda i: polys['points'][i-1]
533
+ # prob = lambda i: polys['prob'][i-1]
534
+ # rays = polys['rays']
535
+
536
+ # for i in labels:
537
+ # if i in labels_eliminated: continue
538
+ # poly_i = Polyhedron(dist(i), origin(i), rays, shape_max=output.shape)
539
+
540
+ # # find all labels that overlap with i (including i)
541
+ # overlapping = set(np.unique(output[poly_i.slice][poly_i.mask])) - {0}
542
+ # assert i in overlapping
543
+ # # compute bbox union to find area to crop/replace in large output label image
544
+ # bbox_union = Polyhedron.coords_bbox(*[(dist(j),origin(j)) for j in overlapping], rays=rays, shape_max=output.shape)
545
+
546
+ # # crop out label i, including the region that include all overlapping labels
547
+ # poly_i = Polyhedron(dist(i), origin(i), rays, bbox=bbox_union)
548
+ # mask = poly_i.mask.copy()
549
+
550
+ # # remove pixels from mask that belong to labels with higher probability
551
+ # for j in [j for j in overlapping if prob(j) > prob(i)]:
552
+ # mask[ Polyhedron(dist(j), origin(j), rays, bbox=bbox_union).mask ] = False
553
+
554
+ # crop = output[poly_i.slice]
555
+ # crop[crop==i] = 0 # delete all remnants of i in crop
556
+ # crop[mask] = i # paint i where mask still active
557
+
558
+ # labels_remaining = set(np.unique(output[poly_i.slice][poly_i.mask])) - {0}
559
+ # labels_eliminated.update(overlapping - labels_remaining)
560
+
561
+ # if len(labels_eliminated) > 0:
562
+ # ind = [i-1 for i in labels_eliminated]
563
+ # for k,v in polys.items():
564
+ # if k in OBJECT_KEYS:
565
+ # polys[k] = np.delete(v, ind, axis=0)
566
+
567
+
568
+
569
+ ############
570
+
571
+
572
+
573
+ def predict_big(model, *args, **kwargs):
574
+ from .models import StarDist2D, StarDist3D
575
+ if isinstance(model,(StarDist2D,StarDist3D)):
576
+ dst = model.__class__.__name__
577
+ else:
578
+ dst = '{StarDist2D, StarDist3D}'
579
+ raise RuntimeError(f"This function has moved to {dst}.predict_instances_big.")
580
+
581
+
582
+
583
+ class NotFullyVisible(Exception):
584
+ pass
585
+
586
+
587
+
588
+ def _grid_divisible(grid, size, name=None, verbose=True):
589
+ if size % grid == 0:
590
+ return size
591
+ _size = size
592
+ size = math.ceil(size / grid) * grid
593
+ if bool(verbose):
594
+ print(f"{verbose if isinstance(verbose,str) else ''}increasing '{'value' if name is None else name}' from {_size} to {size} to be evenly divisible by {grid} (grid)", flush=True)
595
+ assert size % grid == 0
596
+ return size
597
+
598
+
599
+
600
+ # def render_polygons(polys, shape):
601
+ # return polygons_to_label_coord(polys['coord'], shape=shape)
stardist_pkg/bioimageio_utils.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from pkg_resources import get_distribution
3
+ from zipfile import ZipFile
4
+ import numpy as np
5
+ import tempfile
6
+ from distutils.version import LooseVersion
7
+ from csbdeep.utils import axes_check_and_normalize, normalize, _raise
8
+
9
+
10
+ DEEPIMAGEJ_MACRO = \
11
+ """
12
+ //*******************************************************************
13
+ // Date: July-2021
14
+ // Credits: StarDist, DeepImageJ
15
+ // URL:
16
+ // https://github.com/stardist/stardist
17
+ // https://deepimagej.github.io/deepimagej
18
+ // This macro was adapted from
19
+ // https://github.com/deepimagej/imagej-macros/blob/648caa867f6ccb459649d4d3799efa1e2e0c5204/StarDist2D_Post-processing.ijm
20
+ // Please cite the respective contributions when using this code.
21
+ //*******************************************************************
22
+ // Macro to run StarDist postprocessing on 2D images.
23
+ // StarDist and deepImageJ plugins need to be installed.
24
+ // The macro assumes that the image to process is a stack in which
25
+ // the first channel corresponds to the object probability map
26
+ // and the remaining channels are the radial distances from each
27
+ // pixel to the object boundary.
28
+ //*******************************************************************
29
+
30
+ // Get the name of the image to call it
31
+ getDimensions(width, height, channels, slices, frames);
32
+ name=getTitle();
33
+
34
+ probThresh={probThresh};
35
+ nmsThresh={nmsThresh};
36
+
37
+ // Isolate the detection probability scores
38
+ run("Make Substack...", "channels=1");
39
+ rename("scores");
40
+
41
+ // Isolate the oriented distances
42
+ run("Fire");
43
+ selectWindow(name);
44
+ run("Delete Slice", "delete=channel");
45
+ selectWindow(name);
46
+ run("Properties...", "channels=" + maxOf(channels, slices) - 1 + " slices=1 frames=1 pixel_width=1.0000 pixel_height=1.0000 voxel_depth=1.0000");
47
+ rename("distances");
48
+ run("royal");
49
+
50
+ // Run StarDist plugin
51
+ run("Command From Macro", "command=[de.csbdresden.stardist.StarDist2DNMS], args=['prob':'scores', 'dist':'distances', 'probThresh':'" + probThresh + "', 'nmsThresh':'" + nmsThresh + "', 'outputType':'Both', 'excludeBoundary':'2', 'roiPosition':'Stack', 'verbose':'false'], process=[false]");
52
+ """
53
+
54
+
55
+ def _import(error=True):
56
+ try:
57
+ from importlib_metadata import metadata
58
+ from bioimageio.core.build_spec import build_model # type: ignore
59
+ import xarray as xr
60
+ import bioimageio.core # type: ignore
61
+ except ImportError:
62
+ if error:
63
+ raise RuntimeError(
64
+ "Required libraries are missing for bioimage.io model export.\n"
65
+ "Please install StarDist as follows: pip install 'stardist[bioimageio]'\n"
66
+ "(You do not need to uninstall StarDist first.)"
67
+ )
68
+ else:
69
+ return None
70
+ return metadata, build_model, bioimageio.core, xr
71
+
72
+
73
+ def _create_stardist_dependencies(outdir):
74
+ from ruamel.yaml import YAML
75
+ from tensorflow import __version__ as tf_version
76
+ from . import __version__ as stardist_version
77
+ pkg_info = get_distribution("stardist")
78
+ # dependencies that start with the name "bioimageio" will be added as conda dependencies
79
+ reqs_conda = [str(req) for req in pkg_info.requires(extras=['bioimageio']) if str(req).startswith('bioimageio')]
80
+ # only stardist and tensorflow as pip dependencies
81
+ tf_major, tf_minor = LooseVersion(tf_version).version[:2]
82
+ reqs_pip = (f"stardist>={stardist_version}", f"tensorflow>={tf_major}.{tf_minor},<{tf_major+1}")
83
+ # conda environment
84
+ env = dict(
85
+ name = 'stardist',
86
+ channels = ['defaults', 'conda-forge'],
87
+ dependencies = [
88
+ ('python>=3.7,<3.8' if tf_major == 1 else 'python>=3.7'),
89
+ *reqs_conda,
90
+ 'pip', {'pip': reqs_pip},
91
+ ],
92
+ )
93
+ yaml = YAML(typ='safe')
94
+ path = outdir / "environment.yaml"
95
+ with open(path, "w") as f:
96
+ yaml.dump(env, f)
97
+ return f"conda:{path}"
98
+
99
+
100
+ def _create_stardist_doc(outdir):
101
+ doc_path = outdir / "README.md"
102
+ text = (
103
+ "# StarDist Model\n"
104
+ "This is a model for object detection with star-convex shapes.\n"
105
+ "Please see the [StarDist repository](https://github.com/stardist/stardist) for details."
106
+ )
107
+ with open(doc_path, "w") as f:
108
+ f.write(text)
109
+ return doc_path
110
+
111
+
112
+ def _get_stardist_metadata(outdir, model):
113
+ metadata, *_ = _import()
114
+ package_data = metadata("stardist")
115
+ doi_2d = "https://doi.org/10.1007/978-3-030-00934-2_30"
116
+ doi_3d = "https://doi.org/10.1109/WACV45572.2020.9093435"
117
+ authors = {
118
+ 'Martin Weigert': dict(name='Martin Weigert', github_user='maweigert'),
119
+ 'Uwe Schmidt': dict(name='Uwe Schmidt', github_user='uschmidt83'),
120
+ }
121
+ data = dict(
122
+ description=package_data["Summary"],
123
+ authors=list(authors.get(name.strip(),dict(name=name.strip())) for name in package_data["Author"].split(",")),
124
+ git_repo=package_data["Home-Page"],
125
+ license=package_data["License"],
126
+ dependencies=_create_stardist_dependencies(outdir),
127
+ cite=[{"text": "Cell Detection with Star-Convex Polygons", "doi": doi_2d},
128
+ {"text": "Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy", "doi": doi_3d}],
129
+ tags=[
130
+ 'fluorescence-light-microscopy', 'whole-slide-imaging', 'other', # modality
131
+ f'{model.config.n_dim}d', # dims
132
+ 'cells', 'nuclei', # content
133
+ 'tensorflow', # framework
134
+ 'fiji', # software
135
+ 'unet', # network
136
+ 'instance-segmentation', 'object-detection', # task
137
+ 'stardist',
138
+ ],
139
+ covers=["https://raw.githubusercontent.com/stardist/stardist/master/images/stardist_logo.jpg"],
140
+ documentation=_create_stardist_doc(outdir),
141
+ )
142
+ return data
143
+
144
+
145
+ def _predict_tf(model_path, test_input):
146
+ import tensorflow as tf
147
+ from csbdeep.utils.tf import IS_TF_1
148
+ # need to unzip the model assets
149
+ model_assets = model_path.parent / "tf_model"
150
+ with ZipFile(model_path, "r") as f:
151
+ f.extractall(model_assets)
152
+ if IS_TF_1:
153
+ # make a new graph, i.e. don't use the global default graph
154
+ with tf.Graph().as_default():
155
+ with tf.Session() as sess:
156
+ tf_model = tf.saved_model.load_v2(str(model_assets))
157
+ x = tf.convert_to_tensor(test_input, dtype=tf.float32)
158
+ model = tf_model.signatures["serving_default"]
159
+ y = model(x)
160
+ sess.run(tf.global_variables_initializer())
161
+ output = sess.run(y["output"])
162
+ else:
163
+ tf_model = tf.saved_model.load(str(model_assets))
164
+ x = tf.convert_to_tensor(test_input, dtype=tf.float32)
165
+ model = tf_model.signatures["serving_default"]
166
+ y = model(x)
167
+ output = y["output"].numpy()
168
+ return output
169
+
170
+
171
+ def _get_weights_and_model_metadata(outdir, model, test_input, test_input_axes, test_input_norm_axes, mode, min_percentile, max_percentile):
172
+
173
+ # get the path to the exported model assets (saved in outdir)
174
+ if mode == "keras_hdf5":
175
+ raise NotImplementedError("Export to keras format is not supported yet")
176
+ elif mode == "tensorflow_saved_model_bundle":
177
+ assets_uri = outdir / "TF_SavedModel.zip"
178
+ model_csbdeep = model.export_TF(assets_uri, single_output=True, upsample_grid=True)
179
+ else:
180
+ raise ValueError(f"Unsupported mode: {mode}")
181
+
182
+ # to force "inputs.data_type: float32" in the spec (bonus: disables normalization warning in model._predict_setup)
183
+ test_input = test_input.astype(np.float32)
184
+
185
+ # convert test_input to axes_net semantics and shape, also resize if necessary (to adhere to axes_net_div_by)
186
+ test_input, axes_img, axes_net, axes_net_div_by, *_ = model._predict_setup(
187
+ img=test_input,
188
+ axes=test_input_axes,
189
+ normalizer=None,
190
+ n_tiles=None,
191
+ show_tile_progress=False,
192
+ predict_kwargs={},
193
+ )
194
+
195
+ # normalization axes string and numeric indices
196
+ axes_norm = set(axes_net).intersection(set(axes_check_and_normalize(test_input_norm_axes, disallowed='S')))
197
+ axes_norm = "".join(a for a in axes_net if a in axes_norm) # preserve order of axes_net
198
+ axes_norm_num = tuple(axes_net.index(a) for a in axes_norm)
199
+
200
+ # normalize input image
201
+ test_input_norm = normalize(test_input, pmin=min_percentile, pmax=max_percentile, axis=axes_norm_num)
202
+
203
+ net_axes_in = axes_net.lower()
204
+ net_axes_out = axes_check_and_normalize(model._axes_out).lower()
205
+ ndim_tensor = len(net_axes_out) + 1
206
+
207
+ input_min_shape = list(axes_net_div_by)
208
+ input_min_shape[axes_net.index('C')] = model.config.n_channel_in
209
+ input_step = list(axes_net_div_by)
210
+ input_step[axes_net.index('C')] = 0
211
+
212
+ # add the batch axis to shape and step
213
+ input_min_shape = [1] + input_min_shape
214
+ input_step = [0] + input_step
215
+
216
+ # the axes strings in bioimageio convention
217
+ input_axes = "b" + net_axes_in.lower()
218
+ output_axes = "b" + net_axes_out.lower()
219
+
220
+ if mode == "keras_hdf5":
221
+ output_names = ("prob", "dist") + (("class_prob",) if model._is_multiclass() else ())
222
+ output_n_channels = (1, model.config.n_rays,) + ((1,) if model._is_multiclass() else ())
223
+ # the output shape is computed from the input shape using
224
+ # output_shape[i] = output_scale[i] * input_shape[i] + 2 * output_offset[i]
225
+ output_scale = [1]+list(1/g for g in model.config.grid) + [0]
226
+ output_offset = [0]*(ndim_tensor)
227
+
228
+ elif mode == "tensorflow_saved_model_bundle":
229
+ if model._is_multiclass():
230
+ raise NotImplementedError("Tensorflow SavedModel not supported for multiclass models yet")
231
+ # regarding input/output names: https://github.com/CSBDeep/CSBDeep/blob/b0d2f5f344ebe65a9b4c3007f4567fe74268c813/csbdeep/utils/tf.py#L193-L194
232
+ input_names = ["input"]
233
+ output_names = ["output"]
234
+ output_n_channels = (1 + model.config.n_rays,)
235
+ # the output shape is computed from the input shape using
236
+ # output_shape[i] = output_scale[i] * input_shape[i] + 2 * output_offset[i]
237
+ # same shape as input except for the channel dimension
238
+ output_scale = [1]*(ndim_tensor)
239
+ output_scale[output_axes.index("c")] = 0
240
+ # no offset, except for the input axes, where it is output channel / 2
241
+ output_offset = [0.0]*(ndim_tensor)
242
+ output_offset[output_axes.index("c")] = output_n_channels[0] / 2.0
243
+
244
+ assert all(s in (0, 1) for s in output_scale), "halo computation assumption violated"
245
+ halo = model._axes_tile_overlap(output_axes.replace('b', 's'))
246
+ halo = [int(np.ceil(v/8)*8) for v in halo] # optional: round up to be divisible by 8
247
+
248
+ # the output shape needs to be valid after cropping the halo, so we add the halo to the input min shape
249
+ input_min_shape = [ms + 2 * ha for ms, ha in zip(input_min_shape, halo)]
250
+
251
+ # make sure the input min shape is still divisible by the min axis divisor
252
+ input_min_shape = input_min_shape[:1] + [ms + (-ms % div_by) for ms, div_by in zip(input_min_shape[1:], axes_net_div_by)]
253
+ assert all(ms % div_by == 0 for ms, div_by in zip(input_min_shape[1:], axes_net_div_by))
254
+
255
+ metadata, *_ = _import()
256
+ package_data = metadata("stardist")
257
+ is_2D = model.config.n_dim == 2
258
+
259
+ weights_file = outdir / "stardist_weights.h5"
260
+ model.keras_model.save_weights(str(weights_file))
261
+
262
+ config = dict(
263
+ stardist=dict(
264
+ python_version=package_data["Version"],
265
+ thresholds=dict(model.thresholds._asdict()),
266
+ weights=weights_file.name,
267
+ config=vars(model.config),
268
+ )
269
+ )
270
+
271
+ if is_2D:
272
+ macro_file = outdir / "stardist_postprocessing.ijm"
273
+ with open(str(macro_file), 'w', encoding='utf-8') as f:
274
+ f.write(DEEPIMAGEJ_MACRO.format(probThresh=model.thresholds.prob, nmsThresh=model.thresholds.nms))
275
+ config['stardist'].update(postprocessing_macro=macro_file.name)
276
+
277
+ n_inputs = len(input_names)
278
+ assert n_inputs == 1
279
+ input_config = dict(
280
+ input_names=input_names,
281
+ input_min_shape=[input_min_shape],
282
+ input_step=[input_step],
283
+ input_axes=[input_axes],
284
+ input_data_range=[["-inf", "inf"]],
285
+ preprocessing=[[dict(
286
+ name="scale_range",
287
+ kwargs=dict(
288
+ mode="per_sample",
289
+ axes=axes_norm.lower(),
290
+ min_percentile=min_percentile,
291
+ max_percentile=max_percentile,
292
+ ))]]
293
+ )
294
+
295
+ n_outputs = len(output_names)
296
+ output_config = dict(
297
+ output_names=output_names,
298
+ output_data_range=[["-inf", "inf"]] * n_outputs,
299
+ output_axes=[output_axes] * n_outputs,
300
+ output_reference=[input_names[0]] * n_outputs,
301
+ output_scale=[output_scale] * n_outputs,
302
+ output_offset=[output_offset] * n_outputs,
303
+ halo=[halo] * n_outputs
304
+ )
305
+
306
+ in_path = outdir / "test_input.npy"
307
+ np.save(in_path, test_input[np.newaxis])
308
+
309
+ if mode == "tensorflow_saved_model_bundle":
310
+ test_outputs = _predict_tf(assets_uri, test_input_norm[np.newaxis])
311
+ else:
312
+ test_outputs = model.predict(test_input_norm)
313
+
314
+ # out_paths = []
315
+ # for i, out in enumerate(test_outputs):
316
+ # p = outdir / f"test_output{i}.npy"
317
+ # np.save(p, out)
318
+ # out_paths.append(p)
319
+ assert n_outputs == 1
320
+ out_paths = [outdir / "test_output.npy"]
321
+ np.save(out_paths[0], test_outputs)
322
+
323
+ from tensorflow import __version__ as tf_version
324
+ data = dict(weight_uri=assets_uri, test_inputs=[in_path], test_outputs=out_paths,
325
+ config=config, tensorflow_version=tf_version)
326
+ data.update(input_config)
327
+ data.update(output_config)
328
+ _files = [str(weights_file)]
329
+ if is_2D:
330
+ _files.append(str(macro_file))
331
+ data.update(attachments=dict(files=_files))
332
+
333
+ return data
334
+
335
+
336
+ def export_bioimageio(
337
+ model,
338
+ outpath,
339
+ test_input,
340
+ test_input_axes=None,
341
+ test_input_norm_axes='ZYX',
342
+ name=None,
343
+ mode="tensorflow_saved_model_bundle",
344
+ min_percentile=1.0,
345
+ max_percentile=99.8,
346
+ overwrite_spec_kwargs=None,
347
+ ):
348
+ """Export stardist model into bioimage.io format, https://github.com/bioimage-io/spec-bioimage-io.
349
+
350
+ Parameters
351
+ ----------
352
+ model: StarDist2D, StarDist3D
353
+ the model to convert
354
+ outpath: str, Path
355
+ where to save the model
356
+ test_input: np.ndarray
357
+ input image for generating test data
358
+ test_input_axes: str or None
359
+ the axes of the test input, for example 'YX' for a 2d image or 'ZYX' for a 3d volume
360
+ using None assumes that axes of test_input are the same as those of model
361
+ test_input_norm_axes: str
362
+ the axes of the test input which will be jointly normalized, for example 'ZYX' for all spatial dimensions ('Z' ignored for 2D input)
363
+ use 'ZYXC' to also jointly normalize channels (e.g. for RGB input images)
364
+ name: str
365
+ the name of this model (default: None)
366
+ if None, uses the (folder) name of the model (i.e. `model.name`)
367
+ mode: str
368
+ the export type for this model (default: "tensorflow_saved_model_bundle")
369
+ min_percentile: float
370
+ min percentile to be used for image normalization (default: 1.0)
371
+ max_percentile: float
372
+ max percentile to be used for image normalization (default: 99.8)
373
+ overwrite_spec_kwargs: dict or None
374
+ spec keywords that should be overloaded (default: None)
375
+ """
376
+ _, build_model, *_ = _import()
377
+ from .models import StarDist2D, StarDist3D
378
+ isinstance(model, (StarDist2D, StarDist3D)) or _raise(ValueError("not a valid model"))
379
+ 0 <= min_percentile < max_percentile <= 100 or _raise(ValueError("invalid percentile values"))
380
+
381
+ if name is None:
382
+ name = model.name
383
+ name = str(name)
384
+
385
+ outpath = Path(outpath)
386
+ if outpath.suffix == "":
387
+ outdir = outpath
388
+ zip_path = outdir / f"{name}.zip"
389
+ elif outpath.suffix == ".zip":
390
+ outdir = outpath.parent
391
+ zip_path = outpath
392
+ else:
393
+ raise ValueError(f"outpath has to be a folder or zip file, got {outpath}")
394
+ outdir.mkdir(exist_ok=True, parents=True)
395
+
396
+ with tempfile.TemporaryDirectory() as _tmp_dir:
397
+ tmp_dir = Path(_tmp_dir)
398
+ kwargs = _get_stardist_metadata(tmp_dir, model)
399
+ model_kwargs = _get_weights_and_model_metadata(tmp_dir, model, test_input, test_input_axes, test_input_norm_axes, mode,
400
+ min_percentile=min_percentile, max_percentile=max_percentile)
401
+ kwargs.update(model_kwargs)
402
+ if overwrite_spec_kwargs is not None:
403
+ kwargs.update(overwrite_spec_kwargs)
404
+
405
+ build_model(name=name, output_path=zip_path, add_deepimagej_config=(model.config.n_dim==2), root=tmp_dir, **kwargs)
406
+ print(f"\nbioimage.io model with name '{name}' exported to '{zip_path}'")
407
+
408
+
409
+ def import_bioimageio(source, outpath):
410
+ """Import stardist model from bioimage.io format, https://github.com/bioimage-io/spec-bioimage-io.
411
+
412
+ Load a model in bioimage.io format from the given `source` (e.g. path to zip file, URL)
413
+ and convert it to a regular stardist model, which will be saved in the folder `outpath`.
414
+
415
+ Parameters
416
+ ----------
417
+ source: str, Path
418
+ bioimage.io resource (e.g. path, URL)
419
+ outpath: str, Path
420
+ folder to save the stardist model (must not exist previously)
421
+
422
+ Returns
423
+ -------
424
+ StarDist2D or StarDist3D
425
+ stardist model loaded from `outpath`
426
+
427
+ """
428
+ import shutil, uuid
429
+ from csbdeep.utils import save_json
430
+ from .models import StarDist2D, StarDist3D
431
+ *_, bioimageio_core, _ = _import()
432
+
433
+ outpath = Path(outpath)
434
+ not outpath.exists() or _raise(FileExistsError(f"'{outpath}' already exists"))
435
+
436
+ with tempfile.TemporaryDirectory() as _tmp_dir:
437
+ tmp_dir = Path(_tmp_dir)
438
+ # download the full model content to a temporary folder
439
+ zip_path = tmp_dir / f"{str(uuid.uuid4())}.zip"
440
+ bioimageio_core.export_resource_package(source, output_path=zip_path)
441
+ with ZipFile(zip_path, "r") as zip_ref:
442
+ zip_ref.extractall(tmp_dir)
443
+ zip_path.unlink()
444
+ rdf_path = tmp_dir / "rdf.yaml"
445
+ biomodel = bioimageio_core.load_resource_description(rdf_path)
446
+
447
+ # read the stardist specific content
448
+ 'stardist' in biomodel.config or _raise(RuntimeError("bioimage.io model not compatible"))
449
+ config = biomodel.config['stardist']['config']
450
+ thresholds = biomodel.config['stardist']['thresholds']
451
+ weights = biomodel.config['stardist']['weights']
452
+
453
+ # make sure that the keras weights are in the attachments
454
+ weights_file = None
455
+ for f in biomodel.attachments.files:
456
+ if f.name == weights and f.exists():
457
+ weights_file = f
458
+ break
459
+ weights_file is not None or _raise(FileNotFoundError(f"couldn't find weights file '{weights}'"))
460
+
461
+ # save the config and threshold to json, and weights to hdf5 to enable loading as stardist model
462
+ # copy bioimageio files to separate sub-folder
463
+ outpath.mkdir(parents=True)
464
+ save_json(config, str(outpath / 'config.json'))
465
+ save_json(thresholds, str(outpath / 'thresholds.json'))
466
+ shutil.copy(str(weights_file), str(outpath / "weights_bioimageio.h5"))
467
+ shutil.copytree(str(tmp_dir), str(outpath / "bioimageio"))
468
+
469
+ model_class = (StarDist2D if config['n_dim'] == 2 else StarDist3D)
470
+ model = model_class(None, outpath.name, basedir=str(outpath.parent))
471
+
472
+ return model
stardist_pkg/geometry/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, print_function
2
+
3
+ # TODO: rethink naming for 2D/3D functions
4
+
5
+ from .geom2d import star_dist, relabel_image_stardist, ray_angles, dist_to_coord, polygons_to_label, polygons_to_label_coord
6
+
7
+ from .geom2d import _dist_to_coord_old, _polygons_to_label_old
8
+
9
+ #, dist_to_volume, dist_to_centroid
stardist_pkg/geometry/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (522 Bytes). View file
 
stardist_pkg/geometry/__pycache__/geom2d.cpython-37.pyc ADDED
Binary file (7.23 kB). View file
 
stardist_pkg/geometry/__pycache__/geom3d.cpython-37.pyc ADDED
Binary file (11 kB). View file
 
stardist_pkg/geometry/geom2d.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, unicode_literals, absolute_import, division
2
+ import numpy as np
3
+ import warnings
4
+
5
+ from skimage.measure import regionprops
6
+ from skimage.draw import polygon
7
+ from csbdeep.utils import _raise
8
+
9
+ from ..utils import path_absolute, _is_power_of_2, _normalize_grid
10
+ from ..matching import _check_label_array
11
+ from stardist.lib.stardist2d import c_star_dist
12
+
13
+
14
+
15
+ def _ocl_star_dist(lbl, n_rays=32, grid=(1,1)):
16
+ from gputools import OCLProgram, OCLArray, OCLImage
17
+ (np.isscalar(n_rays) and 0 < int(n_rays)) or _raise(ValueError())
18
+ n_rays = int(n_rays)
19
+ # slicing with grid is done with tuple(slice(0, None, g) for g in grid)
20
+ res_shape = tuple((s-1)//g+1 for s, g in zip(lbl.shape, grid))
21
+
22
+ src = OCLImage.from_array(lbl.astype(np.uint16,copy=False))
23
+ dst = OCLArray.empty(res_shape+(n_rays,), dtype=np.float32)
24
+ program = OCLProgram(path_absolute("kernels/stardist2d.cl"), build_options=['-D', 'N_RAYS=%d' % n_rays])
25
+ program.run_kernel('star_dist', res_shape[::-1], None, dst.data, src, np.int32(grid[0]),np.int32(grid[1]))
26
+ return dst.get()
27
+
28
+
29
+ def _cpp_star_dist(lbl, n_rays=32, grid=(1,1)):
30
+ (np.isscalar(n_rays) and 0 < int(n_rays)) or _raise(ValueError())
31
+ return c_star_dist(lbl.astype(np.uint16,copy=False), np.int32(n_rays), np.int32(grid[0]),np.int32(grid[1]))
32
+
33
+
34
+ def _py_star_dist(a, n_rays=32, grid=(1,1)):
35
+ (np.isscalar(n_rays) and 0 < int(n_rays)) or _raise(ValueError())
36
+ if grid != (1,1):
37
+ raise NotImplementedError(grid)
38
+
39
+ n_rays = int(n_rays)
40
+ a = a.astype(np.uint16,copy=False)
41
+ dst = np.empty(a.shape+(n_rays,),np.float32)
42
+
43
+ for i in range(a.shape[0]):
44
+ for j in range(a.shape[1]):
45
+ value = a[i,j]
46
+ if value == 0:
47
+ dst[i,j] = 0
48
+ else:
49
+ st_rays = np.float32((2*np.pi) / n_rays)
50
+ for k in range(n_rays):
51
+ phi = np.float32(k*st_rays)
52
+ dy = np.cos(phi)
53
+ dx = np.sin(phi)
54
+ x, y = np.float32(0), np.float32(0)
55
+ while True:
56
+ x += dx
57
+ y += dy
58
+ ii = int(round(i+x))
59
+ jj = int(round(j+y))
60
+ if (ii < 0 or ii >= a.shape[0] or
61
+ jj < 0 or jj >= a.shape[1] or
62
+ value != a[ii,jj]):
63
+ # small correction as we overshoot the boundary
64
+ t_corr = 1-.5/max(np.abs(dx),np.abs(dy))
65
+ x -= t_corr*dx
66
+ y -= t_corr*dy
67
+ dist = np.sqrt(x**2+y**2)
68
+ dst[i,j,k] = dist
69
+ break
70
+ return dst
71
+
72
+
73
+ def star_dist(a, n_rays=32, grid=(1,1), mode='cpp'):
74
+ """'a' assumbed to be a label image with integer values that encode object ids. id 0 denotes background."""
75
+
76
+ n_rays >= 3 or _raise(ValueError("need 'n_rays' >= 3"))
77
+
78
+ if mode == 'python':
79
+ return _py_star_dist(a, n_rays, grid=grid)
80
+ elif mode == 'cpp':
81
+ return _cpp_star_dist(a, n_rays, grid=grid)
82
+ elif mode == 'opencl':
83
+ return _ocl_star_dist(a, n_rays, grid=grid)
84
+ else:
85
+ _raise(ValueError("Unknown mode %s" % mode))
86
+
87
+
88
+ def _dist_to_coord_old(rhos, grid=(1,1)):
89
+ """convert from polar to cartesian coordinates for a single image (3-D array) or multiple images (4-D array)"""
90
+
91
+ grid = _normalize_grid(grid,2)
92
+ is_single_image = rhos.ndim == 3
93
+ if is_single_image:
94
+ rhos = np.expand_dims(rhos,0)
95
+ assert rhos.ndim == 4
96
+
97
+ n_images,h,w,n_rays = rhos.shape
98
+ coord = np.empty((n_images,h,w,2,n_rays),dtype=rhos.dtype)
99
+
100
+ start = np.indices((h,w))
101
+ for i in range(2):
102
+ coord[...,i,:] = grid[i] * np.broadcast_to(start[i].reshape(1,h,w,1), (n_images,h,w,n_rays))
103
+
104
+ phis = ray_angles(n_rays).reshape(1,1,1,n_rays)
105
+
106
+ coord[...,0,:] += rhos * np.sin(phis) # row coordinate
107
+ coord[...,1,:] += rhos * np.cos(phis) # col coordinate
108
+
109
+ return coord[0] if is_single_image else coord
110
+
111
+
112
+ def _polygons_to_label_old(coord, prob, points, shape=None, thr=-np.inf):
113
+ sh = coord.shape[:2] if shape is None else shape
114
+ lbl = np.zeros(sh,np.int32)
115
+ # sort points with increasing probability
116
+ ind = np.argsort([ prob[p[0],p[1]] for p in points ])
117
+ points = points[ind]
118
+
119
+ i = 1
120
+ for p in points:
121
+ if prob[p[0],p[1]] < thr:
122
+ continue
123
+ rr,cc = polygon(coord[p[0],p[1],0], coord[p[0],p[1],1], sh)
124
+ lbl[rr,cc] = i
125
+ i += 1
126
+
127
+ return lbl
128
+
129
+
130
+ def dist_to_coord(dist, points, scale_dist=(1,1)):
131
+ """convert from polar to cartesian coordinates for a list of distances and center points
132
+ dist.shape = (n_polys, n_rays)
133
+ points.shape = (n_polys, 2)
134
+ len(scale_dist) = 2
135
+ return coord.shape = (n_polys,2,n_rays)
136
+ """
137
+ dist = np.asarray(dist)
138
+ points = np.asarray(points)
139
+ assert dist.ndim==2 and points.ndim==2 and len(dist)==len(points) \
140
+ and points.shape[1]==2 and len(scale_dist)==2
141
+ n_rays = dist.shape[1]
142
+ phis = ray_angles(n_rays)
143
+ coord = (dist[:,np.newaxis]*np.array([np.sin(phis),np.cos(phis)])).astype(np.float32)
144
+ coord *= np.asarray(scale_dist).reshape(1,2,1)
145
+ coord += points[...,np.newaxis]
146
+ return coord
147
+
148
+
149
+ def polygons_to_label_coord(coord, shape, labels=None):
150
+ """renders polygons to image of given shape
151
+
152
+ coord.shape = (n_polys, n_rays)
153
+ """
154
+ coord = np.asarray(coord)
155
+ if labels is None: labels = np.arange(len(coord))
156
+
157
+ _check_label_array(labels, "labels")
158
+ assert coord.ndim==3 and coord.shape[1]==2 and len(coord)==len(labels)
159
+
160
+ lbl = np.zeros(shape,np.int32)
161
+
162
+ for i,c in zip(labels,coord):
163
+ rr,cc = polygon(*c, shape)
164
+ lbl[rr,cc] = i+1
165
+
166
+ return lbl
167
+
168
+
169
+ def polygons_to_label(dist, points, shape, prob=None, thr=-np.inf, scale_dist=(1,1)):
170
+ """converts distances and center points to label image
171
+
172
+ dist.shape = (n_polys, n_rays)
173
+ points.shape = (n_polys, 2)
174
+
175
+ label ids will be consecutive and adhere to the order given
176
+ """
177
+ dist = np.asarray(dist)
178
+ points = np.asarray(points)
179
+ prob = np.inf*np.ones(len(points)) if prob is None else np.asarray(prob)
180
+
181
+ assert dist.ndim==2 and points.ndim==2 and len(dist)==len(points)
182
+ assert len(points)==len(prob) and points.shape[1]==2 and prob.ndim==1
183
+
184
+ n_rays = dist.shape[1]
185
+
186
+ ind = prob>thr
187
+ points = points[ind]
188
+ dist = dist[ind]
189
+ prob = prob[ind]
190
+
191
+ ind = np.argsort(prob, kind='stable')
192
+ points = points[ind]
193
+ dist = dist[ind]
194
+
195
+ coord = dist_to_coord(dist, points, scale_dist=scale_dist)
196
+
197
+ return polygons_to_label_coord(coord, shape=shape, labels=ind)
198
+
199
+
200
+ def relabel_image_stardist(lbl, n_rays, **kwargs):
201
+ """relabel each label region in `lbl` with its star representation"""
202
+ _check_label_array(lbl, "lbl")
203
+ if not lbl.ndim==2:
204
+ raise ValueError("lbl image should be 2 dimensional")
205
+ dist = star_dist(lbl, n_rays, **kwargs)
206
+ points = np.array(tuple(np.array(r.centroid).astype(int) for r in regionprops(lbl)))
207
+ dist = dist[tuple(points.T)]
208
+ return polygons_to_label(dist, points, shape=lbl.shape)
209
+
210
+
211
+ def ray_angles(n_rays=32):
212
+ return np.linspace(0,2*np.pi,n_rays,endpoint=False)
stardist_pkg/kernels/stardist2d.cl ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef M_PI
2
+ #define M_PI 3.141592653589793
3
+ #endif
4
+
5
+ __constant sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
6
+
7
+ inline float2 pol2cart(const float rho, const float phi) {
8
+ const float x = rho * cos(phi);
9
+ const float y = rho * sin(phi);
10
+ return (float2)(x,y);
11
+ }
12
+
13
+ __kernel void star_dist(__global float* dst, read_only image2d_t src, const int grid_y, const int grid_x) {
14
+
15
+ const int i = get_global_id(0), j = get_global_id(1);
16
+ const int Nx = get_global_size(0), Ny = get_global_size(1);
17
+ const float2 grid = (float2)(grid_x, grid_y);
18
+
19
+ const float2 origin = (float2)(i,j) * grid;
20
+ const int value = read_imageui(src,sampler,origin).x;
21
+
22
+ if (value == 0) {
23
+ // background pixel -> nothing to do, write all zeros
24
+ for (int k = 0; k < N_RAYS; k++) {
25
+ dst[k + i*N_RAYS + j*N_RAYS*Nx] = 0;
26
+ }
27
+ } else {
28
+ float st_rays = (2*M_PI) / N_RAYS; // step size for ray angles
29
+ // for all rays
30
+ for (int k = 0; k < N_RAYS; k++) {
31
+ const float phi = k*st_rays; // current ray angle phi
32
+ const float2 dir = pol2cart(1,phi); // small vector in direction of ray
33
+ float2 offset = 0; // offset vector to be added to origin
34
+ // find radius that leaves current object
35
+ while (1) {
36
+ offset += dir;
37
+ const int offset_value = read_imageui(src,sampler,round(origin+offset)).x;
38
+ if (offset_value != value) {
39
+ // small correction as we overshoot the boundary
40
+ const float t_corr = .5f/fmax(fabs(dir.x),fabs(dir.y));
41
+ offset += (t_corr-1.f)*dir;
42
+
43
+ const float dist = sqrt(offset.x*offset.x + offset.y*offset.y);
44
+ dst[k + i*N_RAYS + j*N_RAYS*Nx] = dist;
45
+ break;
46
+ }
47
+ }
48
+ }
49
+ }
50
+
51
+ }
stardist_pkg/kernels/stardist3d.cl ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef M_PI
2
+ #define M_PI 3.141592653589793
3
+ #endif
4
+
5
+ __constant sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
6
+
7
+ inline int round_to_int(float r) {
8
+ return (int)rint(r);
9
+ }
10
+
11
+
12
+ __kernel void stardist3d(read_only image3d_t lbl, __constant float * rays, __global float* dist, const int grid_z, const int grid_y, const int grid_x) {
13
+
14
+ const int i = get_global_id(0);
15
+ const int j = get_global_id(1);
16
+ const int k = get_global_id(2);
17
+
18
+ const int Nx = get_global_size(0);
19
+ const int Ny = get_global_size(1);
20
+ const int Nz = get_global_size(2);
21
+
22
+ const float4 grid = (float4)(grid_x, grid_y, grid_z, 1);
23
+ const float4 origin = (float4)(i,j,k,0) * grid;
24
+ const int value = read_imageui(lbl,sampler,origin).x;
25
+
26
+ if (value == 0) {
27
+ // background pixel -> nothing to do, write all zeros
28
+ for (int m = 0; m < N_RAYS; m++) {
29
+ dist[m + i*N_RAYS + j*N_RAYS*Nx+k*N_RAYS*Nx*Ny] = 0;
30
+ }
31
+
32
+ }
33
+ else {
34
+ for (int m = 0; m < N_RAYS; m++) {
35
+
36
+ const float4 dx = (float4)(rays[3*m+2],rays[3*m+1],rays[3*m],0);
37
+ // if ((i==Nx/2)&&(j==Ny/2)&(k==Nz/2)){
38
+ // printf("kernel: %.2f %.2f %.2f \n",dx.x,dx.y,dx.z);
39
+ // }
40
+ float4 x = (float4)(0,0,0,0);
41
+
42
+ // move along ray
43
+ while (1) {
44
+ x += dx;
45
+ // if ((i==10)&&(j==10)&(k==10)){
46
+ // printf("kernel run: %.2f %.2f %.2f value %d \n",x.x,x.y,x.z, read_imageui(lbl,sampler,origin+x).x);
47
+ // }
48
+
49
+ // to make it equivalent to the cpp version...
50
+ const float4 x_int = (float4)(round_to_int(x.x),
51
+ round_to_int(x.y),
52
+ round_to_int(x.z),
53
+ 0);
54
+
55
+ if (value != read_imageui(lbl,sampler,origin+x_int).x){
56
+
57
+ dist[m + i*N_RAYS + j*N_RAYS*Nx+k*N_RAYS*Nx*Ny] = length(x_int);
58
+ break;
59
+ }
60
+ }
61
+ }
62
+ }
63
+ }
stardist_pkg/matching.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from numba import jit
4
+ from tqdm import tqdm
5
+ from scipy.optimize import linear_sum_assignment
6
+ from skimage.measure import regionprops
7
+ from collections import namedtuple
8
+ from csbdeep.utils import _raise
9
+
10
+ matching_criteria = dict()
11
+
12
+
13
+ def label_are_sequential(y):
14
+ """ returns true if y has only sequential labels from 1... """
15
+ labels = np.unique(y)
16
+ return (set(labels)-{0}) == set(range(1,1+labels.max()))
17
+
18
+
19
+ def is_array_of_integers(y):
20
+ return isinstance(y,np.ndarray) and np.issubdtype(y.dtype, np.integer)
21
+
22
+
23
+ def _check_label_array(y, name=None, check_sequential=False):
24
+ err = ValueError("{label} must be an array of {integers}.".format(
25
+ label = 'labels' if name is None else name,
26
+ integers = ('sequential ' if check_sequential else '') + 'non-negative integers',
27
+ ))
28
+ is_array_of_integers(y) or _raise(err)
29
+ if len(y) == 0:
30
+ return True
31
+ if check_sequential:
32
+ label_are_sequential(y) or _raise(err)
33
+ else:
34
+ y.min() >= 0 or _raise(err)
35
+ return True
36
+
37
+
38
+ def label_overlap(x, y, check=True):
39
+ if check:
40
+ _check_label_array(x,'x',True)
41
+ _check_label_array(y,'y',True)
42
+ x.shape == y.shape or _raise(ValueError("x and y must have the same shape"))
43
+ return _label_overlap(x, y)
44
+
45
+ @jit(nopython=True)
46
+ def _label_overlap(x, y):
47
+ x = x.ravel()
48
+ y = y.ravel()
49
+ overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint)
50
+ for i in range(len(x)):
51
+ overlap[x[i],y[i]] += 1
52
+ return overlap
53
+
54
+
55
+ def _safe_divide(x,y, eps=1e-10):
56
+ """computes a safe divide which returns 0 if y is zero"""
57
+ if np.isscalar(x) and np.isscalar(y):
58
+ return x/y if np.abs(y)>eps else 0.0
59
+ else:
60
+ out = np.zeros(np.broadcast(x,y).shape, np.float32)
61
+ np.divide(x,y, out=out, where=np.abs(y)>eps)
62
+ return out
63
+
64
+
65
+ def intersection_over_union(overlap):
66
+ _check_label_array(overlap,'overlap')
67
+ if np.sum(overlap) == 0:
68
+ return overlap
69
+ n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
70
+ n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
71
+ return _safe_divide(overlap, (n_pixels_pred + n_pixels_true - overlap))
72
+
73
+ matching_criteria['iou'] = intersection_over_union
74
+
75
+
76
+ def intersection_over_true(overlap):
77
+ _check_label_array(overlap,'overlap')
78
+ if np.sum(overlap) == 0:
79
+ return overlap
80
+ n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
81
+ return _safe_divide(overlap, n_pixels_true)
82
+
83
+ matching_criteria['iot'] = intersection_over_true
84
+
85
+
86
+ def intersection_over_pred(overlap):
87
+ _check_label_array(overlap,'overlap')
88
+ if np.sum(overlap) == 0:
89
+ return overlap
90
+ n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
91
+ return _safe_divide(overlap, n_pixels_pred)
92
+
93
+ matching_criteria['iop'] = intersection_over_pred
94
+
95
+
96
+ def precision(tp,fp,fn):
97
+ return tp/(tp+fp) if tp > 0 else 0
98
+ def recall(tp,fp,fn):
99
+ return tp/(tp+fn) if tp > 0 else 0
100
+ def accuracy(tp,fp,fn):
101
+ # also known as "average precision" (?)
102
+ # -> https://www.kaggle.com/c/data-science-bowl-2018#evaluation
103
+ return tp/(tp+fp+fn) if tp > 0 else 0
104
+ def f1(tp,fp,fn):
105
+ # also known as "dice coefficient"
106
+ return (2*tp)/(2*tp+fp+fn) if tp > 0 else 0
107
+
108
+
109
+ def matching(y_true, y_pred, thresh=0.5, criterion='iou', report_matches=False):
110
+ """Calculate detection/instance segmentation metrics between ground truth and predicted label images.
111
+
112
+ Currently, the following metrics are implemented:
113
+
114
+ 'fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'
115
+
116
+ Corresponding objects of y_true and y_pred are counted as true positives (tp), false positives (fp), and false negatives (fn)
117
+ whether their intersection over union (IoU) >= thresh (for criterion='iou', which can be changed)
118
+
119
+ * mean_matched_score is the mean IoUs of matched true positives
120
+
121
+ * mean_true_score is the mean IoUs of matched true positives but normalized by the total number of GT objects
122
+
123
+ * panoptic_quality defined as in Eq. 1 of Kirillov et al. "Panoptic Segmentation", CVPR 2019
124
+
125
+ Parameters
126
+ ----------
127
+ y_true: ndarray
128
+ ground truth label image (integer valued)
129
+ y_pred: ndarray
130
+ predicted label image (integer valued)
131
+ thresh: float
132
+ threshold for matching criterion (default 0.5)
133
+ criterion: string
134
+ matching criterion (default IoU)
135
+ report_matches: bool
136
+ if True, additionally calculate matched_pairs and matched_scores (note, that this returns even gt-pred pairs whose scores are below 'thresh')
137
+
138
+ Returns
139
+ -------
140
+ Matching object with different metrics as attributes
141
+
142
+ Examples
143
+ --------
144
+ >>> y_true = np.zeros((100,100), np.uint16)
145
+ >>> y_true[10:20,10:20] = 1
146
+ >>> y_pred = np.roll(y_true,5,axis = 0)
147
+
148
+ >>> stats = matching(y_true, y_pred)
149
+ >>> print(stats)
150
+ Matching(criterion='iou', thresh=0.5, fp=1, tp=0, fn=1, precision=0, recall=0, accuracy=0, f1=0, n_true=1, n_pred=1, mean_true_score=0.0, mean_matched_score=0.0, panoptic_quality=0.0)
151
+
152
+ """
153
+ _check_label_array(y_true,'y_true')
154
+ _check_label_array(y_pred,'y_pred')
155
+ y_true.shape == y_pred.shape or _raise(ValueError("y_true ({y_true.shape}) and y_pred ({y_pred.shape}) have different shapes".format(y_true=y_true, y_pred=y_pred)))
156
+ criterion in matching_criteria or _raise(ValueError("Matching criterion '%s' not supported." % criterion))
157
+ if thresh is None: thresh = 0
158
+ thresh = float(thresh) if np.isscalar(thresh) else map(float,thresh)
159
+
160
+ y_true, _, map_rev_true = relabel_sequential(y_true)
161
+ y_pred, _, map_rev_pred = relabel_sequential(y_pred)
162
+
163
+ overlap = label_overlap(y_true, y_pred, check=False)
164
+ scores = matching_criteria[criterion](overlap)
165
+ assert 0 <= np.min(scores) <= np.max(scores) <= 1
166
+
167
+ # ignoring background
168
+ scores = scores[1:,1:]
169
+ n_true, n_pred = scores.shape
170
+ n_matched = min(n_true, n_pred)
171
+
172
+ def _single(thr):
173
+ # not_trivial = n_matched > 0 and np.any(scores >= thr)
174
+ not_trivial = n_matched > 0
175
+ if not_trivial:
176
+ # compute optimal matching with scores as tie-breaker
177
+ costs = -(scores >= thr).astype(float) - scores / (2*n_matched)
178
+ true_ind, pred_ind = linear_sum_assignment(costs)
179
+ assert n_matched == len(true_ind) == len(pred_ind)
180
+ match_ok = scores[true_ind,pred_ind] >= thr
181
+ tp = np.count_nonzero(match_ok)
182
+ else:
183
+ tp = 0
184
+ fp = n_pred - tp
185
+ fn = n_true - tp
186
+ # assert tp+fp == n_pred
187
+ # assert tp+fn == n_true
188
+
189
+ # the score sum over all matched objects (tp)
190
+ sum_matched_score = np.sum(scores[true_ind,pred_ind][match_ok]) if not_trivial else 0.0
191
+
192
+ # the score average over all matched objects (tp)
193
+ mean_matched_score = _safe_divide(sum_matched_score, tp)
194
+ # the score average over all gt/true objects
195
+ mean_true_score = _safe_divide(sum_matched_score, n_true)
196
+ panoptic_quality = _safe_divide(sum_matched_score, tp+fp/2+fn/2)
197
+
198
+ stats_dict = dict (
199
+ criterion = criterion,
200
+ thresh = thr,
201
+ fp = fp,
202
+ tp = tp,
203
+ fn = fn,
204
+ precision = precision(tp,fp,fn),
205
+ recall = recall(tp,fp,fn),
206
+ accuracy = accuracy(tp,fp,fn),
207
+ f1 = f1(tp,fp,fn),
208
+ n_true = n_true,
209
+ n_pred = n_pred,
210
+ mean_true_score = mean_true_score,
211
+ mean_matched_score = mean_matched_score,
212
+ panoptic_quality = panoptic_quality,
213
+ )
214
+ if bool(report_matches):
215
+ if not_trivial:
216
+ stats_dict.update (
217
+ # int() to be json serializable
218
+ matched_pairs = tuple((int(map_rev_true[i]),int(map_rev_pred[j])) for i,j in zip(1+true_ind,1+pred_ind)),
219
+ matched_scores = tuple(scores[true_ind,pred_ind]),
220
+ matched_tps = tuple(map(int,np.flatnonzero(match_ok))),
221
+ )
222
+ else:
223
+ stats_dict.update (
224
+ matched_pairs = (),
225
+ matched_scores = (),
226
+ matched_tps = (),
227
+ )
228
+ return namedtuple('Matching',stats_dict.keys())(*stats_dict.values())
229
+
230
+ return _single(thresh) if np.isscalar(thresh) else tuple(map(_single,thresh))
231
+
232
+
233
+
234
+ def matching_dataset(y_true, y_pred, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False):
235
+ """matching metrics for list of images, see `stardist.matching.matching`
236
+ """
237
+ len(y_true) == len(y_pred) or _raise(ValueError("y_true and y_pred must have the same length."))
238
+ return matching_dataset_lazy (
239
+ tuple(zip(y_true,y_pred)), thresh=thresh, criterion=criterion, by_image=by_image, show_progress=show_progress, parallel=parallel,
240
+ )
241
+
242
+
243
+
244
+ def matching_dataset_lazy(y_gen, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False):
245
+
246
+ expected_keys = set(('fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'))
247
+
248
+ single_thresh = False
249
+ if np.isscalar(thresh):
250
+ single_thresh = True
251
+ thresh = (thresh,)
252
+
253
+ tqdm_kwargs = {}
254
+ tqdm_kwargs['disable'] = not bool(show_progress)
255
+ if int(show_progress) > 1:
256
+ tqdm_kwargs['total'] = int(show_progress)
257
+
258
+ # compute matching stats for every pair of label images
259
+ if parallel:
260
+ from concurrent.futures import ThreadPoolExecutor
261
+ fn = lambda pair: matching(*pair, thresh=thresh, criterion=criterion, report_matches=False)
262
+ with ThreadPoolExecutor() as pool:
263
+ stats_all = tuple(pool.map(fn, tqdm(y_gen,**tqdm_kwargs)))
264
+ else:
265
+ stats_all = tuple (
266
+ matching(y_t, y_p, thresh=thresh, criterion=criterion, report_matches=False)
267
+ for y_t,y_p in tqdm(y_gen,**tqdm_kwargs)
268
+ )
269
+
270
+ # accumulate results over all images for each threshold separately
271
+ n_images, n_threshs = len(stats_all), len(thresh)
272
+ accumulate = [{} for _ in range(n_threshs)]
273
+ for stats in stats_all:
274
+ for i,s in enumerate(stats):
275
+ acc = accumulate[i]
276
+ for k,v in s._asdict().items():
277
+ if k == 'mean_true_score' and not bool(by_image):
278
+ # convert mean_true_score to "sum_matched_score"
279
+ acc[k] = acc.setdefault(k,0) + v * s.n_true
280
+ else:
281
+ try:
282
+ acc[k] = acc.setdefault(k,0) + v
283
+ except TypeError:
284
+ pass
285
+
286
+ # normalize/compute 'precision', 'recall', 'accuracy', 'f1'
287
+ for thr,acc in zip(thresh,accumulate):
288
+ set(acc.keys()) == expected_keys or _raise(ValueError("unexpected keys"))
289
+ acc['criterion'] = criterion
290
+ acc['thresh'] = thr
291
+ acc['by_image'] = bool(by_image)
292
+ if bool(by_image):
293
+ for k in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'):
294
+ acc[k] /= n_images
295
+ else:
296
+ tp, fp, fn, n_true = acc['tp'], acc['fp'], acc['fn'], acc['n_true']
297
+ sum_matched_score = acc['mean_true_score']
298
+
299
+ mean_matched_score = _safe_divide(sum_matched_score, tp)
300
+ mean_true_score = _safe_divide(sum_matched_score, n_true)
301
+ panoptic_quality = _safe_divide(sum_matched_score, tp+fp/2+fn/2)
302
+
303
+ acc.update(
304
+ precision = precision(tp,fp,fn),
305
+ recall = recall(tp,fp,fn),
306
+ accuracy = accuracy(tp,fp,fn),
307
+ f1 = f1(tp,fp,fn),
308
+ mean_true_score = mean_true_score,
309
+ mean_matched_score = mean_matched_score,
310
+ panoptic_quality = panoptic_quality,
311
+ )
312
+
313
+ accumulate = tuple(namedtuple('DatasetMatching',acc.keys())(*acc.values()) for acc in accumulate)
314
+ return accumulate[0] if single_thresh else accumulate
315
+
316
+
317
+
318
+ # copied from scikit-image master for now (remove when part of a release)
319
+ def relabel_sequential(label_field, offset=1):
320
+ """Relabel arbitrary labels to {`offset`, ... `offset` + number_of_labels}.
321
+
322
+ This function also returns the forward map (mapping the original labels to
323
+ the reduced labels) and the inverse map (mapping the reduced labels back
324
+ to the original ones).
325
+
326
+ Parameters
327
+ ----------
328
+ label_field : numpy array of int, arbitrary shape
329
+ An array of labels, which must be non-negative integers.
330
+ offset : int, optional
331
+ The return labels will start at `offset`, which should be
332
+ strictly positive.
333
+
334
+ Returns
335
+ -------
336
+ relabeled : numpy array of int, same shape as `label_field`
337
+ The input label field with labels mapped to
338
+ {offset, ..., number_of_labels + offset - 1}.
339
+ The data type will be the same as `label_field`, except when
340
+ offset + number_of_labels causes overflow of the current data type.
341
+ forward_map : numpy array of int, shape ``(label_field.max() + 1,)``
342
+ The map from the original label space to the returned label
343
+ space. Can be used to re-apply the same mapping. See examples
344
+ for usage. The data type will be the same as `relabeled`.
345
+ inverse_map : 1D numpy array of int, of length offset + number of labels
346
+ The map from the new label space to the original space. This
347
+ can be used to reconstruct the original label field from the
348
+ relabeled one. The data type will be the same as `relabeled`.
349
+
350
+ Notes
351
+ -----
352
+ The label 0 is assumed to denote the background and is never remapped.
353
+
354
+ The forward map can be extremely big for some inputs, since its
355
+ length is given by the maximum of the label field. However, in most
356
+ situations, ``label_field.max()`` is much smaller than
357
+ ``label_field.size``, and in these cases the forward map is
358
+ guaranteed to be smaller than either the input or output images.
359
+
360
+ Examples
361
+ --------
362
+ >>> from skimage.segmentation import relabel_sequential
363
+ >>> label_field = np.array([1, 1, 5, 5, 8, 99, 42])
364
+ >>> relab, fw, inv = relabel_sequential(label_field)
365
+ >>> relab
366
+ array([1, 1, 2, 2, 3, 5, 4])
367
+ >>> fw
368
+ array([0, 1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
369
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0,
370
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
371
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
372
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5])
373
+ >>> inv
374
+ array([ 0, 1, 5, 8, 42, 99])
375
+ >>> (fw[label_field] == relab).all()
376
+ True
377
+ >>> (inv[relab] == label_field).all()
378
+ True
379
+ >>> relab, fw, inv = relabel_sequential(label_field, offset=5)
380
+ >>> relab
381
+ array([5, 5, 6, 6, 7, 9, 8])
382
+ """
383
+ offset = int(offset)
384
+ if offset <= 0:
385
+ raise ValueError("Offset must be strictly positive.")
386
+ if np.min(label_field) < 0:
387
+ raise ValueError("Cannot relabel array that contains negative values.")
388
+ max_label = int(label_field.max()) # Ensure max_label is an integer
389
+ if not np.issubdtype(label_field.dtype, np.integer):
390
+ new_type = np.min_scalar_type(max_label)
391
+ label_field = label_field.astype(new_type)
392
+ labels = np.unique(label_field)
393
+ labels0 = labels[labels != 0]
394
+ new_max_label = offset - 1 + len(labels0)
395
+ new_labels0 = np.arange(offset, new_max_label + 1)
396
+ output_type = label_field.dtype
397
+ required_type = np.min_scalar_type(new_max_label)
398
+ if np.dtype(required_type).itemsize > np.dtype(label_field.dtype).itemsize:
399
+ output_type = required_type
400
+ forward_map = np.zeros(max_label + 1, dtype=output_type)
401
+ forward_map[labels0] = new_labels0
402
+ inverse_map = np.zeros(new_max_label + 1, dtype=output_type)
403
+ inverse_map[offset:] = labels0
404
+ relabeled = forward_map[label_field]
405
+ return relabeled, forward_map, inverse_map
406
+
407
+
408
+
409
+ def group_matching_labels(ys, thresh=1e-10, criterion='iou'):
410
+ """
411
+ Group matching objects (i.e. assign the same label id) in a
412
+ list of label images (e.g. consecutive frames of a time-lapse).
413
+
414
+ Uses function `matching` (with provided `criterion` and `thresh`) to
415
+ iteratively/greedily match and group objects/labels in consecutive images of `ys`.
416
+ To that end, matching objects are grouped together by assigning the same label id,
417
+ whereas unmatched objects are assigned a new label id.
418
+ At the end of this process, each label group will have been assigned a unique id.
419
+
420
+ Note that the label images `ys` will not be modified. Instead, they will initially
421
+ be duplicated and converted to data type `np.int32` before objects are grouped and the result
422
+ is returned. (Note that `np.int32` limits the number of label groups to at most 2147483647.)
423
+
424
+ Example
425
+ -------
426
+ import numpy as np
427
+ from stardist.data import test_image_nuclei_2d
428
+ from stardist.matching import group_matching_labels
429
+
430
+ _y = test_image_nuclei_2d(return_mask=True)[1]
431
+ labels = np.stack([_y, 2*np.roll(_y,10)], axis=0)
432
+
433
+ labels_new = group_matching_labels(labels)
434
+
435
+ Parameters
436
+ ----------
437
+ ys : np.ndarray or list/tuple of np.ndarray
438
+ list/array of integer labels (2D or 3D)
439
+
440
+ """
441
+ # check 'ys' without making a copy
442
+ len(ys) > 1 or _raise(ValueError("'ys' must have 2 or more entries"))
443
+ if isinstance(ys, np.ndarray):
444
+ _check_label_array(ys, 'ys')
445
+ ys.ndim > 1 or _raise(ValueError("'ys' must be at least 2-dimensional"))
446
+ ys_grouped = np.empty_like(ys, dtype=np.int32)
447
+ else:
448
+ all(_check_label_array(y, 'ys') for y in ys) or _raise(ValueError("'ys' must be a list of label images"))
449
+ all(y.shape==ys[0].shape for y in ys) or _raise(ValueError("all label images must have the same shape"))
450
+ ys_grouped = np.empty((len(ys),)+ys[0].shape, dtype=np.int32)
451
+
452
+ def _match_single(y_prev, y, next_id):
453
+ y = y.astype(np.int32, copy=False)
454
+ res = matching(y_prev, y, report_matches=True, thresh=thresh, criterion=criterion)
455
+ # relabel dict (for matching labels) that maps label ids from y -> y_prev
456
+ relabel = dict(reversed(res.matched_pairs[i]) for i in res.matched_tps)
457
+ y_grouped = np.zeros_like(y)
458
+ for r in regionprops(y):
459
+ m = (y[r.slice] == r.label)
460
+ if r.label in relabel:
461
+ y_grouped[r.slice][m] = relabel[r.label]
462
+ else:
463
+ y_grouped[r.slice][m] = next_id
464
+ next_id += 1
465
+ return y_grouped, next_id
466
+
467
+ ys_grouped[0] = ys[0]
468
+ next_id = ys_grouped[0].max() + 1
469
+ for i in range(len(ys)-1):
470
+ ys_grouped[i+1], next_id = _match_single(ys_grouped[i], ys[i+1], next_id)
471
+ return ys_grouped
472
+
473
+
474
+
475
+ def _shuffle_labels(y):
476
+ _check_label_array(y, 'y')
477
+ y2 = np.zeros_like(y)
478
+ ids = tuple(set(np.unique(y)) - {0})
479
+ relabel = dict(zip(ids,np.random.permutation(ids)))
480
+ for r in regionprops(y):
481
+ m = (y[r.slice] == r.label)
482
+ y2[r.slice][m] = relabel[r.label]
483
+ return y2
stardist_pkg/models/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, print_function
2
+
3
+ from .model2d import Config2D, StarDist2D, StarDistData2D
4
+
5
+ from csbdeep.utils import backend_channels_last
6
+ from csbdeep.utils.tf import keras_import
7
+ K = keras_import('backend')
8
+ if not backend_channels_last():
9
+ raise NotImplementedError(
10
+ "Keras is configured to use the '%s' image data format, which is currently not supported. "
11
+ "Please change it to use 'channels_last' instead: "
12
+ "https://keras.io/getting-started/faq/#where-is-the-keras-configuration-file-stored" % K.image_data_format()
13
+ )
14
+ del backend_channels_last, K
15
+
16
+ from csbdeep.models import register_model, register_aliases, clear_models_and_aliases
17
+ # register pre-trained models and aliases (TODO: replace with updatable solution)
18
+ clear_models_and_aliases(StarDist2D, StarDist3D)
19
+ register_model(StarDist2D, '2D_versatile_fluo', 'https://github.com/stardist/stardist-models/releases/download/v0.1/python_2D_versatile_fluo.zip', '8db40dacb5a1311b8d2c447ad934fb8a')
20
+ register_model(StarDist2D, '2D_versatile_he', 'https://github.com/stardist/stardist-models/releases/download/v0.1/python_2D_versatile_he.zip', 'bf34cb3c0e5b3435971e18d66778a4ec')
21
+ register_model(StarDist2D, '2D_paper_dsb2018', 'https://github.com/stardist/stardist-models/releases/download/v0.1/python_2D_paper_dsb2018.zip', '6287bf283f85c058ec3e7094b41039b5')
22
+ register_model(StarDist2D, '2D_demo', 'https://github.com/stardist/stardist-models/releases/download/v0.1/python_2D_demo.zip', '31f70402f58c50dd231ec31b4375ea2c')
23
+
24
+ register_aliases(StarDist2D, '2D_paper_dsb2018', 'DSB 2018 (from StarDist 2D paper)')
25
+ register_aliases(StarDist2D, '2D_versatile_fluo', 'Versatile (fluorescent nuclei)')
26
+ register_aliases(StarDist2D, '2D_versatile_he', 'Versatile (H&E nuclei)')
27
+ del register_model, register_aliases, clear_models_and_aliases
stardist_pkg/models/base.py ADDED
@@ -0,0 +1,1196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, unicode_literals, absolute_import, division
2
+
3
+ import numpy as np
4
+ import sys
5
+ import warnings
6
+ import math
7
+ from tqdm import tqdm
8
+ from collections import namedtuple
9
+ from pathlib import Path
10
+ import threading
11
+ import functools
12
+ import scipy.ndimage as ndi
13
+ import numbers
14
+
15
+ from csbdeep.models.base_model import BaseModel
16
+ from csbdeep.utils.tf import export_SavedModel, keras_import, IS_TF_1, CARETensorBoard
17
+
18
+ import tensorflow as tf
19
+ K = keras_import('backend')
20
+ Sequence = keras_import('utils', 'Sequence')
21
+ Adam = keras_import('optimizers', 'Adam')
22
+ ReduceLROnPlateau, TensorBoard = keras_import('callbacks', 'ReduceLROnPlateau', 'TensorBoard')
23
+
24
+ from csbdeep.utils import _raise, backend_channels_last, axes_check_and_normalize, axes_dict, load_json, save_json
25
+ from csbdeep.internals.predict import tile_iterator, total_n_tiles
26
+ from csbdeep.internals.train import RollingSequence
27
+ from csbdeep.data import Resizer
28
+
29
+ from ..sample_patches import get_valid_inds
30
+ from ..nms import _ind_prob_thresh
31
+ from ..utils import _is_power_of_2, _is_floatarray, optimize_threshold
32
+
33
+ # TODO: helper function to check if receptive field of cnn is sufficient for object sizes in GT
34
+
35
+ def generic_masked_loss(mask, loss, weights=1, norm_by_mask=True, reg_weight=0, reg_penalty=K.abs):
36
+ def _loss(y_true, y_pred):
37
+ actual_loss = K.mean(mask * weights * loss(y_true, y_pred), axis=-1)
38
+ norm_mask = (K.mean(mask) + K.epsilon()) if norm_by_mask else 1
39
+ if reg_weight > 0:
40
+ reg_loss = K.mean((1-mask) * reg_penalty(y_pred), axis=-1)
41
+ return actual_loss / norm_mask + reg_weight * reg_loss
42
+ else:
43
+ return actual_loss / norm_mask
44
+ return _loss
45
+
46
+ def masked_loss(mask, penalty, reg_weight, norm_by_mask):
47
+ loss = lambda y_true, y_pred: penalty(y_true - y_pred)
48
+ return generic_masked_loss(mask, loss, reg_weight=reg_weight, norm_by_mask=norm_by_mask)
49
+
50
+ # TODO: should we use norm_by_mask=True in the loss or only in a metric?
51
+ # previous 2D behavior was norm_by_mask=False
52
+ # same question for reg_weight? use 1e-4 (as in 3D) or 0 (as in 2D)?
53
+
54
+ def masked_loss_mae(mask, reg_weight=0, norm_by_mask=True):
55
+ return masked_loss(mask, K.abs, reg_weight=reg_weight, norm_by_mask=norm_by_mask)
56
+
57
+ def masked_loss_mse(mask, reg_weight=0, norm_by_mask=True):
58
+ return masked_loss(mask, K.square, reg_weight=reg_weight, norm_by_mask=norm_by_mask)
59
+
60
+ def masked_metric_mae(mask):
61
+ def relevant_mae(y_true, y_pred):
62
+ return masked_loss(mask, K.abs, reg_weight=0, norm_by_mask=True)(y_true, y_pred)
63
+ return relevant_mae
64
+
65
+ def masked_metric_mse(mask):
66
+ def relevant_mse(y_true, y_pred):
67
+ return masked_loss(mask, K.square, reg_weight=0, norm_by_mask=True)(y_true, y_pred)
68
+ return relevant_mse
69
+
70
+ def kld(y_true, y_pred):
71
+ y_true = K.clip(y_true, K.epsilon(), 1)
72
+ y_pred = K.clip(y_pred, K.epsilon(), 1)
73
+ return K.mean(K.binary_crossentropy(y_true, y_pred) - K.binary_crossentropy(y_true, y_true), axis=-1)
74
+
75
+
76
+ def masked_loss_iou(mask, reg_weight=0, norm_by_mask=True):
77
+ def iou_loss(y_true, y_pred):
78
+ axis = -1 if backend_channels_last() else 1
79
+ # y_pred can be negative (since not constrained) -> 'inter' can be very large for y_pred << 0
80
+ # - clipping y_pred values at 0 can lead to vanishing gradients
81
+ # - 'K.sign(y_pred)' term fixes issue by enforcing that y_pred values >= 0 always lead to larger 'inter' (lower loss)
82
+ inter = K.mean(K.sign(y_pred)*K.square(K.minimum(y_true,y_pred)), axis=axis)
83
+ union = K.mean(K.square(K.maximum(y_true,y_pred)), axis=axis)
84
+ iou = inter/(union+K.epsilon())
85
+ iou = K.expand_dims(iou,axis)
86
+ loss = 1. - iou # + 0.005*K.abs(y_true-y_pred)
87
+ return loss
88
+ return generic_masked_loss(mask, iou_loss, reg_weight=reg_weight, norm_by_mask=norm_by_mask)
89
+
90
+ def masked_metric_iou(mask, reg_weight=0, norm_by_mask=True):
91
+ def iou_metric(y_true, y_pred):
92
+ axis = -1 if backend_channels_last() else 1
93
+ y_pred = K.maximum(0., y_pred)
94
+ inter = K.mean(K.square(K.minimum(y_true,y_pred)), axis=axis)
95
+ union = K.mean(K.square(K.maximum(y_true,y_pred)), axis=axis)
96
+ iou = inter/(union+K.epsilon())
97
+ loss = K.expand_dims(iou,axis)
98
+ return loss
99
+ return generic_masked_loss(mask, iou_metric, reg_weight=reg_weight, norm_by_mask=norm_by_mask)
100
+
101
+
102
+ def weighted_categorical_crossentropy(weights, ndim):
103
+ """ ndim = (2,3) """
104
+
105
+ axis = -1 if backend_channels_last() else 1
106
+ shape = [1]*(ndim+2)
107
+ shape[axis] = len(weights)
108
+ weights = np.broadcast_to(weights, shape)
109
+ weights = K.constant(weights)
110
+
111
+ def weighted_cce(y_true, y_pred):
112
+ # ignore pixels that have y_true (prob_class) < 0
113
+ mask = K.cast(y_true>=0, K.floatx())
114
+ y_pred /= K.sum(y_pred+K.epsilon(), axis=axis, keepdims=True)
115
+ y_pred = K.clip(y_pred, K.epsilon(), 1. - K.epsilon())
116
+ loss = - K.sum(weights*mask*y_true*K.log(y_pred), axis = axis)
117
+ return loss
118
+
119
+ return weighted_cce
120
+
121
+
122
+ class StarDistDataBase(RollingSequence):
123
+
124
+ def __init__(self, X, Y, n_rays, grid, batch_size, patch_size, length,
125
+ n_classes=None, classes=None,
126
+ use_gpu=False, sample_ind_cache=True, maxfilter_patch_size=None, augmenter=None, foreground_prob=0):
127
+
128
+ super().__init__(data_size=len(X), batch_size=batch_size, length=length, shuffle=True)
129
+
130
+ if isinstance(X, (np.ndarray, tuple, list)):
131
+ X = [x.astype(np.float32, copy=False) for x in X]
132
+
133
+ # sanity checks
134
+ len(X)==len(Y) and len(X)>0 or _raise(ValueError("X and Y can't be empty and must have same length"))
135
+
136
+ if classes is None:
137
+ # set classes to None for all images (i.e. defaults to every object instance assigned the same class)
138
+ classes = (None,)*len(X)
139
+ else:
140
+ n_classes is not None or warnings.warn("Ignoring classes since n_classes is None")
141
+
142
+ len(classes)==len(X) or _raise(ValueError("X and classes must have same length"))
143
+
144
+ self.n_classes, self.classes = n_classes, classes
145
+
146
+ nD = len(patch_size)
147
+ assert nD in (2,3)
148
+ x_ndim = X[0].ndim
149
+ assert x_ndim in (nD,nD+1)
150
+
151
+ if isinstance(X, (np.ndarray, tuple, list)) and \
152
+ isinstance(Y, (np.ndarray, tuple, list)):
153
+ all(y.ndim==nD and x.ndim==x_ndim and x.shape[:nD]==y.shape for x,y in zip(X,Y)) or _raise(ValueError("images and masks should have corresponding shapes/dimensions"))
154
+ all(x.shape[:nD]>=tuple(patch_size) for x in X) or _raise(ValueError("Some images are too small for given patch_size {patch_size}".format(patch_size=patch_size)))
155
+
156
+ if x_ndim == nD:
157
+ self.n_channel = None
158
+ else:
159
+ self.n_channel = X[0].shape[-1]
160
+ if isinstance(X, (np.ndarray, tuple, list)):
161
+ assert all(x.shape[-1]==self.n_channel for x in X)
162
+
163
+ assert 0 <= foreground_prob <= 1
164
+
165
+ self.X, self.Y = X, Y
166
+ # self.batch_size = batch_size
167
+ self.n_rays = n_rays
168
+ self.patch_size = patch_size
169
+ self.ss_grid = (slice(None),) + tuple(slice(0, None, g) for g in grid)
170
+ self.grid = tuple(grid)
171
+ self.use_gpu = bool(use_gpu)
172
+ if augmenter is None:
173
+ augmenter = lambda *args: args
174
+ callable(augmenter) or _raise(ValueError("augmenter must be None or callable"))
175
+ self.augmenter = augmenter
176
+ self.foreground_prob = foreground_prob
177
+
178
+ if self.use_gpu:
179
+ from gputools import max_filter
180
+ self.max_filter = lambda y, patch_size: max_filter(y.astype(np.float32), patch_size)
181
+ else:
182
+ from scipy.ndimage.filters import maximum_filter
183
+ self.max_filter = lambda y, patch_size: maximum_filter(y, patch_size, mode='constant')
184
+
185
+ self.maxfilter_patch_size = maxfilter_patch_size if maxfilter_patch_size is not None else self.patch_size
186
+
187
+ self.sample_ind_cache = sample_ind_cache
188
+ self._ind_cache_fg = {}
189
+ self._ind_cache_all = {}
190
+ self.lock = threading.Lock()
191
+
192
+
193
+ def get_valid_inds(self, k, foreground_prob=None):
194
+ if foreground_prob is None:
195
+ foreground_prob = self.foreground_prob
196
+ foreground_only = np.random.uniform() < foreground_prob
197
+ _ind_cache = self._ind_cache_fg if foreground_only else self._ind_cache_all
198
+ if k in _ind_cache:
199
+ inds = _ind_cache[k]
200
+ else:
201
+ patch_filter = (lambda y,p: self.max_filter(y, self.maxfilter_patch_size) > 0) if foreground_only else None
202
+ inds = get_valid_inds(self.Y[k], self.patch_size, patch_filter=patch_filter)
203
+ if self.sample_ind_cache:
204
+ with self.lock:
205
+ _ind_cache[k] = inds
206
+ if foreground_only and len(inds[0])==0:
207
+ # no foreground pixels available
208
+ return self.get_valid_inds(k, foreground_prob=0)
209
+ return inds
210
+
211
+
212
+ def channels_as_tuple(self, x):
213
+ if self.n_channel is None:
214
+ return (x,)
215
+ else:
216
+ return tuple(x[...,i] for i in range(self.n_channel))
217
+
218
+
219
+
220
+ class StarDistBase(BaseModel):
221
+
222
+ def __init__(self, config, name=None, basedir='.'):
223
+ super().__init__(config=config, name=name, basedir=basedir)
224
+ threshs = dict(prob=None, nms=None)
225
+ if basedir is not None:
226
+ try:
227
+ threshs = load_json(str(self.logdir / 'thresholds.json'))
228
+ print("Loading thresholds from 'thresholds.json'.")
229
+ if threshs.get('prob') is None or not (0 < threshs.get('prob') < 1):
230
+ print("- Invalid 'prob' threshold (%s), using default value." % str(threshs.get('prob')))
231
+ threshs['prob'] = None
232
+ if threshs.get('nms') is None or not (0 < threshs.get('nms') < 1):
233
+ print("- Invalid 'nms' threshold (%s), using default value." % str(threshs.get('nms')))
234
+ threshs['nms'] = None
235
+ except FileNotFoundError:
236
+ if config is None and len(tuple(self.logdir.glob('*.h5'))) > 0:
237
+ print("Couldn't load thresholds from 'thresholds.json', using default values. "
238
+ "(Call 'optimize_thresholds' to change that.)")
239
+
240
+ self.thresholds = dict (
241
+ prob = 0.5 if threshs['prob'] is None else threshs['prob'],
242
+ nms = 0.4 if threshs['nms'] is None else threshs['nms'],
243
+ )
244
+ print("Using default values: prob_thresh={prob:g}, nms_thresh={nms:g}.".format(prob=self.thresholds.prob, nms=self.thresholds.nms))
245
+
246
+
247
+ @property
248
+ def thresholds(self):
249
+ return self._thresholds
250
+
251
+ def _is_multiclass(self):
252
+ return (self.config.n_classes is not None)
253
+
254
+ def _parse_classes_arg(self, classes, length):
255
+ """ creates a proper classes tuple from different possible "classes" arguments in model.train()
256
+
257
+ classes can be
258
+ "auto" -> all objects will be assigned to the first foreground class (unless n_classes is None)
259
+ single integer -> all objects will be assigned that class
260
+ tuple, list, ndarray -> do nothing (needs to be of given length)
261
+
262
+ returns a tuple of given length
263
+ """
264
+ if isinstance(classes, str):
265
+ classes == "auto" or _raise(ValueError(f"classes = '{classes}': only 'auto' supported as string argument for classes"))
266
+ if self.config.n_classes is None:
267
+ classes = None
268
+ elif self.config.n_classes == 1:
269
+ classes = (1,)*length
270
+ else:
271
+ raise ValueError("using classes = 'auto' for n_classes > 1 not supported")
272
+ elif isinstance(classes, (tuple, list, np.ndarray)):
273
+ len(classes) == length or _raise(ValueError(f"len(classes) should be {length}!"))
274
+ else:
275
+ raise ValueError("classes should either be 'auto' or a list of scalars/label dicts")
276
+ return classes
277
+
278
+ @thresholds.setter
279
+ def thresholds(self, d):
280
+ self._thresholds = namedtuple('Thresholds',d.keys())(*d.values())
281
+
282
+
283
+ def prepare_for_training(self, optimizer=None):
284
+ """Prepare for neural network training.
285
+
286
+ Compiles the model and creates
287
+ `Keras Callbacks <https://keras.io/callbacks/>`_ to be used for training.
288
+
289
+ Note that this method will be implicitly called once by :func:`train`
290
+ (with default arguments) if not done so explicitly beforehand.
291
+
292
+ Parameters
293
+ ----------
294
+ optimizer : obj or None
295
+ Instance of a `Keras Optimizer <https://keras.io/optimizers/>`_ to be used for training.
296
+ If ``None`` (default), uses ``Adam`` with the learning rate specified in ``config``.
297
+
298
+ """
299
+ if optimizer is None:
300
+ optimizer = Adam(self.config.train_learning_rate)
301
+
302
+ masked_dist_loss = {'mse': masked_loss_mse,
303
+ 'mae': masked_loss_mae,
304
+ 'iou': masked_loss_iou,
305
+ }[self.config.train_dist_loss]
306
+ prob_loss = 'binary_crossentropy'
307
+
308
+
309
+ def split_dist_true_mask(dist_true_mask):
310
+ return tf.split(dist_true_mask, num_or_size_splits=[self.config.n_rays,-1], axis=-1)
311
+
312
+ def dist_loss(dist_true_mask, dist_pred):
313
+ dist_true, dist_mask = split_dist_true_mask(dist_true_mask)
314
+ return masked_dist_loss(dist_mask, reg_weight=self.config.train_background_reg)(dist_true, dist_pred)
315
+
316
+ def dist_iou_metric(dist_true_mask, dist_pred):
317
+ dist_true, dist_mask = split_dist_true_mask(dist_true_mask)
318
+ return masked_metric_iou(dist_mask, reg_weight=0)(dist_true, dist_pred)
319
+
320
+ def relevant_mae(dist_true_mask, dist_pred):
321
+ dist_true, dist_mask = split_dist_true_mask(dist_true_mask)
322
+ return masked_metric_mae(dist_mask)(dist_true, dist_pred)
323
+
324
+ def relevant_mse(dist_true_mask, dist_pred):
325
+ dist_true, dist_mask = split_dist_true_mask(dist_true_mask)
326
+ return masked_metric_mse(dist_mask)(dist_true, dist_pred)
327
+
328
+
329
+ if self._is_multiclass():
330
+ prob_class_loss = weighted_categorical_crossentropy(self.config.train_class_weights, ndim=self.config.n_dim)
331
+ loss = [prob_loss, dist_loss, prob_class_loss]
332
+ else:
333
+ loss = [prob_loss, dist_loss]
334
+
335
+ self.keras_model.compile(optimizer, loss = loss,
336
+ loss_weights = list(self.config.train_loss_weights),
337
+ metrics = {'prob': kld,
338
+ 'dist': [relevant_mae, relevant_mse, dist_iou_metric]})
339
+
340
+ self.callbacks = []
341
+ if self.basedir is not None:
342
+ self.callbacks += self._checkpoint_callbacks()
343
+
344
+ if self.config.train_tensorboard:
345
+ if IS_TF_1:
346
+ self.callbacks.append(CARETensorBoard(log_dir=str(self.logdir), prefix_with_timestamp=False, n_images=3, write_images=True, prob_out=False))
347
+ else:
348
+ self.callbacks.append(TensorBoard(log_dir=str(self.logdir/'logs'), write_graph=False, profile_batch=0))
349
+
350
+ if self.config.train_reduce_lr is not None:
351
+ rlrop_params = self.config.train_reduce_lr
352
+ if 'verbose' not in rlrop_params:
353
+ rlrop_params['verbose'] = True
354
+ # TF2: add as first callback to put 'lr' in the logs for TensorBoard
355
+ self.callbacks.insert(0,ReduceLROnPlateau(**rlrop_params))
356
+
357
+ self._model_prepared = True
358
+
359
+
360
+ def _predict_setup(self, img, axes, normalizer, n_tiles, show_tile_progress, predict_kwargs):
361
+ """ Shared setup code between `predict` and `predict_sparse` """
362
+ if n_tiles is None:
363
+ n_tiles = [1]*img.ndim
364
+ try:
365
+ n_tiles = tuple(n_tiles)
366
+ img.ndim == len(n_tiles) or _raise(TypeError())
367
+ except TypeError:
368
+ raise ValueError("n_tiles must be an iterable of length %d" % img.ndim)
369
+ all(np.isscalar(t) and 1<=t and int(t)==t for t in n_tiles) or _raise(
370
+ ValueError("all values of n_tiles must be integer values >= 1"))
371
+
372
+ n_tiles = tuple(map(int,n_tiles))
373
+
374
+ axes = self._normalize_axes(img, axes)
375
+ axes_net = self.config.axes
376
+
377
+ _permute_axes = self._make_permute_axes(axes, axes_net)
378
+ x = _permute_axes(img) # x has axes_net semantics
379
+
380
+ channel = axes_dict(axes_net)['C']
381
+ self.config.n_channel_in == x.shape[channel] or _raise(ValueError())
382
+ axes_net_div_by = self._axes_div_by(axes_net)
383
+
384
+ grid = tuple(self.config.grid)
385
+ len(grid) == len(axes_net)-1 or _raise(ValueError())
386
+ grid_dict = dict(zip(axes_net.replace('C',''),grid))
387
+
388
+ normalizer = self._check_normalizer_resizer(normalizer, None)[0]
389
+ resizer = StarDistPadAndCropResizer(grid=grid_dict)
390
+
391
+ x = normalizer.before(x, axes_net)
392
+ x = resizer.before(x, axes_net, axes_net_div_by)
393
+
394
+ if not _is_floatarray(x):
395
+ warnings.warn("Predicting on non-float input... ( forgot to normalize? )")
396
+
397
+ def predict_direct(x):
398
+ ys = self.keras_model.predict(x[np.newaxis], **predict_kwargs)
399
+ return tuple(y[0] for y in ys)
400
+
401
+ def tiling_setup():
402
+ assert np.prod(n_tiles) > 1
403
+ tiling_axes = axes_net.replace('C','') # axes eligible for tiling
404
+ x_tiling_axis = tuple(axes_dict(axes_net)[a] for a in tiling_axes) # numerical axis ids for x
405
+ axes_net_tile_overlaps = self._axes_tile_overlap(axes_net)
406
+ # hack: permute tiling axis in the same way as img -> x was permuted
407
+ _n_tiles = _permute_axes(np.empty(n_tiles,bool)).shape
408
+ (all(_n_tiles[i] == 1 for i in range(x.ndim) if i not in x_tiling_axis) or
409
+ _raise(ValueError("entry of n_tiles > 1 only allowed for axes '%s'" % tiling_axes)))
410
+
411
+ sh = [s//grid_dict.get(a,1) for a,s in zip(axes_net,x.shape)]
412
+ sh[channel] = None
413
+ def create_empty_output(n_channel, dtype=np.float32):
414
+ sh[channel] = n_channel
415
+ return np.empty(sh,dtype)
416
+
417
+ if callable(show_tile_progress):
418
+ progress, _show_tile_progress = show_tile_progress, True
419
+ else:
420
+ progress, _show_tile_progress = tqdm, show_tile_progress
421
+
422
+ n_block_overlaps = [int(np.ceil(overlap/blocksize)) for overlap, blocksize
423
+ in zip(axes_net_tile_overlaps, axes_net_div_by)]
424
+
425
+ num_tiles_used = total_n_tiles(x, _n_tiles, block_sizes=axes_net_div_by, n_block_overlaps=n_block_overlaps)
426
+
427
+ tile_generator = progress(tile_iterator(x, _n_tiles, block_sizes=axes_net_div_by, n_block_overlaps=n_block_overlaps),
428
+ disable=(not _show_tile_progress), total=num_tiles_used)
429
+
430
+ return tile_generator, tuple(sh), create_empty_output
431
+
432
+ return x, axes, axes_net, axes_net_div_by, _permute_axes, resizer, n_tiles, grid, grid_dict, channel, predict_direct, tiling_setup
433
+
434
+
435
+ def _predict_generator(self, img, axes=None, normalizer=None, n_tiles=None, show_tile_progress=True, **predict_kwargs):
436
+ """Predict.
437
+
438
+ Parameters
439
+ ----------
440
+ img : :class:`numpy.ndarray`
441
+ Input image
442
+ axes : str or None
443
+ Axes of the input ``img``.
444
+ ``None`` denotes that axes of img are the same as denoted in the config.
445
+ normalizer : :class:`csbdeep.data.Normalizer` or None
446
+ (Optional) normalization of input image before prediction.
447
+ Note that the default (``None``) assumes ``img`` to be already normalized.
448
+ n_tiles : iterable or None
449
+ Out of memory (OOM) errors can occur if the input image is too large.
450
+ To avoid this problem, the input image is broken up into (overlapping) tiles
451
+ that are processed independently and re-assembled.
452
+ This parameter denotes a tuple of the number of tiles for every image axis (see ``axes``).
453
+ ``None`` denotes that no tiling should be used.
454
+ show_tile_progress: bool or callable
455
+ If boolean, indicates whether to show progress (via tqdm) during tiled prediction.
456
+ If callable, must be a drop-in replacement for tqdm.
457
+ show_tile_progress: bool
458
+ Whether to show progress during tiled prediction.
459
+ predict_kwargs: dict
460
+ Keyword arguments for ``predict`` function of Keras model.
461
+
462
+ Returns
463
+ -------
464
+ (:class:`numpy.ndarray`, :class:`numpy.ndarray`, [:class:`numpy.ndarray`])
465
+ Returns the tuple (`prob`, `dist`, [`prob_class`]) of per-pixel object probabilities and star-convex polygon/polyhedra distances.
466
+ In multiclass prediction mode, `prob_class` is the probability map for each of the 1+'n_classes' classes (first class is background)
467
+
468
+ """
469
+
470
+ x, axes, axes_net, axes_net_div_by, _permute_axes, resizer, n_tiles, grid, grid_dict, channel, predict_direct, tiling_setup = \
471
+ self._predict_setup(img, axes, normalizer, n_tiles, show_tile_progress, predict_kwargs)
472
+
473
+ if np.prod(n_tiles) > 1:
474
+ tile_generator, output_shape, create_empty_output = tiling_setup()
475
+
476
+ prob = create_empty_output(1)
477
+ dist = create_empty_output(self.config.n_rays)
478
+ if self._is_multiclass():
479
+ prob_class = create_empty_output(self.config.n_classes+1)
480
+ result = (prob, dist, prob_class)
481
+ else:
482
+ result = (prob, dist)
483
+
484
+ for tile, s_src, s_dst in tile_generator:
485
+ # predict_direct -> prob, dist, [prob_class if multi_class]
486
+ result_tile = predict_direct(tile)
487
+ # account for grid
488
+ s_src = [slice(s.start//grid_dict.get(a,1),s.stop//grid_dict.get(a,1)) for s,a in zip(s_src,axes_net)]
489
+ s_dst = [slice(s.start//grid_dict.get(a,1),s.stop//grid_dict.get(a,1)) for s,a in zip(s_dst,axes_net)]
490
+ # prob and dist have different channel dimensionality than image x
491
+ s_src[channel] = slice(None)
492
+ s_dst[channel] = slice(None)
493
+ s_src, s_dst = tuple(s_src), tuple(s_dst)
494
+ # print(s_src,s_dst)
495
+ for part, part_tile in zip(result, result_tile):
496
+ part[s_dst] = part_tile[s_src]
497
+ yield # yield None after each processed tile
498
+ else:
499
+ # predict_direct -> prob, dist, [prob_class if multi_class]
500
+ result = predict_direct(x)
501
+
502
+ result = [resizer.after(part, axes_net) for part in result]
503
+
504
+ # result = (prob, dist) for legacy or (prob, dist, prob_class) for multiclass
505
+
506
+ # prob
507
+ result[0] = np.take(result[0],0,axis=channel)
508
+ # dist
509
+ result[1] = np.maximum(1e-3, result[1]) # avoid small dist values to prevent problems with Qhull
510
+ result[1] = np.moveaxis(result[1],channel,-1)
511
+
512
+ if self._is_multiclass():
513
+ # prob_class
514
+ result[2] = np.moveaxis(result[2],channel,-1)
515
+
516
+ # last "yield" is the actual output that would have been "return"ed if this was a regular function
517
+ yield tuple(result)
518
+
519
+
520
+ @functools.wraps(_predict_generator)
521
+ def predict(self, *args, **kwargs):
522
+ # return last "yield"ed value of generator
523
+ r = None
524
+ for r in self._predict_generator(*args, **kwargs):
525
+ pass
526
+ return r
527
+
528
+
529
+ def _predict_sparse_generator(self, img, prob_thresh=None, axes=None, normalizer=None, n_tiles=None, show_tile_progress=True, b=2, **predict_kwargs):
530
+ """ Sparse version of model.predict()
531
+ Returns
532
+ -------
533
+ (prob, dist, [prob_class], points) flat list of probs, dists, (optional prob_class) and points
534
+ """
535
+ if prob_thresh is None: prob_thresh = self.thresholds.prob
536
+
537
+ x, axes, axes_net, axes_net_div_by, _permute_axes, resizer, n_tiles, grid, grid_dict, channel, predict_direct, tiling_setup = \
538
+ self._predict_setup(img, axes, normalizer, n_tiles, show_tile_progress, predict_kwargs)
539
+
540
+ def _prep(prob, dist):
541
+ prob = np.take(prob,0,axis=channel)
542
+ dist = np.moveaxis(dist,channel,-1)
543
+ dist = np.maximum(1e-3, dist)
544
+ return prob, dist
545
+
546
+ proba, dista, pointsa, prob_class = [],[],[], []
547
+
548
+ if np.prod(n_tiles) > 1:
549
+ tile_generator, output_shape, create_empty_output = tiling_setup()
550
+
551
+ sh = list(output_shape)
552
+ sh[channel] = 1;
553
+
554
+ proba, dista, pointsa, prob_classa = [], [], [], []
555
+
556
+ for tile, s_src, s_dst in tile_generator:
557
+
558
+ results_tile = predict_direct(tile)
559
+
560
+ # account for grid
561
+ s_src = [slice(s.start//grid_dict.get(a,1),s.stop//grid_dict.get(a,1)) for s,a in zip(s_src,axes_net)]
562
+ s_dst = [slice(s.start//grid_dict.get(a,1),s.stop//grid_dict.get(a,1)) for s,a in zip(s_dst,axes_net)]
563
+ s_src[channel] = slice(None)
564
+ s_dst[channel] = slice(None)
565
+ s_src, s_dst = tuple(s_src), tuple(s_dst)
566
+
567
+ prob_tile, dist_tile = results_tile[:2]
568
+ prob_tile, dist_tile = _prep(prob_tile[s_src], dist_tile[s_src])
569
+
570
+ bs = list((b if s.start==0 else -1, b if s.stop==_sh else -1) for s,_sh in zip(s_dst, sh))
571
+ bs.pop(channel)
572
+ inds = _ind_prob_thresh(prob_tile, prob_thresh, b=bs)
573
+ proba.extend(prob_tile[inds].copy())
574
+ dista.extend(dist_tile[inds].copy())
575
+ _points = np.stack(np.where(inds), axis=1)
576
+ offset = list(s.start for i,s in enumerate(s_dst))
577
+ offset.pop(channel)
578
+ _points = _points + np.array(offset).reshape((1,len(offset)))
579
+ _points = _points * np.array(self.config.grid).reshape((1,len(self.config.grid)))
580
+ pointsa.extend(_points)
581
+
582
+ if self._is_multiclass():
583
+ p = results_tile[2][s_src].copy()
584
+ p = np.moveaxis(p,channel,-1)
585
+ prob_classa.extend(p[inds])
586
+ yield # yield None after each processed tile
587
+
588
+ else:
589
+ # predict_direct -> prob, dist, [prob_class if multi_class]
590
+ results = predict_direct(x)
591
+ prob, dist = results[:2]
592
+ prob, dist = _prep(prob, dist)
593
+ inds = _ind_prob_thresh(prob, prob_thresh, b=b)
594
+ proba = prob[inds].copy()
595
+ dista = dist[inds].copy()
596
+ _points = np.stack(np.where(inds), axis=1)
597
+ pointsa = (_points * np.array(self.config.grid).reshape((1,len(self.config.grid))))
598
+
599
+ if self._is_multiclass():
600
+ p = np.moveaxis(results[2],channel,-1)
601
+ prob_classa = p[inds].copy()
602
+
603
+
604
+ proba = np.asarray(proba)
605
+ dista = np.asarray(dista).reshape((-1,self.config.n_rays))
606
+ pointsa = np.asarray(pointsa).reshape((-1,self.config.n_dim))
607
+
608
+ idx = resizer.filter_points(x.ndim, pointsa, axes_net)
609
+ proba = proba[idx]
610
+ dista = dista[idx]
611
+ pointsa = pointsa[idx]
612
+
613
+ # last "yield" is the actual output that would have been "return"ed if this was a regular function
614
+ if self._is_multiclass():
615
+ prob_classa = np.asarray(prob_classa).reshape((-1,self.config.n_classes+1))
616
+ prob_classa = prob_classa[idx]
617
+ yield proba, dista, prob_classa, pointsa
618
+ else:
619
+ prob_classa = None
620
+ yield proba, dista, pointsa
621
+
622
+
623
+ @functools.wraps(_predict_sparse_generator)
624
+ def predict_sparse(self, *args, **kwargs):
625
+ # return last "yield"ed value of generator
626
+ r = None
627
+ for r in self._predict_sparse_generator(*args, **kwargs):
628
+ pass
629
+ return r
630
+
631
+
632
+ def _predict_instances_generator(self, img, axes=None, normalizer=None,
633
+ sparse=True,
634
+ prob_thresh=None, nms_thresh=None,
635
+ scale=None,
636
+ n_tiles=None, show_tile_progress=True,
637
+ verbose=False,
638
+ return_labels=True,
639
+ predict_kwargs=None, nms_kwargs=None,
640
+ overlap_label=None, return_predict=False):
641
+ """Predict instance segmentation from input image.
642
+
643
+ Parameters
644
+ ----------
645
+ img : :class:`numpy.ndarray`
646
+ Input image
647
+ axes : str or None
648
+ Axes of the input ``img``.
649
+ ``None`` denotes that axes of img are the same as denoted in the config.
650
+ normalizer : :class:`csbdeep.data.Normalizer` or None
651
+ (Optional) normalization of input image before prediction.
652
+ Note that the default (``None``) assumes ``img`` to be already normalized.
653
+ sparse: bool
654
+ If true, aggregate probabilities/distances sparsely during tiled
655
+ prediction to save memory (recommended).
656
+ prob_thresh : float or None
657
+ Consider only object candidates from pixels with predicted object probability
658
+ above this threshold (also see `optimize_thresholds`).
659
+ nms_thresh : float or None
660
+ Perform non-maximum suppression that considers two objects to be the same
661
+ when their area/surface overlap exceeds this threshold (also see `optimize_thresholds`).
662
+ scale: None or float or iterable
663
+ Scale the input image internally by this factor and rescale the output accordingly.
664
+ All spatial axes (X,Y,Z) will be scaled if a scalar value is provided.
665
+ Alternatively, multiple scale values (compatible with input `axes`) can be used
666
+ for more fine-grained control (scale values for non-spatial axes must be 1).
667
+ n_tiles : iterable or None
668
+ Out of memory (OOM) errors can occur if the input image is too large.
669
+ To avoid this problem, the input image is broken up into (overlapping) tiles
670
+ that are processed independently and re-assembled.
671
+ This parameter denotes a tuple of the number of tiles for every image axis (see ``axes``).
672
+ ``None`` denotes that no tiling should be used.
673
+ show_tile_progress: bool
674
+ Whether to show progress during tiled prediction.
675
+ verbose: bool
676
+ Whether to print some info messages.
677
+ return_labels: bool
678
+ Whether to create a label image, otherwise return None in its place.
679
+ predict_kwargs: dict
680
+ Keyword arguments for ``predict`` function of Keras model.
681
+ nms_kwargs: dict
682
+ Keyword arguments for non-maximum suppression.
683
+ overlap_label: scalar or None
684
+ if not None, label the regions where polygons overlap with that value
685
+ return_predict: bool
686
+ Also return the outputs of :func:`predict` (in a separate tuple)
687
+ If True, implies sparse = False
688
+
689
+ Returns
690
+ -------
691
+ (:class:`numpy.ndarray`, dict), (optional: return tuple of :func:`predict`)
692
+ Returns a tuple of the label instances image and also
693
+ a dictionary with the details (coordinates, etc.) of all remaining polygons/polyhedra.
694
+
695
+ """
696
+ if predict_kwargs is None:
697
+ predict_kwargs = {}
698
+ if nms_kwargs is None:
699
+ nms_kwargs = {}
700
+
701
+ if return_predict and sparse:
702
+ sparse = False
703
+ warnings.warn("Setting sparse to False because return_predict is True")
704
+
705
+ nms_kwargs.setdefault("verbose", verbose)
706
+
707
+ _axes = self._normalize_axes(img, axes)
708
+ _axes_net = self.config.axes
709
+ _permute_axes = self._make_permute_axes(_axes, _axes_net)
710
+ _shape_inst = tuple(s for s,a in zip(_permute_axes(img).shape, _axes_net) if a != 'C')
711
+
712
+ if scale is not None:
713
+ if isinstance(scale, numbers.Number):
714
+ scale = tuple(scale if a in 'XYZ' else 1 for a in _axes)
715
+ scale = tuple(scale)
716
+ len(scale) == len(_axes) or _raise(ValueError(f"scale {scale} must be of length {len(_axes)}, i.e. one value for each of the axes {_axes}"))
717
+ for s,a in zip(scale,_axes):
718
+ s > 0 or _raise(ValueError("scale values must be greater than 0"))
719
+ (s in (1,None) or a in 'XYZ') or warnings.warn(f"replacing scale value {s} for non-spatial axis {a} with 1")
720
+ scale = tuple(s if a in 'XYZ' else 1 for s,a in zip(scale,_axes))
721
+ verbose and print(f"scaling image by factors {scale} for axes {_axes}")
722
+ img = ndi.zoom(img, scale, order=1)
723
+
724
+ yield 'predict' # indicate that prediction is starting
725
+ res = None
726
+ if sparse:
727
+ for res in self._predict_sparse_generator(img, axes=axes, normalizer=normalizer, n_tiles=n_tiles,
728
+ prob_thresh=prob_thresh, show_tile_progress=show_tile_progress, **predict_kwargs):
729
+ if res is None:
730
+ yield 'tile' # yield 'tile' each time a tile has been processed
731
+ else:
732
+ for res in self._predict_generator(img, axes=axes, normalizer=normalizer, n_tiles=n_tiles,
733
+ show_tile_progress=show_tile_progress, **predict_kwargs):
734
+ if res is None:
735
+ yield 'tile' # yield 'tile' each time a tile has been processed
736
+ res = tuple(res) + (None,)
737
+
738
+ if self._is_multiclass():
739
+ prob, dist, prob_class, points = res
740
+ else:
741
+ prob, dist, points = res
742
+ prob_class = None
743
+
744
+ yield 'nms' # indicate that non-maximum suppression is starting
745
+ res_instances = self._instances_from_prediction(_shape_inst, prob, dist,
746
+ points=points,
747
+ prob_class=prob_class,
748
+ prob_thresh=prob_thresh,
749
+ nms_thresh=nms_thresh,
750
+ scale=(None if scale is None else dict(zip(_axes,scale))),
751
+ return_labels=return_labels,
752
+ overlap_label=overlap_label,
753
+ **nms_kwargs)
754
+
755
+ # last "yield" is the actual output that would have been "return"ed if this was a regular function
756
+ if return_predict:
757
+ yield res_instances, tuple(res[:-1])
758
+ else:
759
+ yield res_instances
760
+
761
+
762
+ @functools.wraps(_predict_instances_generator)
763
+ def predict_instances(self, *args, **kwargs):
764
+ # the reason why the actual computation happens as a generator function
765
+ # (in '_predict_instances_generator') is that the generator is called
766
+ # from the stardist napari plugin, which has its benefits regarding
767
+ # control flow and progress display. however, typical use cases should
768
+ # almost always use this function ('predict_instances'), and shouldn't
769
+ # even notice (thanks to @functools.wraps) that it wraps the generator
770
+ # function. note that similar reasoning applies to 'predict' and
771
+ # 'predict_sparse'.
772
+
773
+ # return last "yield"ed value of generator
774
+ r = None
775
+ for r in self._predict_instances_generator(*args, **kwargs):
776
+ pass
777
+ return r
778
+
779
+
780
+ # def _predict_instances_old(self, img, axes=None, normalizer=None,
781
+ # sparse = False,
782
+ # prob_thresh=None, nms_thresh=None,
783
+ # n_tiles=None, show_tile_progress=True,
784
+ # verbose = False,
785
+ # predict_kwargs=None, nms_kwargs=None, overlap_label=None):
786
+ # """
787
+ # old version, should be removed....
788
+ # """
789
+ # if predict_kwargs is None:
790
+ # predict_kwargs = {}
791
+ # if nms_kwargs is None:
792
+ # nms_kwargs = {}
793
+
794
+ # nms_kwargs.setdefault("verbose", verbose)
795
+
796
+ # _axes = self._normalize_axes(img, axes)
797
+ # _axes_net = self.config.axes
798
+ # _permute_axes = self._make_permute_axes(_axes, _axes_net)
799
+ # _shape_inst = tuple(s for s,a in zip(_permute_axes(img).shape, _axes_net) if a != 'C')
800
+
801
+
802
+ # res = self.predict(img, axes=axes, normalizer=normalizer,
803
+ # n_tiles=n_tiles,
804
+ # show_tile_progress=show_tile_progress,
805
+ # **predict_kwargs)
806
+
807
+ # res = tuple(res) + (None,)
808
+
809
+ # if self._is_multiclass():
810
+ # prob, dist, prob_class, points = res
811
+ # else:
812
+ # prob, dist, points = res
813
+ # prob_class = None
814
+
815
+
816
+ # return self._instances_from_prediction_old(_shape_inst, prob, dist,
817
+ # points = points,
818
+ # prob_class = prob_class,
819
+ # prob_thresh=prob_thresh,
820
+ # nms_thresh=nms_thresh,
821
+ # overlap_label=overlap_label,
822
+ # **nms_kwargs)
823
+
824
+
825
+ def predict_instances_big(self, img, axes, block_size, min_overlap, context=None,
826
+ labels_out=None, labels_out_dtype=np.int32, show_progress=True, **kwargs):
827
+ """Predict instance segmentation from very large input images.
828
+
829
+ Intended to be used when `predict_instances` cannot be used due to memory limitations.
830
+ This function will break the input image into blocks and process them individually
831
+ via `predict_instances` and assemble all the partial results. If used as intended, the result
832
+ should be the same as if `predict_instances` was used directly on the whole image.
833
+
834
+ **Important**: The crucial assumption is that all predicted object instances are smaller than
835
+ the provided `min_overlap`. Also, it must hold that: min_overlap + 2*context < block_size.
836
+
837
+ Example
838
+ -------
839
+ >>> img.shape
840
+ (20000, 20000)
841
+ >>> labels, polys = model.predict_instances_big(img, axes='YX', block_size=4096,
842
+ min_overlap=128, context=128, n_tiles=(4,4))
843
+
844
+ Parameters
845
+ ----------
846
+ img: :class:`numpy.ndarray` or similar
847
+ Input image
848
+ axes: str
849
+ Axes of the input ``img`` (such as 'YX', 'ZYX', 'YXC', etc.)
850
+ block_size: int or iterable of int
851
+ Process input image in blocks of the provided shape.
852
+ (If a scalar value is given, it is used for all spatial image dimensions.)
853
+ min_overlap: int or iterable of int
854
+ Amount of guaranteed overlap between blocks.
855
+ (If a scalar value is given, it is used for all spatial image dimensions.)
856
+ context: int or iterable of int, or None
857
+ Amount of image context on all sides of a block, which is discarded.
858
+ If None, uses an automatic estimate that should work in many cases.
859
+ (If a scalar value is given, it is used for all spatial image dimensions.)
860
+ labels_out: :class:`numpy.ndarray` or similar, or None, or False
861
+ numpy array or similar (must be of correct shape) to which the label image is written.
862
+ If None, will allocate a numpy array of the correct shape and data type ``labels_out_dtype``.
863
+ If False, will not write the label image (useful if only the dictionary is needed).
864
+ labels_out_dtype: str or dtype
865
+ Data type of returned label image if ``labels_out=None`` (has no effect otherwise).
866
+ show_progress: bool
867
+ Show progress bar for block processing.
868
+ kwargs: dict
869
+ Keyword arguments for ``predict_instances``.
870
+
871
+ Returns
872
+ -------
873
+ (:class:`numpy.ndarray` or False, dict)
874
+ Returns the label image and a dictionary with the details (coordinates, etc.) of the polygons/polyhedra.
875
+
876
+ """
877
+ from ..big import _grid_divisible, BlockND, OBJECT_KEYS#, repaint_labels
878
+ from ..matching import relabel_sequential
879
+
880
+ n = img.ndim
881
+ axes = axes_check_and_normalize(axes, length=n)
882
+ grid = self._axes_div_by(axes)
883
+ axes_out = self._axes_out.replace('C','')
884
+ shape_dict = dict(zip(axes,img.shape))
885
+ shape_out = tuple(shape_dict[a] for a in axes_out)
886
+
887
+ if context is None:
888
+ context = self._axes_tile_overlap(axes)
889
+
890
+ if np.isscalar(block_size): block_size = n*[block_size]
891
+ if np.isscalar(min_overlap): min_overlap = n*[min_overlap]
892
+ if np.isscalar(context): context = n*[context]
893
+ block_size, min_overlap, context = list(block_size), list(min_overlap), list(context)
894
+ assert n == len(block_size) == len(min_overlap) == len(context)
895
+
896
+ if 'C' in axes:
897
+ # single block for channel axis
898
+ i = axes_dict(axes)['C']
899
+ # if (block_size[i], min_overlap[i], context[i]) != (None, None, None):
900
+ # print("Ignoring values of 'block_size', 'min_overlap', and 'context' for channel axis " +
901
+ # "(set to 'None' to avoid this warning).", file=sys.stderr, flush=True)
902
+ block_size[i] = img.shape[i]
903
+ min_overlap[i] = context[i] = 0
904
+
905
+ block_size = tuple(_grid_divisible(g, v, name='block_size', verbose=False) for v,g,a in zip(block_size, grid,axes))
906
+ min_overlap = tuple(_grid_divisible(g, v, name='min_overlap', verbose=False) for v,g,a in zip(min_overlap,grid,axes))
907
+ context = tuple(_grid_divisible(g, v, name='context', verbose=False) for v,g,a in zip(context, grid,axes))
908
+
909
+ # print(f"input: shape {img.shape} with axes {axes}")
910
+ print(f'effective: block_size={block_size}, min_overlap={min_overlap}, context={context}', flush=True)
911
+
912
+ for a,c,o in zip(axes,context,self._axes_tile_overlap(axes)):
913
+ if c < o:
914
+ print(f"{a}: context of {c} is small, recommended to use at least {o}", flush=True)
915
+
916
+ # create block cover
917
+ blocks = BlockND.cover(img.shape, axes, block_size, min_overlap, context, grid)
918
+
919
+ if np.isscalar(labels_out) and bool(labels_out) is False:
920
+ labels_out = None
921
+ else:
922
+ if labels_out is None:
923
+ labels_out = np.zeros(shape_out, dtype=labels_out_dtype)
924
+ else:
925
+ labels_out.shape == shape_out or _raise(ValueError(f"'labels_out' must have shape {shape_out} (axes {axes_out})."))
926
+
927
+ polys_all = {}
928
+ # problem_ids = []
929
+ label_offset = 1
930
+
931
+ kwargs_override = dict(axes=axes, overlap_label=None, return_labels=True, return_predict=False)
932
+ if show_progress:
933
+ kwargs_override['show_tile_progress'] = False # disable progress for predict_instances
934
+ for k,v in kwargs_override.items():
935
+ if k in kwargs: print(f"changing '{k}' from {kwargs[k]} to {v}", flush=True)
936
+ kwargs[k] = v
937
+
938
+ blocks = tqdm(blocks, disable=(not show_progress))
939
+ # actual computation
940
+ for block in blocks:
941
+ labels, polys = self.predict_instances(block.read(img, axes=axes), **kwargs)
942
+ labels = block.crop_context(labels, axes=axes_out)
943
+ labels, polys = block.filter_objects(labels, polys, axes=axes_out)
944
+ # TODO: relabel_sequential is not very memory-efficient (will allocate memory proportional to label_offset)
945
+ # this should not change the order of labels
946
+ labels = relabel_sequential(labels, label_offset)[0]
947
+
948
+ # labels, fwd_map, _ = relabel_sequential(labels, label_offset)
949
+ # if len(incomplete) > 0:
950
+ # problem_ids.extend([fwd_map[i] for i in incomplete])
951
+ # if show_progress:
952
+ # blocks.set_postfix_str(f"found {len(problem_ids)} problematic {'object' if len(problem_ids)==1 else 'objects'}")
953
+ if labels_out is not None:
954
+ block.write(labels_out, labels, axes=axes_out)
955
+
956
+ for k,v in polys.items():
957
+ polys_all.setdefault(k,[]).append(v)
958
+
959
+ label_offset += len(polys['prob'])
960
+ del labels
961
+
962
+ polys_all = {k: (np.concatenate(v) if k in OBJECT_KEYS else v[0]) for k,v in polys_all.items()}
963
+
964
+ # if labels_out is not None and len(problem_ids) > 0:
965
+ # # if show_progress:
966
+ # # blocks.write('')
967
+ # # print(f"Found {len(problem_ids)} objects that violate the 'min_overlap' assumption.", file=sys.stderr, flush=True)
968
+ # repaint_labels(labels_out, problem_ids, polys_all, show_progress=False)
969
+
970
+ return labels_out, polys_all#, tuple(problem_ids)
971
+
972
+
973
+ def optimize_thresholds(self, X_val, Y_val, nms_threshs=[0.3,0.4,0.5], iou_threshs=[0.3,0.5,0.7], predict_kwargs=None, optimize_kwargs=None, save_to_json=True):
974
+ """Optimize two thresholds (probability, NMS overlap) necessary for predicting object instances.
975
+
976
+ Note that the default thresholds yield good results in many cases, but optimizing
977
+ the thresholds for a particular dataset can further improve performance.
978
+
979
+ The optimized thresholds are automatically used for all further predictions
980
+ and also written to the model directory.
981
+
982
+ See ``utils.optimize_threshold`` for details and possible choices for ``optimize_kwargs``.
983
+
984
+ Parameters
985
+ ----------
986
+ X_val : list of ndarray
987
+ (Validation) input images (must be normalized) to use for threshold tuning.
988
+ Y_val : list of ndarray
989
+ (Validation) label images to use for threshold tuning.
990
+ nms_threshs : list of float
991
+ List of overlap thresholds to be considered for NMS.
992
+ For each value in this list, optimization is run to find a corresponding prob_thresh value.
993
+ iou_threshs : list of float
994
+ List of intersection over union (IOU) thresholds for which
995
+ the (average) matching performance is considered to tune the thresholds.
996
+ predict_kwargs: dict
997
+ Keyword arguments for ``predict`` function of this class.
998
+ (If not provided, will guess value for `n_tiles` to prevent out of memory errors.)
999
+ optimize_kwargs: dict
1000
+ Keyword arguments for ``utils.optimize_threshold`` function.
1001
+
1002
+ """
1003
+ if predict_kwargs is None:
1004
+ predict_kwargs = {}
1005
+ if optimize_kwargs is None:
1006
+ optimize_kwargs = {}
1007
+
1008
+ def _predict_kwargs(x):
1009
+ if 'n_tiles' in predict_kwargs:
1010
+ return predict_kwargs
1011
+ else:
1012
+ return {**predict_kwargs, 'n_tiles': self._guess_n_tiles(x), 'show_tile_progress': False}
1013
+
1014
+ # only take first two elements of predict in case multi class is activated
1015
+ Yhat_val = [self.predict(x, **_predict_kwargs(x))[:2] for x in X_val]
1016
+
1017
+ opt_prob_thresh, opt_measure, opt_nms_thresh = None, -np.inf, None
1018
+ for _opt_nms_thresh in nms_threshs:
1019
+ _opt_prob_thresh, _opt_measure = optimize_threshold(Y_val, Yhat_val, model=self, nms_thresh=_opt_nms_thresh, iou_threshs=iou_threshs, **optimize_kwargs)
1020
+ if _opt_measure > opt_measure:
1021
+ opt_prob_thresh, opt_measure, opt_nms_thresh = _opt_prob_thresh, _opt_measure, _opt_nms_thresh
1022
+ opt_threshs = dict(prob=opt_prob_thresh, nms=opt_nms_thresh)
1023
+
1024
+ self.thresholds = opt_threshs
1025
+ print(end='', file=sys.stderr, flush=True)
1026
+ print("Using optimized values: prob_thresh={prob:g}, nms_thresh={nms:g}.".format(prob=self.thresholds.prob, nms=self.thresholds.nms))
1027
+ if save_to_json and self.basedir is not None:
1028
+ print("Saving to 'thresholds.json'.")
1029
+ save_json(opt_threshs, str(self.logdir / 'thresholds.json'))
1030
+ return opt_threshs
1031
+
1032
+
1033
+ def _guess_n_tiles(self, img):
1034
+ axes = self._normalize_axes(img, axes=None)
1035
+ shape = list(img.shape)
1036
+ if 'C' in axes:
1037
+ del shape[axes_dict(axes)['C']]
1038
+ b = self.config.train_batch_size**(1.0/self.config.n_dim)
1039
+ n_tiles = [int(np.ceil(s/(p*b))) for s,p in zip(shape,self.config.train_patch_size)]
1040
+ if 'C' in axes:
1041
+ n_tiles.insert(axes_dict(axes)['C'],1)
1042
+ return tuple(n_tiles)
1043
+
1044
+
1045
+ def _normalize_axes(self, img, axes):
1046
+ if axes is None:
1047
+ axes = self.config.axes
1048
+ assert 'C' in axes
1049
+ if img.ndim == len(axes)-1 and self.config.n_channel_in == 1:
1050
+ # img has no dedicated channel axis, but 'C' always part of config axes
1051
+ axes = axes.replace('C','')
1052
+ return axes_check_and_normalize(axes, img.ndim)
1053
+
1054
+
1055
+ def _compute_receptive_field(self, img_size=None):
1056
+ # TODO: good enough?
1057
+ from scipy.ndimage import zoom
1058
+ if img_size is None:
1059
+ img_size = tuple(g*(128 if self.config.n_dim==2 else 64) for g in self.config.grid)
1060
+ if np.isscalar(img_size):
1061
+ img_size = (img_size,) * self.config.n_dim
1062
+ img_size = tuple(img_size)
1063
+ # print(img_size)
1064
+ assert all(_is_power_of_2(s) for s in img_size)
1065
+ mid = tuple(s//2 for s in img_size)
1066
+ x = np.zeros((1,)+img_size+(self.config.n_channel_in,), dtype=np.float32)
1067
+ z = np.zeros_like(x)
1068
+ x[(0,)+mid+(slice(None),)] = 1
1069
+ y = self.keras_model.predict(x)[0][0,...,0]
1070
+ y0 = self.keras_model.predict(z)[0][0,...,0]
1071
+ grid = tuple((np.array(x.shape[1:-1])/np.array(y.shape)).astype(int))
1072
+ assert grid == self.config.grid
1073
+ y = zoom(y, grid,order=0)
1074
+ y0 = zoom(y0,grid,order=0)
1075
+ ind = np.where(np.abs(y-y0)>0)
1076
+ return [(m-np.min(i), np.max(i)-m) for (m,i) in zip(mid,ind)]
1077
+
1078
+
1079
+ def _axes_tile_overlap(self, query_axes):
1080
+ query_axes = axes_check_and_normalize(query_axes)
1081
+ try:
1082
+ self._tile_overlap
1083
+ except AttributeError:
1084
+ self._tile_overlap = self._compute_receptive_field()
1085
+ overlap = dict(zip(
1086
+ self.config.axes.replace('C',''),
1087
+ tuple(max(rf) for rf in self._tile_overlap)
1088
+ ))
1089
+ return tuple(overlap.get(a,0) for a in query_axes)
1090
+
1091
+
1092
+ def export_TF(self, fname=None, single_output=True, upsample_grid=True):
1093
+ """Export model to TensorFlow's SavedModel format that can be used e.g. in the Fiji plugin
1094
+
1095
+ Parameters
1096
+ ----------
1097
+ fname : str
1098
+ Path of the zip file to store the model
1099
+ If None, the default path "<modeldir>/TF_SavedModel.zip" is used
1100
+ single_output: bool
1101
+ If set, concatenates the two model outputs into a single output (note: this is currently mandatory for further use in Fiji)
1102
+ upsample_grid: bool
1103
+ If set, upsamples the output to the input shape (note: this is currently mandatory for further use in Fiji)
1104
+ """
1105
+ Concatenate, UpSampling2D, UpSampling3D, Conv2DTranspose, Conv3DTranspose = keras_import('layers', 'Concatenate', 'UpSampling2D', 'UpSampling3D', 'Conv2DTranspose', 'Conv3DTranspose')
1106
+ Model = keras_import('models', 'Model')
1107
+
1108
+ if self.basedir is None and fname is None:
1109
+ raise ValueError("Need explicit 'fname', since model directory not available (basedir=None).")
1110
+
1111
+ if self._is_multiclass():
1112
+ warnings.warn("multi-class mode not supported yet, removing classification output from exported model")
1113
+
1114
+ grid = self.config.grid
1115
+ prob = self.keras_model.outputs[0]
1116
+ dist = self.keras_model.outputs[1]
1117
+ assert self.config.n_dim in (2,3)
1118
+
1119
+ if upsample_grid and any(g>1 for g in grid):
1120
+ # CSBDeep Fiji plugin needs same size input/output
1121
+ # -> we need to upsample the outputs if grid > (1,1)
1122
+ # note: upsampling prob with a transposed convolution creates sparse
1123
+ # prob output with less candidates than with standard upsampling
1124
+ conv_transpose = Conv2DTranspose if self.config.n_dim==2 else Conv3DTranspose
1125
+ upsampling = UpSampling2D if self.config.n_dim==2 else UpSampling3D
1126
+ prob = conv_transpose(1, (1,)*self.config.n_dim,
1127
+ strides=grid, padding='same',
1128
+ kernel_initializer='ones', use_bias=False)(prob)
1129
+ dist = upsampling(grid)(dist)
1130
+
1131
+ inputs = self.keras_model.inputs[0]
1132
+ outputs = Concatenate()([prob,dist]) if single_output else [prob,dist]
1133
+ csbdeep_model = Model(inputs, outputs)
1134
+
1135
+ fname = (self.logdir / 'TF_SavedModel.zip') if fname is None else Path(fname)
1136
+ export_SavedModel(csbdeep_model, str(fname))
1137
+ return csbdeep_model
1138
+
1139
+
1140
+
1141
+ class StarDistPadAndCropResizer(Resizer):
1142
+
1143
+ # TODO: check correctness
1144
+ def __init__(self, grid, mode='reflect', **kwargs):
1145
+ assert isinstance(grid, dict)
1146
+ self.mode = mode
1147
+ self.grid = grid
1148
+ self.kwargs = kwargs
1149
+
1150
+
1151
+ def before(self, x, axes, axes_div_by):
1152
+ assert all(a%g==0 for g,a in zip((self.grid.get(a,1) for a in axes), axes_div_by))
1153
+ axes = axes_check_and_normalize(axes,x.ndim)
1154
+ def _split(v):
1155
+ return 0, v # only pad at the end
1156
+ self.pad = {
1157
+ a : _split((div_n-s%div_n)%div_n)
1158
+ for a, div_n, s in zip(axes, axes_div_by, x.shape)
1159
+ }
1160
+ x_pad = np.pad(x, tuple(self.pad[a] for a in axes), mode=self.mode, **self.kwargs)
1161
+ self.padded_shape = dict(zip(axes,x_pad.shape))
1162
+ if 'C' in self.padded_shape: del self.padded_shape['C']
1163
+ return x_pad
1164
+
1165
+
1166
+ def after(self, x, axes):
1167
+ # axes can include 'C', which may not have been present in before()
1168
+ axes = axes_check_and_normalize(axes,x.ndim)
1169
+ assert all(s_pad == s * g for s,s_pad,g in zip(x.shape,
1170
+ (self.padded_shape.get(a,_s) for a,_s in zip(axes,x.shape)),
1171
+ (self.grid.get(a,1) for a in axes)))
1172
+ # print(self.padded_shape)
1173
+ # print(self.pad)
1174
+ # print(self.grid)
1175
+ crop = tuple (
1176
+ slice(0, -(math.floor(p[1]/g)) if p[1]>=g else None)
1177
+ for p,g in zip((self.pad.get(a,(0,0)) for a in axes),(self.grid.get(a,1) for a in axes))
1178
+ )
1179
+ # print(crop)
1180
+ return x[crop]
1181
+
1182
+
1183
+ def filter_points(self, ndim, points, axes):
1184
+ """ returns indices of points inside crop region """
1185
+ assert points.ndim==2
1186
+ axes = axes_check_and_normalize(axes,ndim)
1187
+
1188
+ bounds = np.array(tuple(self.padded_shape[a]-self.pad[a][1] for a in axes if a.lower() in ('z','y','x')))
1189
+ idx = np.where(np.all(points< bounds, 1))
1190
+ return idx
1191
+
1192
+
1193
+
1194
+ def _tf_version_at_least(version_string="1.0.0"):
1195
+ from packaging import version
1196
+ return version.parse(tf.__version__) >= version.parse(version_string)
stardist_pkg/models/model2d.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, unicode_literals, absolute_import, division
2
+
3
+ import numpy as np
4
+ import warnings
5
+ import math
6
+ from tqdm import tqdm
7
+
8
+ from csbdeep.models import BaseConfig
9
+ from csbdeep.internals.blocks import unet_block
10
+ from csbdeep.utils import _raise, backend_channels_last, axes_check_and_normalize, axes_dict
11
+ from csbdeep.utils.tf import keras_import, IS_TF_1, CARETensorBoard, CARETensorBoardImage
12
+ from skimage.segmentation import clear_border
13
+ from skimage.measure import regionprops
14
+ from scipy.ndimage import zoom
15
+ from distutils.version import LooseVersion
16
+
17
+ keras = keras_import()
18
+ K = keras_import('backend')
19
+ Input, Conv2D, MaxPooling2D = keras_import('layers', 'Input', 'Conv2D', 'MaxPooling2D')
20
+ Model = keras_import('models', 'Model')
21
+
22
+ from .base import StarDistBase, StarDistDataBase, _tf_version_at_least
23
+ from ..sample_patches import sample_patches
24
+ from ..utils import edt_prob, _normalize_grid, mask_to_categorical
25
+ from ..geometry import star_dist, dist_to_coord, polygons_to_label
26
+ from ..nms import non_maximum_suppression, non_maximum_suppression_sparse
27
+
28
+
29
+ class StarDistData2D(StarDistDataBase):
30
+
31
+ def __init__(self, X, Y, batch_size, n_rays, length,
32
+ n_classes=None, classes=None,
33
+ patch_size=(256,256), b=32, grid=(1,1), shape_completion=False, augmenter=None, foreground_prob=0, **kwargs):
34
+
35
+ super().__init__(X=X, Y=Y, n_rays=n_rays, grid=grid,
36
+ n_classes=n_classes, classes=classes,
37
+ batch_size=batch_size, patch_size=patch_size, length=length,
38
+ augmenter=augmenter, foreground_prob=foreground_prob, **kwargs)
39
+
40
+ self.shape_completion = bool(shape_completion)
41
+ if self.shape_completion and b > 0:
42
+ self.b = slice(b,-b),slice(b,-b)
43
+ else:
44
+ self.b = slice(None),slice(None)
45
+
46
+ self.sd_mode = 'opencl' if self.use_gpu else 'cpp'
47
+
48
+
49
+ def __getitem__(self, i):
50
+ idx = self.batch(i)
51
+ arrays = [sample_patches((self.Y[k],) + self.channels_as_tuple(self.X[k]),
52
+ patch_size=self.patch_size, n_samples=1,
53
+ valid_inds=self.get_valid_inds(k)) for k in idx]
54
+
55
+ if self.n_channel is None:
56
+ X, Y = list(zip(*[(x[0][self.b],y[0]) for y,x in arrays]))
57
+ else:
58
+ X, Y = list(zip(*[(np.stack([_x[0] for _x in x],axis=-1)[self.b], y[0]) for y,*x in arrays]))
59
+
60
+ X, Y = tuple(zip(*tuple(self.augmenter(_x, _y) for _x, _y in zip(X,Y))))
61
+
62
+
63
+ prob = np.stack([edt_prob(lbl[self.b][self.ss_grid[1:3]]) for lbl in Y])
64
+ # prob = np.stack([edt_prob(lbl[self.b]) for lbl in Y])
65
+ # prob = prob[self.ss_grid]
66
+
67
+ if self.shape_completion:
68
+ Y_cleared = [clear_border(lbl) for lbl in Y]
69
+ _dist = np.stack([star_dist(lbl,self.n_rays,mode=self.sd_mode)[self.b+(slice(None),)] for lbl in Y_cleared])
70
+ dist = _dist[self.ss_grid]
71
+ dist_mask = np.stack([edt_prob(lbl[self.b][self.ss_grid[1:3]]) for lbl in Y_cleared])
72
+ else:
73
+ # directly subsample with grid
74
+ dist = np.stack([star_dist(lbl,self.n_rays,mode=self.sd_mode, grid=self.grid) for lbl in Y])
75
+ dist_mask = prob
76
+
77
+ X = np.stack(X)
78
+ if X.ndim == 3: # input image has no channel axis
79
+ X = np.expand_dims(X,-1)
80
+ prob = np.expand_dims(prob,-1)
81
+ dist_mask = np.expand_dims(dist_mask,-1)
82
+
83
+ # subsample wth given grid
84
+ # dist_mask = dist_mask[self.ss_grid]
85
+ # prob = prob[self.ss_grid]
86
+
87
+ # append dist_mask to dist as additional channel
88
+ # dist_and_mask = np.concatenate([dist,dist_mask],axis=-1)
89
+ # faster than concatenate
90
+ dist_and_mask = np.empty(dist.shape[:-1]+(self.n_rays+1,), np.float32)
91
+ dist_and_mask[...,:-1] = dist
92
+ dist_and_mask[...,-1:] = dist_mask
93
+
94
+
95
+ if self.n_classes is None:
96
+ return [X], [prob,dist_and_mask]
97
+ else:
98
+ prob_class = np.stack(tuple((mask_to_categorical(y, self.n_classes, self.classes[k]) for y,k in zip(Y, idx))))
99
+
100
+ # TODO: investigate downsampling via simple indexing vs. using 'zoom'
101
+ # prob_class = prob_class[self.ss_grid]
102
+ # 'zoom' might lead to better registered maps (especially if upscaled later)
103
+ prob_class = zoom(prob_class, (1,)+tuple(1/g for g in self.grid)+(1,), order=0)
104
+
105
+ return [X], [prob,dist_and_mask, prob_class]
106
+
107
+
108
+
109
+ class Config2D(BaseConfig):
110
+ """Configuration for a :class:`StarDist2D` model.
111
+
112
+ Parameters
113
+ ----------
114
+ axes : str or None
115
+ Axes of the input images.
116
+ n_rays : int
117
+ Number of radial directions for the star-convex polygon.
118
+ Recommended to use a power of 2 (default: 32).
119
+ n_channel_in : int
120
+ Number of channels of given input image (default: 1).
121
+ grid : (int,int)
122
+ Subsampling factors (must be powers of 2) for each of the axes.
123
+ Model will predict on a subsampled grid for increased efficiency and larger field of view.
124
+ n_classes : None or int
125
+ Number of object classes to use for multi-class predection (use None to disable)
126
+ backbone : str
127
+ Name of the neural network architecture to be used as backbone.
128
+ kwargs : dict
129
+ Overwrite (or add) configuration attributes (see below).
130
+
131
+
132
+ Attributes
133
+ ----------
134
+ unet_n_depth : int
135
+ Number of U-Net resolution levels (down/up-sampling layers).
136
+ unet_kernel_size : (int,int)
137
+ Convolution kernel size for all (U-Net) convolution layers.
138
+ unet_n_filter_base : int
139
+ Number of convolution kernels (feature channels) for first U-Net layer.
140
+ Doubled after each down-sampling layer.
141
+ unet_pool : (int,int)
142
+ Maxpooling size for all (U-Net) convolution layers.
143
+ net_conv_after_unet : int
144
+ Number of filters of the extra convolution layer after U-Net (0 to disable).
145
+ unet_* : *
146
+ Additional parameters for U-net backbone.
147
+ train_shape_completion : bool
148
+ Train model to predict complete shapes for partially visible objects at image boundary.
149
+ train_completion_crop : int
150
+ If 'train_shape_completion' is set to True, specify number of pixels to crop at boundary of training patches.
151
+ Should be chosen based on (largest) object sizes.
152
+ train_patch_size : (int,int)
153
+ Size of patches to be cropped from provided training images.
154
+ train_background_reg : float
155
+ Regularizer to encourage distance predictions on background regions to be 0.
156
+ train_foreground_only : float
157
+ Fraction (0..1) of patches that will only be sampled from regions that contain foreground pixels.
158
+ train_sample_cache : bool
159
+ Activate caching of valid patch regions for all training images (disable to save memory for large datasets)
160
+ train_dist_loss : str
161
+ Training loss for star-convex polygon distances ('mse' or 'mae').
162
+ train_loss_weights : tuple of float
163
+ Weights for losses relating to (probability, distance)
164
+ train_epochs : int
165
+ Number of training epochs.
166
+ train_steps_per_epoch : int
167
+ Number of parameter update steps per epoch.
168
+ train_learning_rate : float
169
+ Learning rate for training.
170
+ train_batch_size : int
171
+ Batch size for training.
172
+ train_n_val_patches : int
173
+ Number of patches to be extracted from validation images (``None`` = one patch per image).
174
+ train_tensorboard : bool
175
+ Enable TensorBoard for monitoring training progress.
176
+ train_reduce_lr : dict
177
+ Parameter :class:`dict` of ReduceLROnPlateau_ callback; set to ``None`` to disable.
178
+ use_gpu : bool
179
+ Indicate that the data generator should use OpenCL to do computations on the GPU.
180
+
181
+ .. _ReduceLROnPlateau: https://keras.io/api/callbacks/reduce_lr_on_plateau/
182
+ """
183
+
184
+ def __init__(self, axes='YX', n_rays=32, n_channel_in=1, grid=(1,1), n_classes=None, backbone='unet', **kwargs):
185
+ """See class docstring."""
186
+
187
+ super().__init__(axes=axes, n_channel_in=n_channel_in, n_channel_out=1+n_rays)
188
+
189
+ # directly set by parameters
190
+ self.n_rays = int(n_rays)
191
+ self.grid = _normalize_grid(grid,2)
192
+ self.backbone = str(backbone).lower()
193
+ self.n_classes = None if n_classes is None else int(n_classes)
194
+
195
+ # default config (can be overwritten by kwargs below)
196
+ if self.backbone == 'unet':
197
+ self.unet_n_depth = 3
198
+ self.unet_kernel_size = 3,3
199
+ self.unet_n_filter_base = 32
200
+ self.unet_n_conv_per_depth = 2
201
+ self.unet_pool = 2,2
202
+ self.unet_activation = 'relu'
203
+ self.unet_last_activation = 'relu'
204
+ self.unet_batch_norm = False
205
+ self.unet_dropout = 0.0
206
+ self.unet_prefix = ''
207
+ self.net_conv_after_unet = 128
208
+ else:
209
+ # TODO: resnet backbone for 2D model?
210
+ raise ValueError("backbone '%s' not supported." % self.backbone)
211
+
212
+ # net_mask_shape not needed but kept for legacy reasons
213
+ if backend_channels_last():
214
+ self.net_input_shape = None,None,self.n_channel_in
215
+ self.net_mask_shape = None,None,1
216
+ else:
217
+ self.net_input_shape = self.n_channel_in,None,None
218
+ self.net_mask_shape = 1,None,None
219
+
220
+ self.train_shape_completion = False
221
+ self.train_completion_crop = 32
222
+ self.train_patch_size = 256,256
223
+ self.train_background_reg = 1e-4
224
+ self.train_foreground_only = 0.9
225
+ self.train_sample_cache = True
226
+
227
+ self.train_dist_loss = 'mae'
228
+ self.train_loss_weights = (1,0.2) if self.n_classes is None else (1,0.2,1)
229
+ self.train_class_weights = (1,1) if self.n_classes is None else (1,)*(self.n_classes+1)
230
+ self.train_epochs = 400
231
+ self.train_steps_per_epoch = 100
232
+ self.train_learning_rate = 0.0003
233
+ self.train_batch_size = 4
234
+ self.train_n_val_patches = None
235
+ self.train_tensorboard = True
236
+ # the parameter 'min_delta' was called 'epsilon' for keras<=2.1.5
237
+ min_delta_key = 'epsilon' if LooseVersion(keras.__version__)<=LooseVersion('2.1.5') else 'min_delta'
238
+ self.train_reduce_lr = {'factor': 0.5, 'patience': 40, min_delta_key: 0}
239
+
240
+ self.use_gpu = False
241
+
242
+ # remove derived attributes that shouldn't be overwritten
243
+ for k in ('n_dim', 'n_channel_out'):
244
+ try: del kwargs[k]
245
+ except KeyError: pass
246
+
247
+ self.update_parameters(False, **kwargs)
248
+
249
+ # FIXME: put into is_valid()
250
+ if not len(self.train_loss_weights) == (2 if self.n_classes is None else 3):
251
+ raise ValueError(f"train_loss_weights {self.train_loss_weights} not compatible with n_classes ({self.n_classes}): must be 3 weights if n_classes is not None, otherwise 2")
252
+
253
+ if not len(self.train_class_weights) == (2 if self.n_classes is None else self.n_classes+1):
254
+ raise ValueError(f"train_class_weights {self.train_class_weights} not compatible with n_classes ({self.n_classes}): must be 'n_classes + 1' weights if n_classes is not None, otherwise 2")
255
+
256
+
257
+
258
+ class StarDist2D(StarDistBase):
259
+ """StarDist2D model.
260
+
261
+ Parameters
262
+ ----------
263
+ config : :class:`Config` or None
264
+ Will be saved to disk as JSON (``config.json``).
265
+ If set to ``None``, will be loaded from disk (must exist).
266
+ name : str or None
267
+ Model name. Uses a timestamp if set to ``None`` (default).
268
+ basedir : str
269
+ Directory that contains (or will contain) a folder with the given model name.
270
+
271
+ Raises
272
+ ------
273
+ FileNotFoundError
274
+ If ``config=None`` and config cannot be loaded from disk.
275
+ ValueError
276
+ Illegal arguments, including invalid configuration.
277
+
278
+ Attributes
279
+ ----------
280
+ config : :class:`Config`
281
+ Configuration, as provided during instantiation.
282
+ keras_model : `Keras model <https://keras.io/getting-started/functional-api-guide/>`_
283
+ Keras neural network model.
284
+ name : str
285
+ Model name.
286
+ logdir : :class:`pathlib.Path`
287
+ Path to model folder (which stores configuration, weights, etc.)
288
+ """
289
+
290
+ def __init__(self, config=Config2D(), name=None, basedir='.'):
291
+ """See class docstring."""
292
+ super().__init__(config, name=name, basedir=basedir)
293
+
294
+
295
+ def _build(self):
296
+ self.config.backbone == 'unet' or _raise(NotImplementedError())
297
+ unet_kwargs = {k[len('unet_'):]:v for (k,v) in vars(self.config).items() if k.startswith('unet_')}
298
+
299
+ input_img = Input(self.config.net_input_shape, name='input')
300
+
301
+ # maxpool input image to grid size
302
+ pooled = np.array([1,1])
303
+ pooled_img = input_img
304
+ while tuple(pooled) != tuple(self.config.grid):
305
+ pool = 1 + (np.asarray(self.config.grid) > pooled)
306
+ pooled *= pool
307
+ for _ in range(self.config.unet_n_conv_per_depth):
308
+ pooled_img = Conv2D(self.config.unet_n_filter_base, self.config.unet_kernel_size,
309
+ padding='same', activation=self.config.unet_activation)(pooled_img)
310
+ pooled_img = MaxPooling2D(pool)(pooled_img)
311
+
312
+ unet_base = unet_block(**unet_kwargs)(pooled_img)
313
+
314
+ if self.config.net_conv_after_unet > 0:
315
+ unet = Conv2D(self.config.net_conv_after_unet, self.config.unet_kernel_size,
316
+ name='features', padding='same', activation=self.config.unet_activation)(unet_base)
317
+ else:
318
+ unet = unet_base
319
+
320
+ output_prob = Conv2D( 1, (1,1), name='prob', padding='same', activation='sigmoid')(unet)
321
+ output_dist = Conv2D(self.config.n_rays, (1,1), name='dist', padding='same', activation='linear')(unet)
322
+
323
+ # attach extra classification head when self.n_classes is given
324
+ if self._is_multiclass():
325
+ if self.config.net_conv_after_unet > 0:
326
+ unet_class = Conv2D(self.config.net_conv_after_unet, self.config.unet_kernel_size,
327
+ name='features_class', padding='same', activation=self.config.unet_activation)(unet_base)
328
+ else:
329
+ unet_class = unet_base
330
+
331
+ output_prob_class = Conv2D(self.config.n_classes+1, (1,1), name='prob_class', padding='same', activation='softmax')(unet_class)
332
+ return Model([input_img], [output_prob,output_dist,output_prob_class])
333
+ else:
334
+ return Model([input_img], [output_prob,output_dist])
335
+
336
+
337
+ def train(self, X, Y, validation_data, classes='auto', augmenter=None, seed=None, epochs=None, steps_per_epoch=None, workers=1):
338
+ """Train the neural network with the given data.
339
+
340
+ Parameters
341
+ ----------
342
+ X : tuple, list, `numpy.ndarray`, `keras.utils.Sequence`
343
+ Input images
344
+ Y : tuple, list, `numpy.ndarray`, `keras.utils.Sequence`
345
+ Label masks
346
+ classes (optional): 'auto' or iterable of same length as X
347
+ label id -> class id mapping for each label mask of Y if multiclass prediction is activated (n_classes > 0)
348
+ list of dicts with label id -> class id (1,...,n_classes)
349
+ 'auto' -> all objects will be assigned to the first non-background class,
350
+ or will be ignored if config.n_classes is None
351
+ validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`) or triple (if multiclass)
352
+ Tuple (triple if multiclass) of X,Y,[classes] validation data.
353
+ augmenter : None or callable
354
+ Function with expected signature ``xt, yt = augmenter(x, y)``
355
+ that takes in a single pair of input/label image (x,y) and returns
356
+ the transformed images (xt, yt) for the purpose of data augmentation
357
+ during training. Not applied to validation images.
358
+ Example:
359
+ def simple_augmenter(x,y):
360
+ x = x + 0.05*np.random.normal(0,1,x.shape)
361
+ return x,y
362
+ seed : int
363
+ Convenience to set ``np.random.seed(seed)``. (To obtain reproducible validation patches, etc.)
364
+ epochs : int
365
+ Optional argument to use instead of the value from ``config``.
366
+ steps_per_epoch : int
367
+ Optional argument to use instead of the value from ``config``.
368
+
369
+ Returns
370
+ -------
371
+ ``History`` object
372
+ See `Keras training history <https://keras.io/models/model/#fit>`_.
373
+
374
+ """
375
+ if seed is not None:
376
+ # https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development
377
+ np.random.seed(seed)
378
+ if epochs is None:
379
+ epochs = self.config.train_epochs
380
+ if steps_per_epoch is None:
381
+ steps_per_epoch = self.config.train_steps_per_epoch
382
+
383
+ classes = self._parse_classes_arg(classes, len(X))
384
+
385
+ if not self._is_multiclass() and classes is not None:
386
+ warnings.warn("Ignoring given classes as n_classes is set to None")
387
+
388
+ isinstance(validation_data,(list,tuple)) or _raise(ValueError())
389
+ if self._is_multiclass() and len(validation_data) == 2:
390
+ validation_data = tuple(validation_data) + ('auto',)
391
+ ((len(validation_data) == (3 if self._is_multiclass() else 2))
392
+ or _raise(ValueError(f'len(validation_data) = {len(validation_data)}, but should be {3 if self._is_multiclass() else 2}')))
393
+
394
+ patch_size = self.config.train_patch_size
395
+ axes = self.config.axes.replace('C','')
396
+ b = self.config.train_completion_crop if self.config.train_shape_completion else 0
397
+ div_by = self._axes_div_by(axes)
398
+ [(p-2*b) % d == 0 or _raise(ValueError(
399
+ "'train_patch_size' - 2*'train_completion_crop' must be divisible by {d} along axis '{a}'".format(a=a,d=d) if self.config.train_shape_completion else
400
+ "'train_patch_size' must be divisible by {d} along axis '{a}'".format(a=a,d=d)
401
+ )) for p,d,a in zip(patch_size,div_by,axes)]
402
+
403
+ if not self._model_prepared:
404
+ self.prepare_for_training()
405
+
406
+ data_kwargs = dict (
407
+ n_rays = self.config.n_rays,
408
+ patch_size = self.config.train_patch_size,
409
+ grid = self.config.grid,
410
+ shape_completion = self.config.train_shape_completion,
411
+ b = self.config.train_completion_crop,
412
+ use_gpu = self.config.use_gpu,
413
+ foreground_prob = self.config.train_foreground_only,
414
+ n_classes = self.config.n_classes,
415
+ sample_ind_cache = self.config.train_sample_cache,
416
+ )
417
+
418
+ # generate validation data and store in numpy arrays
419
+ n_data_val = len(validation_data[0])
420
+ classes_val = self._parse_classes_arg(validation_data[2], n_data_val) if self._is_multiclass() else None
421
+ n_take = self.config.train_n_val_patches if self.config.train_n_val_patches is not None else n_data_val
422
+ _data_val = StarDistData2D(validation_data[0],validation_data[1], classes=classes_val, batch_size=n_take, length=1, **data_kwargs)
423
+ data_val = _data_val[0]
424
+
425
+ # expose data generator as member for general diagnostics
426
+ self.data_train = StarDistData2D(X, Y, classes=classes, batch_size=self.config.train_batch_size,
427
+ augmenter=augmenter, length=epochs*steps_per_epoch, **data_kwargs)
428
+
429
+ if self.config.train_tensorboard:
430
+ # show dist for three rays
431
+ _n = min(3, self.config.n_rays)
432
+ channel = axes_dict(self.config.axes)['C']
433
+ output_slices = [[slice(None)]*4,[slice(None)]*4]
434
+ output_slices[1][1+channel] = slice(0,(self.config.n_rays//_n)*_n, self.config.n_rays//_n)
435
+ if self._is_multiclass():
436
+ _n = min(3, self.config.n_classes)
437
+ output_slices += [[slice(None)]*4]
438
+ output_slices[2][1+channel] = slice(1,1+(self.config.n_classes//_n)*_n, self.config.n_classes//_n)
439
+
440
+ if IS_TF_1:
441
+ for cb in self.callbacks:
442
+ if isinstance(cb,CARETensorBoard):
443
+ cb.output_slices = output_slices
444
+ # target image for dist includes dist_mask and thus has more channels than dist output
445
+ cb.output_target_shapes = [None,[None]*4,None]
446
+ cb.output_target_shapes[1][1+channel] = data_val[1][1].shape[1+channel]
447
+ elif self.basedir is not None and not any(isinstance(cb,CARETensorBoardImage) for cb in self.callbacks):
448
+ self.callbacks.append(CARETensorBoardImage(model=self.keras_model, data=data_val, log_dir=str(self.logdir/'logs'/'images'),
449
+ n_images=3, prob_out=False, output_slices=output_slices))
450
+
451
+ fit = self.keras_model.fit_generator if IS_TF_1 else self.keras_model.fit
452
+ history = fit(iter(self.data_train), validation_data=data_val,
453
+ epochs=epochs, steps_per_epoch=steps_per_epoch,
454
+ workers=workers, use_multiprocessing=workers>1,
455
+ callbacks=self.callbacks, verbose=1,
456
+ # set validation batchsize to training batchsize (only works for tf >= 2.2)
457
+ **(dict(validation_batch_size = self.config.train_batch_size) if _tf_version_at_least("2.2.0") else {}))
458
+ self._training_finished()
459
+
460
+ return history
461
+
462
+
463
+ # def _instances_from_prediction_old(self, img_shape, prob, dist,points = None, prob_class = None, prob_thresh=None, nms_thresh=None, overlap_label = None, **nms_kwargs):
464
+ # from stardist.geometry.geom2d import _polygons_to_label_old, _dist_to_coord_old
465
+ # from stardist.nms import _non_maximum_suppression_old
466
+
467
+ # if prob_thresh is None: prob_thresh = self.thresholds.prob
468
+ # if nms_thresh is None: nms_thresh = self.thresholds.nms
469
+ # if overlap_label is not None: raise NotImplementedError("overlap_label not supported for 2D yet!")
470
+
471
+ # coord = _dist_to_coord_old(dist, grid=self.config.grid)
472
+ # inds = _non_maximum_suppression_old(coord, prob, grid=self.config.grid,
473
+ # prob_thresh=prob_thresh, nms_thresh=nms_thresh, **nms_kwargs)
474
+ # labels = _polygons_to_label_old(coord, prob, inds, shape=img_shape)
475
+ # # sort 'inds' such that ids in 'labels' map to entries in polygon dictionary entries
476
+ # inds = inds[np.argsort(prob[inds[:,0],inds[:,1]])]
477
+ # # adjust for grid
478
+ # points = inds*np.array(self.config.grid)
479
+
480
+ # res_dict = dict(coord=coord[inds[:,0],inds[:,1]], points=points, prob=prob[inds[:,0],inds[:,1]])
481
+
482
+ # if prob_class is not None:
483
+ # prob_class = np.asarray(prob_class)
484
+ # res_dict.update(dict(class_prob = prob_class))
485
+
486
+ # return labels, res_dict
487
+
488
+
489
+ def _instances_from_prediction(self, img_shape, prob, dist, points=None, prob_class=None, prob_thresh=None, nms_thresh=None, overlap_label=None, return_labels=True, scale=None, **nms_kwargs):
490
+ """
491
+ if points is None -> dense prediction
492
+ if points is not None -> sparse prediction
493
+
494
+ if prob_class is None -> single class prediction
495
+ if prob_class is not None -> multi class prediction
496
+ """
497
+ if prob_thresh is None: prob_thresh = self.thresholds.prob
498
+ if nms_thresh is None: nms_thresh = self.thresholds.nms
499
+ if overlap_label is not None: raise NotImplementedError("overlap_label not supported for 2D yet!")
500
+
501
+ # sparse prediction
502
+ if points is not None:
503
+ points, probi, disti, indsi = non_maximum_suppression_sparse(dist, prob, points, nms_thresh=nms_thresh, **nms_kwargs)
504
+ if prob_class is not None:
505
+ prob_class = prob_class[indsi]
506
+
507
+ # dense prediction
508
+ else:
509
+ points, probi, disti = non_maximum_suppression(dist, prob, grid=self.config.grid,
510
+ prob_thresh=prob_thresh, nms_thresh=nms_thresh, **nms_kwargs)
511
+ if prob_class is not None:
512
+ inds = tuple(p//g for p,g in zip(points.T, self.config.grid))
513
+ prob_class = prob_class[inds]
514
+
515
+ if scale is not None:
516
+ # need to undo the scaling given by the scale dict, e.g. scale = dict(X=0.5,Y=0.5):
517
+ # 1. re-scale points (origins of polygons)
518
+ # 2. re-scale coordinates (computed from distances) of (zero-origin) polygons
519
+ if not (isinstance(scale,dict) and 'X' in scale and 'Y' in scale):
520
+ raise ValueError("scale must be a dictionary with entries for 'X' and 'Y'")
521
+ rescale = (1/scale['Y'],1/scale['X'])
522
+ points = points * np.array(rescale).reshape(1,2)
523
+ else:
524
+ rescale = (1,1)
525
+
526
+ if return_labels:
527
+ labels = polygons_to_label(disti, points, prob=probi, shape=img_shape, scale_dist=rescale)
528
+ else:
529
+ labels = None
530
+
531
+ coord = dist_to_coord(disti, points, scale_dist=rescale)
532
+ res_dict = dict(coord=coord, points=points, prob=probi)
533
+
534
+ # multi class prediction
535
+ if prob_class is not None:
536
+ prob_class = np.asarray(prob_class)
537
+ class_id = np.argmax(prob_class, axis=-1)
538
+ res_dict.update(dict(class_prob=prob_class, class_id=class_id))
539
+
540
+ return labels, res_dict
541
+
542
+
543
+ def _axes_div_by(self, query_axes):
544
+ self.config.backbone == 'unet' or _raise(NotImplementedError())
545
+ query_axes = axes_check_and_normalize(query_axes)
546
+ assert len(self.config.unet_pool) == len(self.config.grid)
547
+ div_by = dict(zip(
548
+ self.config.axes.replace('C',''),
549
+ tuple(p**self.config.unet_n_depth * g for p,g in zip(self.config.unet_pool,self.config.grid))
550
+ ))
551
+ return tuple(div_by.get(a,1) for a in query_axes)
552
+
553
+
554
+ # def _axes_tile_overlap(self, query_axes):
555
+ # self.config.backbone == 'unet' or _raise(NotImplementedError())
556
+ # query_axes = axes_check_and_normalize(query_axes)
557
+ # assert len(self.config.unet_pool) == len(self.config.grid) == len(self.config.unet_kernel_size)
558
+ # # TODO: compute this properly when any value of grid > 1
559
+ # # all(g==1 for g in self.config.grid) or warnings.warn('FIXME')
560
+ # overlap = dict(zip(
561
+ # self.config.axes.replace('C',''),
562
+ # tuple(tile_overlap(self.config.unet_n_depth + int(np.log2(g)), k, p)
563
+ # for p,k,g in zip(self.config.unet_pool,self.config.unet_kernel_size,self.config.grid))
564
+ # ))
565
+ # return tuple(overlap.get(a,0) for a in query_axes)
566
+
567
+
568
+ @property
569
+ def _config_class(self):
570
+ return Config2D
stardist_pkg/nms.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, unicode_literals, absolute_import, division
2
+ import numpy as np
3
+ from time import time
4
+ from .utils import _normalize_grid
5
+
6
+ def _ind_prob_thresh(prob, prob_thresh, b=2):
7
+ if b is not None and np.isscalar(b):
8
+ b = ((b,b),)*prob.ndim
9
+
10
+ ind_thresh = prob > prob_thresh
11
+ if b is not None:
12
+ _ind_thresh = np.zeros_like(ind_thresh)
13
+ ss = tuple(slice(_bs[0] if _bs[0]>0 else None,
14
+ -_bs[1] if _bs[1]>0 else None) for _bs in b)
15
+ _ind_thresh[ss] = True
16
+ ind_thresh &= _ind_thresh
17
+ return ind_thresh
18
+
19
+
20
+ def _non_maximum_suppression_old(coord, prob, grid=(1,1), b=2, nms_thresh=0.5, prob_thresh=0.5, verbose=False, max_bbox_search=True):
21
+ """2D coordinates of the polys that survive from a given prediction (prob, coord)
22
+
23
+ prob.shape = (Ny,Nx)
24
+ coord.shape = (Ny,Nx,2,n_rays)
25
+
26
+ b: don't use pixel closer than b pixels to the image boundary
27
+
28
+ returns retained points
29
+ """
30
+ from .lib.stardist2d import c_non_max_suppression_inds_old
31
+
32
+ # TODO: using b>0 with grid>1 can suppress small/cropped objects at the image boundary
33
+
34
+ assert prob.ndim == 2
35
+ assert coord.ndim == 4
36
+ grid = _normalize_grid(grid,2)
37
+
38
+ # mask = prob > prob_thresh
39
+ # if b is not None and b > 0:
40
+ # _mask = np.zeros_like(mask)
41
+ # _mask[b:-b,b:-b] = True
42
+ # mask &= _mask
43
+
44
+ mask = _ind_prob_thresh(prob, prob_thresh, b)
45
+
46
+ polygons = coord[mask]
47
+ scores = prob[mask]
48
+
49
+ # sort scores descendingly
50
+ ind = np.argsort(scores)[::-1]
51
+ survivors = np.zeros(len(ind), bool)
52
+ polygons = polygons[ind]
53
+ scores = scores[ind]
54
+
55
+ if max_bbox_search:
56
+ # map pixel indices to ids of sorted polygons (-1 => polygon at that pixel not a candidate)
57
+ mapping = -np.ones(mask.shape,np.int32)
58
+ mapping.flat[ np.flatnonzero(mask)[ind] ] = range(len(ind))
59
+ else:
60
+ mapping = np.empty((0,0),np.int32)
61
+
62
+ if verbose:
63
+ t = time()
64
+
65
+ survivors[ind] = c_non_max_suppression_inds_old(np.ascontiguousarray(polygons.astype(np.int32)),
66
+ mapping, np.float32(nms_thresh), np.int32(max_bbox_search),
67
+ np.int32(grid[0]), np.int32(grid[1]),np.int32(verbose))
68
+
69
+ if verbose:
70
+ print("keeping %s/%s polygons" % (np.count_nonzero(survivors), len(polygons)))
71
+ print("NMS took %.4f s" % (time() - t))
72
+
73
+ points = np.stack([ii[survivors] for ii in np.nonzero(mask)],axis=-1)
74
+ return points
75
+
76
+
77
+ def non_maximum_suppression(dist, prob, grid=(1,1), b=2, nms_thresh=0.5, prob_thresh=0.5,
78
+ use_bbox=True, use_kdtree=True, verbose=False,cut=False):
79
+ """Non-Maximum-Supression of 2D polygons
80
+
81
+ Retains only polygons whose overlap is smaller than nms_thresh
82
+
83
+ dist.shape = (Ny,Nx, n_rays)
84
+ prob.shape = (Ny,Nx)
85
+
86
+ returns the retained points, probabilities, and distances:
87
+
88
+ points, prob, dist = non_maximum_suppression(dist, prob, ....
89
+
90
+ """
91
+
92
+ # TODO: using b>0 with grid>1 can suppress small/cropped objects at the image boundary
93
+
94
+ assert prob.ndim == 2 and dist.ndim == 3 and prob.shape == dist.shape[:2]
95
+ dist = np.asarray(dist)
96
+ prob = np.asarray(prob)
97
+ n_rays = dist.shape[-1]
98
+
99
+ grid = _normalize_grid(grid,2)
100
+
101
+ # mask = prob > prob_thresh
102
+ # if b is not None and b > 0:
103
+ # _mask = np.zeros_like(mask)
104
+ # _mask[b:-b,b:-b] = True
105
+ # mask &= _mask
106
+
107
+ mask = _ind_prob_thresh(prob, prob_thresh, b)
108
+ points = np.stack(np.where(mask), axis=1)
109
+
110
+ dist = dist[mask]
111
+ scores = prob[mask]
112
+
113
+ # sort scores descendingly
114
+ ind = np.argsort(scores)[::-1]
115
+ if cut is True and ind.shape[0] > 20000:
116
+ #if cut is True and :
117
+ ind = ind[:round(ind.shape[0]*0.5)]
118
+ dist = dist[ind]
119
+ scores = scores[ind]
120
+ points = points[ind]
121
+
122
+ points = (points * np.array(grid).reshape((1,2)))
123
+
124
+ if verbose:
125
+ t = time()
126
+
127
+ inds = non_maximum_suppression_inds(dist, points.astype(np.int32, copy=False), scores=scores,
128
+ use_bbox=use_bbox, use_kdtree=use_kdtree,
129
+ thresh=nms_thresh, verbose=verbose)
130
+
131
+ if verbose:
132
+ print("keeping %s/%s polygons" % (np.count_nonzero(inds), len(inds)))
133
+ print("NMS took %.4f s" % (time() - t))
134
+
135
+ return points[inds], scores[inds], dist[inds]
136
+
137
+
138
+ def non_maximum_suppression_sparse(dist, prob, points, b=2, nms_thresh=0.5,
139
+ use_bbox=True, use_kdtree = True, verbose=False):
140
+ """Non-Maximum-Supression of 2D polygons from a list of dists, probs (scores), and points
141
+
142
+ Retains only polyhedra whose overlap is smaller than nms_thresh
143
+
144
+ dist.shape = (n_polys, n_rays)
145
+ prob.shape = (n_polys,)
146
+ points.shape = (n_polys,2)
147
+
148
+ returns the retained instances
149
+
150
+ (pointsi, probi, disti, indsi)
151
+
152
+ with
153
+ pointsi = points[indsi] ...
154
+
155
+ """
156
+
157
+ # TODO: using b>0 with grid>1 can suppress small/cropped objects at the image boundary
158
+
159
+ dist = np.asarray(dist)
160
+ prob = np.asarray(prob)
161
+ points = np.asarray(points)
162
+ n_rays = dist.shape[-1]
163
+
164
+ assert dist.ndim == 2 and prob.ndim == 1 and points.ndim == 2 and \
165
+ points.shape[-1]==2 and len(prob) == len(dist) == len(points)
166
+
167
+ verbose and print("predicting instances with nms_thresh = {nms_thresh}".format(nms_thresh=nms_thresh), flush=True)
168
+
169
+ inds_original = np.arange(len(prob))
170
+ _sorted = np.argsort(prob)[::-1]
171
+ probi = prob[_sorted]
172
+ disti = dist[_sorted]
173
+ pointsi = points[_sorted]
174
+ inds_original = inds_original[_sorted]
175
+
176
+ if verbose:
177
+ print("non-maximum suppression...")
178
+ t = time()
179
+
180
+ inds = non_maximum_suppression_inds(disti, pointsi, scores=probi, thresh=nms_thresh, use_kdtree = use_kdtree, verbose=verbose)
181
+
182
+ if verbose:
183
+ print("keeping %s/%s polyhedra" % (np.count_nonzero(inds), len(inds)))
184
+ print("NMS took %.4f s" % (time() - t))
185
+
186
+ return pointsi[inds], probi[inds], disti[inds], inds_original[inds]
187
+
188
+
189
+ def non_maximum_suppression_inds(dist, points, scores, thresh=0.5, use_bbox=True, use_kdtree = True, verbose=1):
190
+ """
191
+ Applies non maximum supression to ray-convex polygons given by dists and points
192
+ sorted by scores and IoU threshold
193
+
194
+ P1 will suppress P2, if IoU(P1,P2) > thresh
195
+
196
+ with IoU(P1,P2) = Ainter(P1,P2) / min(A(P1),A(P2))
197
+
198
+ i.e. the smaller thresh, the more polygons will be supressed
199
+
200
+ dist.shape = (n_poly, n_rays)
201
+ point.shape = (n_poly, 2)
202
+ score.shape = (n_poly,)
203
+
204
+ returns indices of selected polygons
205
+ """
206
+
207
+ from stardist.lib.stardist2d import c_non_max_suppression_inds
208
+
209
+ assert dist.ndim == 2
210
+ assert points.ndim == 2
211
+
212
+ n_poly = dist.shape[0]
213
+
214
+ if scores is None:
215
+ scores = np.ones(n_poly)
216
+
217
+ assert len(scores) == n_poly
218
+ assert points.shape[0] == n_poly
219
+
220
+ def _prep(x, dtype):
221
+ return np.ascontiguousarray(x.astype(dtype, copy=False))
222
+
223
+ inds = c_non_max_suppression_inds(_prep(dist, np.float32),
224
+ _prep(points, np.float32),
225
+ int(use_kdtree),
226
+ int(use_bbox),
227
+ int(verbose),
228
+ np.float32(thresh))
229
+
230
+ return inds
231
+
232
+
233
+ #########
234
+
235
+
236
+ def non_maximum_suppression_3d(dist, prob, rays, grid=(1,1,1), b=2, nms_thresh=0.5, prob_thresh=0.5, use_bbox=True, use_kdtree=True, verbose=False):
237
+ """Non-Maximum-Supression of 3D polyhedra
238
+
239
+ Retains only polyhedra whose overlap is smaller than nms_thresh
240
+
241
+ dist.shape = (Nz,Ny,Nx, n_rays)
242
+ prob.shape = (Nz,Ny,Nx)
243
+
244
+ returns the retained points, probabilities, and distances:
245
+
246
+ points, prob, dist = non_maximum_suppression_3d(dist, prob, ....
247
+ """
248
+
249
+ # TODO: using b>0 with grid>1 can suppress small/cropped objects at the image boundary
250
+
251
+ dist = np.asarray(dist)
252
+ prob = np.asarray(prob)
253
+
254
+ assert prob.ndim == 3 and dist.ndim == 4 and dist.shape[-1] == len(rays) and prob.shape == dist.shape[:3]
255
+
256
+ grid = _normalize_grid(grid,3)
257
+
258
+ verbose and print("predicting instances with prob_thresh = {prob_thresh} and nms_thresh = {nms_thresh}".format(prob_thresh=prob_thresh, nms_thresh=nms_thresh), flush=True)
259
+
260
+ # ind_thresh = prob > prob_thresh
261
+ # if b is not None and b > 0:
262
+ # _ind_thresh = np.zeros_like(ind_thresh)
263
+ # _ind_thresh[b:-b,b:-b,b:-b] = True
264
+ # ind_thresh &= _ind_thresh
265
+
266
+ ind_thresh = _ind_prob_thresh(prob, prob_thresh, b)
267
+ points = np.stack(np.where(ind_thresh), axis=1)
268
+ verbose and print("found %s candidates"%len(points))
269
+ probi = prob[ind_thresh]
270
+ disti = dist[ind_thresh]
271
+
272
+ _sorted = np.argsort(probi)[::-1]
273
+ probi = probi[_sorted]
274
+ disti = disti[_sorted]
275
+ points = points[_sorted]
276
+
277
+ verbose and print("non-maximum suppression...")
278
+ points = (points * np.array(grid).reshape((1,3)))
279
+
280
+ inds = non_maximum_suppression_3d_inds(disti, points, rays=rays, scores=probi, thresh=nms_thresh,
281
+ use_bbox=use_bbox, use_kdtree = use_kdtree,
282
+ verbose=verbose)
283
+
284
+ verbose and print("keeping %s/%s polyhedra" % (np.count_nonzero(inds), len(inds)))
285
+ return points[inds], probi[inds], disti[inds]
286
+
287
+
288
+ def non_maximum_suppression_3d_sparse(dist, prob, points, rays, b=2, nms_thresh=0.5, use_kdtree = True, verbose=False):
289
+ """Non-Maximum-Supression of 3D polyhedra from a list of dists, probs and points
290
+
291
+ Retains only polyhedra whose overlap is smaller than nms_thresh
292
+ dist.shape = (n_polys, n_rays)
293
+ prob.shape = (n_polys,)
294
+ points.shape = (n_polys,3)
295
+
296
+ returns the retained instances
297
+
298
+ (pointsi, probi, disti, indsi)
299
+
300
+ with
301
+ pointsi = points[indsi] ...
302
+ """
303
+
304
+ # TODO: using b>0 with grid>1 can suppress small/cropped objects at the image boundary
305
+
306
+ dist = np.asarray(dist)
307
+ prob = np.asarray(prob)
308
+ points = np.asarray(points)
309
+
310
+ assert dist.ndim == 2 and prob.ndim == 1 and points.ndim == 2 and \
311
+ dist.shape[-1] == len(rays) and points.shape[-1]==3 and len(prob) == len(dist) == len(points)
312
+
313
+ verbose and print("predicting instances with nms_thresh = {nms_thresh}".format(nms_thresh=nms_thresh), flush=True)
314
+
315
+ inds_original = np.arange(len(prob))
316
+ _sorted = np.argsort(prob)[::-1]
317
+ probi = prob[_sorted]
318
+ disti = dist[_sorted]
319
+ pointsi = points[_sorted]
320
+ inds_original = inds_original[_sorted]
321
+
322
+ verbose and print("non-maximum suppression...")
323
+
324
+ inds = non_maximum_suppression_3d_inds(disti, pointsi, rays=rays, scores=probi, thresh=nms_thresh, use_kdtree = use_kdtree, verbose=verbose)
325
+
326
+ verbose and print("keeping %s/%s polyhedra" % (np.count_nonzero(inds), len(inds)))
327
+ return pointsi[inds], probi[inds], disti[inds], inds_original[inds]
328
+
329
+
330
+ def non_maximum_suppression_3d_inds(dist, points, rays, scores, thresh=0.5, use_bbox=True, use_kdtree = True, verbose=1):
331
+ """
332
+ Applies non maximum supression to ray-convex polyhedra given by dists and rays
333
+ sorted by scores and IoU threshold
334
+
335
+ P1 will suppress P2, if IoU(P1,P2) > thresh
336
+
337
+ with IoU(P1,P2) = Ainter(P1,P2) / min(A(P1),A(P2))
338
+
339
+ i.e. the smaller thresh, the more polygons will be supressed
340
+
341
+ dist.shape = (n_poly, n_rays)
342
+ point.shape = (n_poly, 3)
343
+ score.shape = (n_poly,)
344
+
345
+ returns indices of selected polygons
346
+ """
347
+ from .lib.stardist3d import c_non_max_suppression_inds
348
+
349
+ assert dist.ndim == 2
350
+ assert points.ndim == 2
351
+ assert dist.shape[1] == len(rays)
352
+
353
+ n_poly = dist.shape[0]
354
+
355
+ if scores is None:
356
+ scores = np.ones(n_poly)
357
+
358
+ assert len(scores) == n_poly
359
+ assert points.shape[0] == n_poly
360
+
361
+ # sort scores descendingly
362
+ ind = np.argsort(scores)[::-1]
363
+ survivors = np.ones(n_poly, bool)
364
+ dist = dist[ind]
365
+ points = points[ind]
366
+ scores = scores[ind]
367
+
368
+ def _prep(x, dtype):
369
+ return np.ascontiguousarray(x.astype(dtype, copy=False))
370
+
371
+ if verbose:
372
+ t = time()
373
+
374
+ survivors[ind] = c_non_max_suppression_inds(_prep(dist, np.float32),
375
+ _prep(points, np.float32),
376
+ _prep(rays.vertices, np.float32),
377
+ _prep(rays.faces, np.int32),
378
+ _prep(scores, np.float32),
379
+ int(use_bbox),
380
+ int(use_kdtree),
381
+ int(verbose),
382
+ np.float32(thresh))
383
+
384
+ if verbose:
385
+ print("NMS took %.4f s" % (time() - t))
386
+
387
+ return survivors
stardist_pkg/rays3d.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ray factory
3
+
4
+ classes that provide vertex and triangle information for rays on spheres
5
+
6
+ Example:
7
+
8
+ rays = Rays_Tetra(n_level = 4)
9
+
10
+ print(rays.vertices)
11
+ print(rays.faces)
12
+
13
+ """
14
+ from __future__ import print_function, unicode_literals, absolute_import, division
15
+ import numpy as np
16
+ from scipy.spatial import ConvexHull
17
+ import copy
18
+ import warnings
19
+
20
+ class Rays_Base(object):
21
+ def __init__(self, **kwargs):
22
+ self.kwargs = kwargs
23
+ self._vertices, self._faces = self.setup_vertices_faces()
24
+ self._vertices = np.asarray(self._vertices, np.float32)
25
+ self._faces = np.asarray(self._faces, int)
26
+ self._faces = np.asanyarray(self._faces)
27
+
28
+ def setup_vertices_faces(self):
29
+ """has to return
30
+
31
+ verts , faces
32
+
33
+ verts = ( (z_1,y_1,x_1), ... )
34
+ faces ( (0,1,2), (2,3,4), ... )
35
+
36
+ """
37
+ raise NotImplementedError()
38
+
39
+ @property
40
+ def vertices(self):
41
+ """read-only property"""
42
+ return self._vertices.copy()
43
+
44
+ @property
45
+ def faces(self):
46
+ """read-only property"""
47
+ return self._faces.copy()
48
+
49
+ def __getitem__(self, i):
50
+ return self.vertices[i]
51
+
52
+ def __len__(self):
53
+ return len(self._vertices)
54
+
55
+ def __repr__(self):
56
+ def _conv(x):
57
+ if isinstance(x,(tuple, list, np.ndarray)):
58
+ return "_".join(_conv(_x) for _x in x)
59
+ if isinstance(x,float):
60
+ return "%.2f"%x
61
+ return str(x)
62
+ return "%s_%s" % (self.__class__.__name__, "_".join("%s_%s" % (k, _conv(v)) for k, v in sorted(self.kwargs.items())))
63
+
64
+ def to_json(self):
65
+ return {
66
+ "name": self.__class__.__name__,
67
+ "kwargs": self.kwargs
68
+ }
69
+
70
+ def dist_loss_weights(self, anisotropy = (1,1,1)):
71
+ """returns the anisotropy corrected weights for each ray"""
72
+ anisotropy = np.array(anisotropy)
73
+ assert anisotropy.shape == (3,)
74
+ return np.linalg.norm(self.vertices*anisotropy, axis = -1)
75
+
76
+ def volume(self, dist=None):
77
+ """volume of the starconvex polyhedron spanned by dist (if None, uses dist=1)
78
+ dist can be a nD array, but the last dimension has to be of length n_rays
79
+ """
80
+ if dist is None: dist = np.ones_like(self.vertices)
81
+
82
+ dist = np.asarray(dist)
83
+
84
+ if not dist.shape[-1]==len(self.vertices):
85
+ raise ValueError("last dimension of dist should have length len(rays.vertices)")
86
+ # all the shuffling below is to allow dist to be an arbitrary sized array (with last dim n_rays)
87
+ # self.vertices -> (n_rays,3)
88
+ # dist -> (m,n,..., n_rays)
89
+
90
+ # dist -> (m,n,..., n_rays, 3)
91
+ dist = np.repeat(np.expand_dims(dist,-1), 3, axis = -1)
92
+ # verts -> (m,n,..., n_rays, 3)
93
+ verts = np.broadcast_to(self.vertices, dist.shape)
94
+
95
+ # dist, verts -> (n_rays, m,n, ..., 3)
96
+ dist = np.moveaxis(dist,-2,0)
97
+ verts = np.moveaxis(verts,-2,0)
98
+
99
+ # vs -> (n_faces, 3, m, n, ..., 3)
100
+ vs = (dist*verts)[self.faces]
101
+ # vs -> (n_faces, m, n, ..., 3, 3)
102
+ vs = np.moveaxis(vs, 1,-2)
103
+ # vs -> (n_faces * m * n, 3, 3)
104
+ vs = vs.reshape((len(self.faces)*int(np.prod(dist.shape[1:-1])),3,3))
105
+ d = np.linalg.det(list(vs)).reshape((len(self.faces),)+dist.shape[1:-1])
106
+
107
+ return -1./6*np.sum(d, axis = 0)
108
+
109
+ def surface(self, dist=None):
110
+ """surface area of the starconvex polyhedron spanned by dist (if None, uses dist=1)"""
111
+ dist = np.asarray(dist)
112
+
113
+ if not dist.shape[-1]==len(self.vertices):
114
+ raise ValueError("last dimension of dist should have length len(rays.vertices)")
115
+
116
+ # self.vertices -> (n_rays,3)
117
+ # dist -> (m,n,..., n_rays)
118
+
119
+ # all the shuffling below is to allow dist to be an arbitrary sized array (with last dim n_rays)
120
+
121
+ # dist -> (m,n,..., n_rays, 3)
122
+ dist = np.repeat(np.expand_dims(dist,-1), 3, axis = -1)
123
+ # verts -> (m,n,..., n_rays, 3)
124
+ verts = np.broadcast_to(self.vertices, dist.shape)
125
+
126
+ # dist, verts -> (n_rays, m,n, ..., 3)
127
+ dist = np.moveaxis(dist,-2,0)
128
+ verts = np.moveaxis(verts,-2,0)
129
+
130
+ # vs -> (n_faces, 3, m, n, ..., 3)
131
+ vs = (dist*verts)[self.faces]
132
+ # vs -> (n_faces, m, n, ..., 3, 3)
133
+ vs = np.moveaxis(vs, 1,-2)
134
+ # vs -> (n_faces * m * n, 3, 3)
135
+ vs = vs.reshape((len(self.faces)*int(np.prod(dist.shape[1:-1])),3,3))
136
+
137
+ pa = vs[...,1,:]-vs[...,0,:]
138
+ pb = vs[...,2,:]-vs[...,0,:]
139
+
140
+ d = .5*np.linalg.norm(np.cross(list(pa), list(pb)), axis = -1)
141
+ d = d.reshape((len(self.faces),)+dist.shape[1:-1])
142
+ return np.sum(d, axis = 0)
143
+
144
+
145
+ def copy(self, scale=(1,1,1)):
146
+ """ returns a copy whose vertices are scaled by given factor"""
147
+ scale = np.asarray(scale)
148
+ assert scale.shape == (3,)
149
+ res = copy.deepcopy(self)
150
+ res._vertices *= scale[np.newaxis]
151
+ return res
152
+
153
+
154
+
155
+
156
+ def rays_from_json(d):
157
+ return eval(d["name"])(**d["kwargs"])
158
+
159
+
160
+ ################################################################
161
+
162
+ class Rays_Explicit(Rays_Base):
163
+ def __init__(self, vertices0, faces0):
164
+ self.vertices0, self.faces0 = vertices0, faces0
165
+ super().__init__(vertices0=list(vertices0), faces0=list(faces0))
166
+
167
+ def setup_vertices_faces(self):
168
+ return self.vertices0, self.faces0
169
+
170
+
171
+ class Rays_Cartesian(Rays_Base):
172
+ def __init__(self, n_rays_x=11, n_rays_z=5):
173
+ super().__init__(n_rays_x=n_rays_x, n_rays_z=n_rays_z)
174
+
175
+ def setup_vertices_faces(self):
176
+ """has to return list of ( (z_1,y_1,x_1), ... ) _"""
177
+ n_rays_x, n_rays_z = self.kwargs["n_rays_x"], self.kwargs["n_rays_z"]
178
+ dphi = np.float32(2. * np.pi / n_rays_x)
179
+ dtheta = np.float32(np.pi / n_rays_z)
180
+
181
+ verts = []
182
+ for mz in range(n_rays_z):
183
+ for mx in range(n_rays_x):
184
+ phi = mx * dphi
185
+ theta = mz * dtheta
186
+ if mz == 0:
187
+ theta = 1e-12
188
+ if mz == n_rays_z - 1:
189
+ theta = np.pi - 1e-12
190
+ dx = np.cos(phi) * np.sin(theta)
191
+ dy = np.sin(phi) * np.sin(theta)
192
+ dz = np.cos(theta)
193
+ if mz == 0 or mz == n_rays_z - 1:
194
+ dx += 1e-12
195
+ dy += 1e-12
196
+ verts.append([dz, dy, dx])
197
+
198
+ verts = np.array(verts)
199
+
200
+ def _ind(mz, mx):
201
+ return mz * n_rays_x + mx
202
+
203
+ faces = []
204
+
205
+ for mz in range(n_rays_z - 1):
206
+ for mx in range(n_rays_x):
207
+ faces.append([_ind(mz, mx), _ind(mz + 1, (mx + 1) % n_rays_x), _ind(mz, (mx + 1) % n_rays_x)])
208
+ faces.append([_ind(mz, mx), _ind(mz + 1, mx), _ind(mz + 1, (mx + 1) % n_rays_x)])
209
+
210
+ faces = np.array(faces)
211
+
212
+ return verts, faces
213
+
214
+
215
+ class Rays_SubDivide(Rays_Base):
216
+ """
217
+ Subdivision polyehdra
218
+
219
+ n_level = 1 -> base polyhedra
220
+ n_level = 2 -> 1x subdivision
221
+ n_level = 3 -> 2x subdivision
222
+ ...
223
+ """
224
+
225
+ def __init__(self, n_level=4):
226
+ super().__init__(n_level=n_level)
227
+
228
+ def base_polyhedron(self):
229
+ raise NotImplementedError()
230
+
231
+ def setup_vertices_faces(self):
232
+ n_level = self.kwargs["n_level"]
233
+ verts0, faces0 = self.base_polyhedron()
234
+ return self._recursive_split(verts0, faces0, n_level)
235
+
236
+ def _recursive_split(self, verts, faces, n_level):
237
+ if n_level <= 1:
238
+ return verts, faces
239
+ else:
240
+ verts, faces = Rays_SubDivide.split(verts, faces)
241
+ return self._recursive_split(verts, faces, n_level - 1)
242
+
243
+ @classmethod
244
+ def split(self, verts0, faces0):
245
+ """split a level"""
246
+
247
+ split_edges = dict()
248
+ verts = list(verts0[:])
249
+ faces = []
250
+
251
+ def _add(a, b):
252
+ """ returns index of middle point and adds vertex if not already added"""
253
+ edge = tuple(sorted((a, b)))
254
+ if not edge in split_edges:
255
+ v = .5 * (verts[a] + verts[b])
256
+ v *= 1. / np.linalg.norm(v)
257
+ verts.append(v)
258
+ split_edges[edge] = len(verts) - 1
259
+ return split_edges[edge]
260
+
261
+ for v1, v2, v3 in faces0:
262
+ ind1 = _add(v1, v2)
263
+ ind2 = _add(v2, v3)
264
+ ind3 = _add(v3, v1)
265
+ faces.append([v1, ind1, ind3])
266
+ faces.append([v2, ind2, ind1])
267
+ faces.append([v3, ind3, ind2])
268
+ faces.append([ind1, ind2, ind3])
269
+
270
+ return verts, faces
271
+
272
+
273
+ class Rays_Tetra(Rays_SubDivide):
274
+ """
275
+ Subdivision of a tetrahedron
276
+
277
+ n_level = 1 -> normal tetrahedron (4 vertices)
278
+ n_level = 2 -> 1x subdivision (10 vertices)
279
+ n_level = 3 -> 2x subdivision (34 vertices)
280
+ ...
281
+ """
282
+
283
+ def base_polyhedron(self):
284
+ verts = np.array([
285
+ [np.sqrt(8. / 9), 0., -1. / 3],
286
+ [-np.sqrt(2. / 9), np.sqrt(2. / 3), -1. / 3],
287
+ [-np.sqrt(2. / 9), -np.sqrt(2. / 3), -1. / 3],
288
+ [0., 0., 1.]
289
+ ])
290
+ faces = [[0, 1, 2],
291
+ [0, 3, 1],
292
+ [0, 2, 3],
293
+ [1, 3, 2]]
294
+
295
+ return verts, faces
296
+
297
+
298
+ class Rays_Octo(Rays_SubDivide):
299
+ """
300
+ Subdivision of a tetrahedron
301
+
302
+ n_level = 1 -> normal Octahedron (6 vertices)
303
+ n_level = 2 -> 1x subdivision (18 vertices)
304
+ n_level = 3 -> 2x subdivision (66 vertices)
305
+
306
+ """
307
+
308
+ def base_polyhedron(self):
309
+ verts = np.array([
310
+ [0, 0, 1],
311
+ [0, 1, 0],
312
+ [0, 0, -1],
313
+ [0, -1, 0],
314
+ [1, 0, 0],
315
+ [-1, 0, 0]])
316
+
317
+ faces = [[0, 1, 4],
318
+ [0, 5, 1],
319
+ [1, 2, 4],
320
+ [1, 5, 2],
321
+ [2, 3, 4],
322
+ [2, 5, 3],
323
+ [3, 0, 4],
324
+ [3, 5, 0],
325
+ ]
326
+
327
+ return verts, faces
328
+
329
+
330
+ def reorder_faces(verts, faces):
331
+ """reorder faces such that their orientation points outward"""
332
+ def _single(face):
333
+ return face[::-1] if np.linalg.det(verts[face])>0 else face
334
+ return tuple(map(_single, faces))
335
+
336
+
337
+ class Rays_GoldenSpiral(Rays_Base):
338
+ def __init__(self, n=70, anisotropy = None):
339
+ if n<4:
340
+ raise ValueError("At least 4 points have to be given!")
341
+ super().__init__(n=n, anisotropy = anisotropy if anisotropy is None else tuple(anisotropy))
342
+
343
+ def setup_vertices_faces(self):
344
+ n = self.kwargs["n"]
345
+ anisotropy = self.kwargs["anisotropy"]
346
+ if anisotropy is None:
347
+ anisotropy = np.ones(3)
348
+ else:
349
+ anisotropy = np.array(anisotropy)
350
+
351
+ # the smaller golden angle = 2pi * 0.3819...
352
+ g = (3. - np.sqrt(5.)) * np.pi
353
+ phi = g * np.arange(n)
354
+ # z = np.linspace(-1, 1, n + 2)[1:-1]
355
+ # rho = np.sqrt(1. - z ** 2)
356
+ # verts = np.stack([rho*np.cos(phi), rho*np.sin(phi),z]).T
357
+ #
358
+ z = np.linspace(-1, 1, n)
359
+ rho = np.sqrt(1. - z ** 2)
360
+ verts = np.stack([z, rho * np.sin(phi), rho * np.cos(phi)]).T
361
+
362
+ # warnings.warn("ray definition has changed! Old results are invalid!")
363
+
364
+ # correct for anisotropy
365
+ verts = verts/anisotropy
366
+ #verts /= np.linalg.norm(verts, axis=-1, keepdims=True)
367
+
368
+ hull = ConvexHull(verts)
369
+ faces = reorder_faces(verts,hull.simplices)
370
+
371
+ verts /= np.linalg.norm(verts, axis=-1, keepdims=True)
372
+
373
+ return verts, faces
stardist_pkg/sample_patches.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """provides a faster sampling function"""
2
+
3
+ import numpy as np
4
+ from csbdeep.utils import _raise, choice
5
+
6
+
7
+ def sample_patches(datas, patch_size, n_samples, valid_inds=None, verbose=False):
8
+ """optimized version of csbdeep.data.sample_patches_from_multiple_stacks
9
+ """
10
+
11
+ len(patch_size)==datas[0].ndim or _raise(ValueError())
12
+
13
+ if not all(( a.shape == datas[0].shape for a in datas )):
14
+ raise ValueError("all input shapes must be the same: %s" % (" / ".join(str(a.shape) for a in datas)))
15
+
16
+ if not all(( 0 < s <= d for s,d in zip(patch_size,datas[0].shape) )):
17
+ raise ValueError("patch_size %s negative or larger than data shape %s along some dimensions" % (str(patch_size), str(datas[0].shape)))
18
+
19
+ if valid_inds is None:
20
+ valid_inds = tuple(_s.ravel() for _s in np.meshgrid(*tuple(np.arange(p//2,s-p//2+1) for s,p in zip(datas[0].shape, patch_size))))
21
+
22
+ n_valid = len(valid_inds[0])
23
+
24
+ if n_valid == 0:
25
+ raise ValueError("no regions to sample from!")
26
+
27
+ idx = choice(range(n_valid), n_samples, replace=(n_valid < n_samples))
28
+ rand_inds = [v[idx] for v in valid_inds]
29
+ res = [np.stack([data[tuple(slice(_r-(_p//2),_r+_p-(_p//2)) for _r,_p in zip(r,patch_size))] for r in zip(*rand_inds)]) for data in datas]
30
+
31
+ return res
32
+
33
+
34
+ def get_valid_inds(img, patch_size, patch_filter=None):
35
+ """
36
+ Returns all indices of an image that
37
+ - can be used as center points for sampling patches of a given patch_size, and
38
+ - are part of the boolean mask given by the function patch_filter (if provided)
39
+
40
+ img: np.ndarray
41
+ patch_size: tuple of ints
42
+ the width of patches per img dimension,
43
+ patch_filter: None or callable
44
+ a function with signature patch_filter(img, patch_size) returning a boolean mask
45
+ """
46
+
47
+ len(patch_size)==img.ndim or _raise(ValueError())
48
+
49
+ if not all(( 0 < s <= d for s,d in zip(patch_size,img.shape))):
50
+ raise ValueError("patch_size %s negative or larger than image shape %s along some dimensions" % (str(patch_size), str(img.shape)))
51
+
52
+ if patch_filter is None:
53
+ # only cut border indices (which is faster)
54
+ patch_mask = np.ones(img.shape,dtype=bool)
55
+ valid_inds = tuple(np.arange(p // 2, s - p + p // 2 + 1).astype(np.uint32) for p, s in zip(patch_size, img.shape))
56
+ valid_inds = tuple(s.ravel() for s in np.meshgrid(*valid_inds, indexing='ij'))
57
+ else:
58
+ patch_mask = patch_filter(img, patch_size)
59
+
60
+ # get the valid indices
61
+ border_slices = tuple([slice(p // 2, s - p + p // 2 + 1) for p, s in zip(patch_size, img.shape)])
62
+ valid_inds = np.where(patch_mask[border_slices])
63
+ valid_inds = tuple((v + s.start).astype(np.uint32) for s, v in zip(border_slices, valid_inds))
64
+
65
+ return valid_inds
stardist_pkg/utils.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function, unicode_literals, absolute_import, division
2
+
3
+ import numpy as np
4
+ import warnings
5
+ import os
6
+ import datetime
7
+ from tqdm import tqdm
8
+ from collections import defaultdict
9
+ from zipfile import ZipFile, ZIP_DEFLATED
10
+ from scipy.ndimage.morphology import distance_transform_edt, binary_fill_holes
11
+ from scipy.ndimage.measurements import find_objects
12
+ from scipy.optimize import minimize_scalar
13
+ from skimage.measure import regionprops
14
+ from csbdeep.utils import _raise
15
+ from csbdeep.utils.six import Path
16
+ from collections.abc import Iterable
17
+
18
+ from .matching import matching_dataset, _check_label_array
19
+
20
+
21
+ try:
22
+ from edt import edt
23
+ _edt_available = True
24
+ try: _edt_parallel_max = len(os.sched_getaffinity(0))
25
+ except: _edt_parallel_max = 128
26
+ _edt_parallel_default = 4
27
+ _edt_parallel = os.environ.get('STARDIST_EDT_NUM_THREADS', _edt_parallel_default)
28
+ try:
29
+ _edt_parallel = min(_edt_parallel_max, int(_edt_parallel))
30
+ except ValueError as e:
31
+ warnings.warn(f"Invalid value ({_edt_parallel}) for STARDIST_EDT_NUM_THREADS. Using default value ({_edt_parallel_default}) instead.")
32
+ _edt_parallel = _edt_parallel_default
33
+ del _edt_parallel_default, _edt_parallel_max
34
+ except ImportError:
35
+ _edt_available = False
36
+ # warnings.warn("Could not find package edt... \nConsider installing it with \n pip install edt\nto improve training data generation performance.")
37
+ pass
38
+
39
+
40
+ def gputools_available():
41
+ try:
42
+ import gputools
43
+ except:
44
+ return False
45
+ return True
46
+
47
+
48
+ def path_absolute(path_relative):
49
+ """ Get absolute path to resource"""
50
+ base_path = os.path.abspath(os.path.dirname(__file__))
51
+ return os.path.join(base_path, path_relative)
52
+
53
+
54
+ def _is_power_of_2(i):
55
+ assert i > 0
56
+ e = np.log2(i)
57
+ return e == int(e)
58
+
59
+
60
+ def _normalize_grid(grid,n):
61
+ try:
62
+ grid = tuple(grid)
63
+ (len(grid) == n and
64
+ all(map(np.isscalar,grid)) and
65
+ all(map(_is_power_of_2,grid))) or _raise(TypeError())
66
+ return tuple(int(g) for g in grid)
67
+ except (TypeError, AssertionError):
68
+ raise ValueError("grid = {grid} must be a list/tuple of length {n} with values that are power of 2".format(grid=grid, n=n))
69
+
70
+
71
+ def edt_prob(lbl_img, anisotropy=None):
72
+ if _edt_available:
73
+ return _edt_prob_edt(lbl_img, anisotropy=anisotropy)
74
+ else:
75
+ # warnings.warn("Could not find package edt... \nConsider installing it with \n pip install edt\nto improve training data generation performance.")
76
+ return _edt_prob_scipy(lbl_img, anisotropy=anisotropy)
77
+
78
+ def _edt_prob_edt(lbl_img, anisotropy=None):
79
+ """Perform EDT on each labeled object and normalize.
80
+ Internally uses https://github.com/seung-lab/euclidean-distance-transform-3d
81
+ that can handle multiple labels at once
82
+ """
83
+ lbl_img = np.ascontiguousarray(lbl_img)
84
+ constant_img = lbl_img.min() == lbl_img.max() and lbl_img.flat[0] > 0
85
+ if constant_img:
86
+ warnings.warn("EDT of constant label image is ill-defined. (Assuming background around it.)")
87
+ # we just need to compute the edt once but then normalize it for each object
88
+ prob = edt(lbl_img, anisotropy=anisotropy, black_border=constant_img, parallel=_edt_parallel)
89
+ objects = find_objects(lbl_img)
90
+ for i,sl in enumerate(objects,1):
91
+ # i: object label id, sl: slices of object in lbl_img
92
+ if sl is None: continue
93
+ _mask = lbl_img[sl]==i
94
+ # normalize it
95
+ prob[sl][_mask] /= np.max(prob[sl][_mask]+1e-10)
96
+ return prob
97
+
98
+ def _edt_prob_scipy(lbl_img, anisotropy=None):
99
+ """Perform EDT on each labeled object and normalize."""
100
+ def grow(sl,interior):
101
+ return tuple(slice(s.start-int(w[0]),s.stop+int(w[1])) for s,w in zip(sl,interior))
102
+ def shrink(interior):
103
+ return tuple(slice(int(w[0]),(-1 if w[1] else None)) for w in interior)
104
+ constant_img = lbl_img.min() == lbl_img.max() and lbl_img.flat[0] > 0
105
+ if constant_img:
106
+ lbl_img = np.pad(lbl_img, ((1,1),)*lbl_img.ndim, mode='constant')
107
+ warnings.warn("EDT of constant label image is ill-defined. (Assuming background around it.)")
108
+ objects = find_objects(lbl_img)
109
+ prob = np.zeros(lbl_img.shape,np.float32)
110
+ for i,sl in enumerate(objects,1):
111
+ # i: object label id, sl: slices of object in lbl_img
112
+ if sl is None: continue
113
+ interior = [(s.start>0,s.stop<sz) for s,sz in zip(sl,lbl_img.shape)]
114
+ # 1. grow object slice by 1 for all interior object bounding boxes
115
+ # 2. perform (correct) EDT for object with label id i
116
+ # 3. extract EDT for object of original slice and normalize
117
+ # 4. store edt for object only for pixels of given label id i
118
+ shrink_slice = shrink(interior)
119
+ grown_mask = lbl_img[grow(sl,interior)]==i
120
+ mask = grown_mask[shrink_slice]
121
+ edt = distance_transform_edt(grown_mask, sampling=anisotropy)[shrink_slice][mask]
122
+ prob[sl][mask] = edt/(np.max(edt)+1e-10)
123
+ if constant_img:
124
+ prob = prob[(slice(1,-1),)*lbl_img.ndim].copy()
125
+ return prob
126
+
127
+
128
+ def _fill_label_holes(lbl_img, **kwargs):
129
+ lbl_img_filled = np.zeros_like(lbl_img)
130
+ for l in (set(np.unique(lbl_img)) - set([0])):
131
+ mask = lbl_img==l
132
+ mask_filled = binary_fill_holes(mask,**kwargs)
133
+ lbl_img_filled[mask_filled] = l
134
+ return lbl_img_filled
135
+
136
+
137
+ def fill_label_holes(lbl_img, **kwargs):
138
+ """Fill small holes in label image."""
139
+ # TODO: refactor 'fill_label_holes' and 'edt_prob' to share code
140
+ def grow(sl,interior):
141
+ return tuple(slice(s.start-int(w[0]),s.stop+int(w[1])) for s,w in zip(sl,interior))
142
+ def shrink(interior):
143
+ return tuple(slice(int(w[0]),(-1 if w[1] else None)) for w in interior)
144
+ objects = find_objects(lbl_img)
145
+ lbl_img_filled = np.zeros_like(lbl_img)
146
+ for i,sl in enumerate(objects,1):
147
+ if sl is None: continue
148
+ interior = [(s.start>0,s.stop<sz) for s,sz in zip(sl,lbl_img.shape)]
149
+ shrink_slice = shrink(interior)
150
+ grown_mask = lbl_img[grow(sl,interior)]==i
151
+ mask_filled = binary_fill_holes(grown_mask,**kwargs)[shrink_slice]
152
+ lbl_img_filled[sl][mask_filled] = i
153
+ return lbl_img_filled
154
+
155
+
156
+ def sample_points(n_samples, mask, prob=None, b=2):
157
+ """sample points to draw some of the associated polygons"""
158
+ if b is not None and b > 0:
159
+ # ignore image boundary, since predictions may not be reliable
160
+ mask_b = np.zeros_like(mask)
161
+ mask_b[b:-b,b:-b] = True
162
+ else:
163
+ mask_b = True
164
+
165
+ points = np.nonzero(mask & mask_b)
166
+
167
+ if prob is not None:
168
+ # weighted sampling via prob
169
+ w = prob[points[0],points[1]].astype(np.float64)
170
+ w /= np.sum(w)
171
+ ind = np.random.choice(len(points[0]), n_samples, replace=True, p=w)
172
+ else:
173
+ ind = np.random.choice(len(points[0]), n_samples, replace=True)
174
+
175
+ points = points[0][ind], points[1][ind]
176
+ points = np.stack(points,axis=-1)
177
+ return points
178
+
179
+
180
+ def calculate_extents(lbl, func=np.median):
181
+ """ Aggregate bounding box sizes of objects in label images. """
182
+ if (isinstance(lbl,np.ndarray) and lbl.ndim==4) or (not isinstance(lbl,np.ndarray) and isinstance(lbl,Iterable)):
183
+ return func(np.stack([calculate_extents(_lbl,func) for _lbl in lbl], axis=0), axis=0)
184
+
185
+ n = lbl.ndim
186
+ n in (2,3) or _raise(ValueError("label image should be 2- or 3-dimensional (or pass a list of these)"))
187
+
188
+ regs = regionprops(lbl)
189
+ if len(regs) == 0:
190
+ return np.zeros(n)
191
+ else:
192
+ extents = np.array([np.array(r.bbox[n:])-np.array(r.bbox[:n]) for r in regs])
193
+ return func(extents, axis=0)
194
+
195
+
196
+ def polyroi_bytearray(x,y,pos=None,subpixel=True):
197
+ """ Byte array of polygon roi with provided x and y coordinates
198
+ See https://github.com/imagej/imagej1/blob/master/ij/io/RoiDecoder.java
199
+ """
200
+ import struct
201
+ def _int16(x):
202
+ return int(x).to_bytes(2, byteorder='big', signed=True)
203
+ def _uint16(x):
204
+ return int(x).to_bytes(2, byteorder='big', signed=False)
205
+ def _int32(x):
206
+ return int(x).to_bytes(4, byteorder='big', signed=True)
207
+ def _float(x):
208
+ return struct.pack(">f", x)
209
+
210
+ subpixel = bool(subpixel)
211
+ # add offset since pixel center is at (0.5,0.5) in ImageJ
212
+ x_raw = np.asarray(x).ravel() + 0.5
213
+ y_raw = np.asarray(y).ravel() + 0.5
214
+ x = np.round(x_raw)
215
+ y = np.round(y_raw)
216
+ assert len(x) == len(y)
217
+ top, left, bottom, right = y.min(), x.min(), y.max(), x.max() # bbox
218
+
219
+ n_coords = len(x)
220
+ bytes_header = 64
221
+ bytes_total = bytes_header + n_coords*2*2 + subpixel*n_coords*2*4
222
+ B = [0] * bytes_total
223
+ B[ 0: 4] = map(ord,'Iout') # magic start
224
+ B[ 4: 6] = _int16(227) # version
225
+ B[ 6: 8] = _int16(0) # roi type (0 = polygon)
226
+ B[ 8:10] = _int16(top) # bbox top
227
+ B[10:12] = _int16(left) # bbox left
228
+ B[12:14] = _int16(bottom) # bbox bottom
229
+ B[14:16] = _int16(right) # bbox right
230
+ B[16:18] = _uint16(n_coords) # number of coordinates
231
+ if subpixel:
232
+ B[50:52] = _int16(128) # subpixel resolution (option flag)
233
+ if pos is not None:
234
+ B[56:60] = _int32(pos) # position (C, Z, or T)
235
+
236
+ for i,(_x,_y) in enumerate(zip(x,y)):
237
+ xs = bytes_header + 2*i
238
+ ys = xs + 2*n_coords
239
+ B[xs:xs+2] = _int16(_x - left)
240
+ B[ys:ys+2] = _int16(_y - top)
241
+
242
+ if subpixel:
243
+ base1 = bytes_header + n_coords*2*2
244
+ base2 = base1 + n_coords*4
245
+ for i,(_x,_y) in enumerate(zip(x_raw,y_raw)):
246
+ xs = base1 + 4*i
247
+ ys = base2 + 4*i
248
+ B[xs:xs+4] = _float(_x)
249
+ B[ys:ys+4] = _float(_y)
250
+
251
+ return bytearray(B)
252
+
253
+
254
+ def export_imagej_rois(fname, polygons, set_position=True, subpixel=True, compression=ZIP_DEFLATED):
255
+ """ polygons assumed to be a list of arrays with shape (id,2,c) """
256
+
257
+ if isinstance(polygons,np.ndarray):
258
+ polygons = (polygons,)
259
+
260
+ fname = Path(fname)
261
+ if fname.suffix == '.zip':
262
+ fname = fname.with_suffix('')
263
+
264
+ with ZipFile(str(fname)+'.zip', mode='w', compression=compression) as roizip:
265
+ for pos,polygroup in enumerate(polygons,start=1):
266
+ for i,poly in enumerate(polygroup,start=1):
267
+ roi = polyroi_bytearray(poly[1],poly[0], pos=(pos if set_position else None), subpixel=subpixel)
268
+ roizip.writestr('{pos:03d}_{i:03d}.roi'.format(pos=pos,i=i), roi)
269
+
270
+
271
+ def optimize_threshold(Y, Yhat, model, nms_thresh, measure='accuracy', iou_threshs=[0.3,0.5,0.7], bracket=None, tol=1e-2, maxiter=20, verbose=1):
272
+ """ Tune prob_thresh for provided (fixed) nms_thresh to maximize matching score (for given measure and averaged over iou_threshs). """
273
+ np.isscalar(nms_thresh) or _raise(ValueError("nms_thresh must be a scalar"))
274
+ iou_threshs = [iou_threshs] if np.isscalar(iou_threshs) else iou_threshs
275
+ values = dict()
276
+
277
+ if bracket is None:
278
+ max_prob = max([np.max(prob) for prob, dist in Yhat])
279
+ bracket = max_prob/2, max_prob
280
+ # print("bracket =", bracket)
281
+
282
+ with tqdm(total=maxiter, disable=(verbose!=1), desc="NMS threshold = %g" % nms_thresh) as progress:
283
+
284
+ def fn(thr):
285
+ prob_thresh = np.clip(thr, *bracket)
286
+ value = values.get(prob_thresh)
287
+ if value is None:
288
+ Y_instances = [model._instances_from_prediction(y.shape, *prob_dist, prob_thresh=prob_thresh, nms_thresh=nms_thresh)[0] for y,prob_dist in zip(Y,Yhat)]
289
+ stats = matching_dataset(Y, Y_instances, thresh=iou_threshs, show_progress=False, parallel=True)
290
+ values[prob_thresh] = value = np.mean([s._asdict()[measure] for s in stats])
291
+ if verbose > 1:
292
+ print("{now} thresh: {prob_thresh:f} {measure}: {value:f}".format(
293
+ now = datetime.datetime.now().strftime('%H:%M:%S'),
294
+ prob_thresh = prob_thresh,
295
+ measure = measure,
296
+ value = value,
297
+ ), flush=True)
298
+ else:
299
+ progress.update()
300
+ progress.set_postfix_str("{prob_thresh:.3f} -> {value:.3f}".format(prob_thresh=prob_thresh, value=value))
301
+ progress.refresh()
302
+ return -value
303
+
304
+ opt = minimize_scalar(fn, method='golden', bracket=bracket, tol=tol, options={'maxiter': maxiter})
305
+
306
+ verbose > 1 and print('\n',opt, flush=True)
307
+ return opt.x, -opt.fun
308
+
309
+
310
+ def _invert_dict(d):
311
+ """ return v-> [k_1,k_2,k_3....] for k,v in d"""
312
+ res = defaultdict(list)
313
+ for k,v in d.items():
314
+ res[v].append(k)
315
+ return res
316
+
317
+
318
+ def mask_to_categorical(y, n_classes, classes, return_cls_dict=False):
319
+ """generates a multi-channel categorical class map
320
+
321
+ Parameters
322
+ ----------
323
+ y : n-dimensional ndarray
324
+ integer label array
325
+ n_classes : int
326
+ Number of different classes (without background)
327
+ classes: dict, integer, or None
328
+ the label to class assignment
329
+ can be
330
+ - dict {label -> class_id}
331
+ the value of class_id can be
332
+ 0 -> background class
333
+ 1...n_classes -> the respective object class (1 ... n_classes)
334
+ None -> ignore object (prob is set to -1 for the pixels of the object, except for background class)
335
+ - single integer value or None -> broadcast value to all labels
336
+
337
+ Returns
338
+ -------
339
+ probability map of shape y.shape+(n_classes+1,) (first channel is background)
340
+
341
+ """
342
+
343
+ _check_label_array(y, 'y')
344
+ if not (np.issubdtype(type(n_classes), np.integer) and n_classes>=1):
345
+ raise ValueError(f"n_classes is '{n_classes}' but should be a positive integer")
346
+
347
+ y_labels = np.unique(y[y>0]).tolist()
348
+
349
+ # build dict class_id -> labels (inverse of classes)
350
+ if np.issubdtype(type(classes), np.integer) or classes is None:
351
+ classes = dict((k,classes) for k in y_labels)
352
+ elif isinstance(classes, dict):
353
+ pass
354
+ else:
355
+ raise ValueError("classes should be dict, single scalar, or None!")
356
+
357
+ if not set(y_labels).issubset(set(classes.keys())):
358
+ raise ValueError(f"all gt labels should be present in class dict provided \ngt_labels found\n{set(y_labels)}\nclass dict labels provided\n{set(classes.keys())}")
359
+
360
+ cls_dict = _invert_dict(classes)
361
+
362
+ # prob map
363
+ y_mask = np.zeros(y.shape+(n_classes+1,), np.float32)
364
+
365
+ for cls, labels in cls_dict.items():
366
+ if cls is None:
367
+ # prob == -1 will be used in the loss to ignore object
368
+ y_mask[np.isin(y, labels)] = -1
369
+ elif np.issubdtype(type(cls), np.integer) and 0 <= cls <= n_classes:
370
+ y_mask[...,cls] = np.isin(y, labels)
371
+ else:
372
+ raise ValueError(f"Wrong class id '{cls}' (for n_classes={n_classes})")
373
+
374
+ # set 0/1 background prob (unaffected by None values for class ids)
375
+ y_mask[...,0] = (y==0)
376
+
377
+ if return_cls_dict:
378
+ return y_mask, cls_dict
379
+ else:
380
+ return y_mask
381
+
382
+
383
+ def _is_floatarray(x):
384
+ return isinstance(x.dtype.type(0),np.floating)
385
+
386
+
387
+ def abspath(root, relpath):
388
+ from pathlib import Path
389
+ root = Path(root)
390
+ if root.is_dir():
391
+ path = root/relpath
392
+ else:
393
+ path = root.parent/relpath
394
+ return str(path.absolute())
stardist_pkg/version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = '0.8.3'
utils_modify.py ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import warnings
13
+ from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from stardist_pkg.big import _grid_divisible, BlockND, OBJECT_KEYS#, repaint_labels
18
+ from stardist_pkg.matching import relabel_sequential
19
+ from stardist_pkg import dist_to_coord, non_maximum_suppression, polygons_to_label
20
+ #from stardist_pkg import dist_to_coord, polygons_to_label
21
+ from stardist_pkg import star_dist,edt_prob
22
+ from monai.data.meta_tensor import MetaTensor
23
+ from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size
24
+ from monai.transforms import Resize
25
+ from monai.utils import (
26
+ BlendMode,
27
+ PytorchPadMode,
28
+ convert_data_type,
29
+ convert_to_dst_type,
30
+ ensure_tuple,
31
+ fall_back_tuple,
32
+ look_up_option,
33
+ optional_import,
34
+ )
35
+ import cv2
36
+ from scipy import ndimage
37
+ from scipy.ndimage.filters import gaussian_filter
38
+ from scipy.ndimage.interpolation import affine_transform, map_coordinates
39
+ from skimage import morphology as morph
40
+ from scipy.ndimage import filters, measurements
41
+ from scipy.ndimage.morphology import (
42
+ binary_dilation,
43
+ binary_fill_holes,
44
+ distance_transform_cdt,
45
+ distance_transform_edt,
46
+ )
47
+
48
+ from skimage.segmentation import watershed
49
+ tqdm, _ = optional_import("tqdm", name="tqdm")
50
+
51
+ __all__ = ["sliding_window_inference"]
52
+
53
+
54
+ ####
55
+ def normalize(mask, dtype=np.uint8):
56
+ return (255 * mask / np.amax(mask)).astype(dtype)
57
+
58
+ def fix_mirror_padding(ann):
59
+ """Deal with duplicated instances due to mirroring in interpolation
60
+ during shape augmentation (scale, rotation etc.).
61
+
62
+ """
63
+ current_max_id = np.amax(ann)
64
+ inst_list = list(np.unique(ann))
65
+ if 0 in inst_list:
66
+ inst_list.remove(0) # 0 is background
67
+ for inst_id in inst_list:
68
+ inst_map = np.array(ann == inst_id, np.uint8)
69
+ remapped_ids = measurements.label(inst_map)[0]
70
+ remapped_ids[remapped_ids > 1] += current_max_id
71
+ ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1]
72
+ current_max_id = np.amax(ann)
73
+ return ann
74
+
75
+ ####
76
+ def get_bounding_box(img):
77
+ """Get bounding box coordinate information."""
78
+ rows = np.any(img, axis=1)
79
+ cols = np.any(img, axis=0)
80
+ rmin, rmax = np.where(rows)[0][[0, -1]]
81
+ cmin, cmax = np.where(cols)[0][[0, -1]]
82
+ # due to python indexing, need to add 1 to max
83
+ # else accessing will be 1px in the box, not out
84
+ rmax += 1
85
+ cmax += 1
86
+ return [rmin, rmax, cmin, cmax]
87
+
88
+
89
+ ####
90
+ def cropping_center(x, crop_shape, batch=False):
91
+ """Crop an input image at the centre.
92
+
93
+ Args:
94
+ x: input array
95
+ crop_shape: dimensions of cropped array
96
+
97
+ Returns:
98
+ x: cropped array
99
+
100
+ """
101
+ orig_shape = x.shape
102
+ if not batch:
103
+ h0 = int((orig_shape[0] - crop_shape[0]) * 0.5)
104
+ w0 = int((orig_shape[1] - crop_shape[1]) * 0.5)
105
+ x = x[h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]]
106
+ else:
107
+ h0 = int((orig_shape[1] - crop_shape[0]) * 0.5)
108
+ w0 = int((orig_shape[2] - crop_shape[1]) * 0.5)
109
+ x = x[:, h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]]
110
+ return x
111
+
112
+ def gen_instance_hv_map(ann, crop_shape):
113
+ """Input annotation must be of original shape.
114
+
115
+ The map is calculated only for instances within the crop portion
116
+ but based on the original shape in original image.
117
+
118
+ Perform following operation:
119
+ Obtain the horizontal and vertical distance maps for each
120
+ nuclear instance.
121
+
122
+ """
123
+ orig_ann = ann.copy() # instance ID map
124
+ fixed_ann = fix_mirror_padding(orig_ann)
125
+ # re-cropping with fixed instance id map
126
+ crop_ann = cropping_center(fixed_ann, crop_shape)
127
+ # TODO: deal with 1 label warning
128
+ crop_ann = morph.remove_small_objects(crop_ann, min_size=30)
129
+
130
+ x_map = np.zeros(orig_ann.shape[:2], dtype=np.float32)
131
+ y_map = np.zeros(orig_ann.shape[:2], dtype=np.float32)
132
+
133
+ inst_list = list(np.unique(crop_ann))
134
+ if 0 in inst_list:
135
+ inst_list.remove(0) # 0 is background
136
+ for inst_id in inst_list:
137
+ inst_map = np.array(fixed_ann == inst_id, np.uint8)
138
+ inst_box = get_bounding_box(inst_map) # rmin, rmax, cmin, cmax
139
+
140
+ # expand the box by 2px
141
+ # Because we first pad the ann at line 207, the bboxes
142
+ # will remain valid after expansion
143
+ inst_box[0] -= 2
144
+ inst_box[2] -= 2
145
+ inst_box[1] += 2
146
+ inst_box[3] += 2
147
+
148
+ # fix inst_box
149
+ inst_box[0] = max(inst_box[0], 0)
150
+ inst_box[2] = max(inst_box[2], 0)
151
+ # inst_box[1] = min(inst_box[1], fixed_ann.shape[0])
152
+ # inst_box[3] = min(inst_box[3], fixed_ann.shape[1])
153
+
154
+ inst_map = inst_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
155
+
156
+ if inst_map.shape[0] < 2 or inst_map.shape[1] < 2:
157
+ print(f'inst_map.shape < 2: {inst_map.shape}, {inst_box}, {get_bounding_box(np.array(fixed_ann == inst_id, np.uint8))}')
158
+ continue
159
+
160
+ # instance center of mass, rounded to nearest pixel
161
+ inst_com = list(measurements.center_of_mass(inst_map))
162
+ if np.isnan(measurements.center_of_mass(inst_map)).any():
163
+ print(inst_id, fixed_ann.shape, np.array(fixed_ann == inst_id, np.uint8).shape)
164
+ print(get_bounding_box(np.array(fixed_ann == inst_id, np.uint8)))
165
+ print(inst_map)
166
+ print(inst_list)
167
+ print(inst_box)
168
+ print(np.count_nonzero(np.array(fixed_ann == inst_id, np.uint8)))
169
+
170
+ inst_com[0] = int(inst_com[0] + 0.5)
171
+ inst_com[1] = int(inst_com[1] + 0.5)
172
+
173
+ inst_x_range = np.arange(1, inst_map.shape[1] + 1)
174
+ inst_y_range = np.arange(1, inst_map.shape[0] + 1)
175
+ # shifting center of pixels grid to instance center of mass
176
+ inst_x_range -= inst_com[1]
177
+ inst_y_range -= inst_com[0]
178
+
179
+ inst_x, inst_y = np.meshgrid(inst_x_range, inst_y_range)
180
+
181
+ # remove coord outside of instance
182
+ inst_x[inst_map == 0] = 0
183
+ inst_y[inst_map == 0] = 0
184
+ inst_x = inst_x.astype("float32")
185
+ inst_y = inst_y.astype("float32")
186
+
187
+ # normalize min into -1 scale
188
+ if np.min(inst_x) < 0:
189
+ inst_x[inst_x < 0] /= -np.amin(inst_x[inst_x < 0])
190
+ if np.min(inst_y) < 0:
191
+ inst_y[inst_y < 0] /= -np.amin(inst_y[inst_y < 0])
192
+ # normalize max into +1 scale
193
+ if np.max(inst_x) > 0:
194
+ inst_x[inst_x > 0] /= np.amax(inst_x[inst_x > 0])
195
+ if np.max(inst_y) > 0:
196
+ inst_y[inst_y > 0] /= np.amax(inst_y[inst_y > 0])
197
+
198
+ ####
199
+ x_map_box = x_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
200
+ x_map_box[inst_map > 0] = inst_x[inst_map > 0]
201
+
202
+ y_map_box = y_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
203
+ y_map_box[inst_map > 0] = inst_y[inst_map > 0]
204
+
205
+ hv_map = np.dstack([x_map, y_map])
206
+ return hv_map
207
+
208
+ def remove_small_objects(pred, min_size=64, connectivity=1):
209
+ """Remove connected components smaller than the specified size.
210
+
211
+ This function is taken from skimage.morphology.remove_small_objects, but the warning
212
+ is removed when a single label is provided.
213
+
214
+ Args:
215
+ pred: input labelled array
216
+ min_size: minimum size of instance in output array
217
+ connectivity: The connectivity defining the neighborhood of a pixel.
218
+
219
+ Returns:
220
+ out: output array with instances removed under min_size
221
+
222
+ """
223
+ out = pred
224
+
225
+ if min_size == 0: # shortcut for efficiency
226
+ return out
227
+
228
+ if out.dtype == bool:
229
+ selem = ndimage.generate_binary_structure(pred.ndim, connectivity)
230
+ ccs = np.zeros_like(pred, dtype=np.int32)
231
+ ndimage.label(pred, selem, output=ccs)
232
+ else:
233
+ ccs = out
234
+
235
+ try:
236
+ component_sizes = np.bincount(ccs.ravel())
237
+ except ValueError:
238
+ raise ValueError(
239
+ "Negative value labels are not supported. Try "
240
+ "relabeling the input with `scipy.ndimage.label` or "
241
+ "`skimage.morphology.label`."
242
+ )
243
+
244
+ too_small = component_sizes < min_size
245
+ too_small_mask = too_small[ccs]
246
+ out[too_small_mask] = 0
247
+
248
+ return out
249
+
250
+ ####
251
+ def gen_targets(ann, crop_shape, **kwargs):
252
+ """Generate the targets for the network."""
253
+ hv_map = gen_instance_hv_map(ann, crop_shape)
254
+ np_map = ann.copy()
255
+ np_map[np_map > 0] = 1
256
+
257
+ hv_map = cropping_center(hv_map, crop_shape)
258
+ np_map = cropping_center(np_map, crop_shape)
259
+
260
+ target_dict = {
261
+ "hv_map": hv_map,
262
+ "np_map": np_map,
263
+ }
264
+
265
+ return target_dict
266
+ def __proc_np_hv(pred, np_thres=0.5, ksize=21, overall_thres=0.4, obj_size_thres=10):
267
+ """Process Nuclei Prediction with XY Coordinate Map.
268
+
269
+ Args:
270
+ pred: prediction output, assuming
271
+ channel 0 contain probability map of nuclei
272
+ channel 1 containing the regressed X-map
273
+ channel 2 containing the regressed Y-map
274
+
275
+ """
276
+ pred = np.array(pred, dtype=np.float32)
277
+
278
+ blb_raw = pred[..., 0]
279
+ h_dir_raw = pred[..., 1]
280
+ v_dir_raw = pred[..., 2]
281
+
282
+ # processing
283
+ blb = np.array(blb_raw >= np_thres, dtype=np.int32)
284
+
285
+ blb = measurements.label(blb)[0]
286
+ blb = remove_small_objects(blb, min_size=10)
287
+ blb[blb > 0] = 1 # background is 0 already
288
+
289
+ h_dir = cv2.normalize(
290
+ h_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
291
+ )
292
+ v_dir = cv2.normalize(
293
+ v_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
294
+ )
295
+
296
+ sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=ksize)
297
+ sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=ksize)
298
+
299
+ sobelh = 1 - (
300
+ cv2.normalize(
301
+ sobelh, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
302
+ )
303
+ )
304
+ sobelv = 1 - (
305
+ cv2.normalize(
306
+ sobelv, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
307
+ )
308
+ )
309
+
310
+ overall = np.maximum(sobelh, sobelv)
311
+ overall = overall - (1 - blb)
312
+ overall[overall < 0] = 0
313
+
314
+ dist = (1.0 - overall) * blb
315
+ ## nuclei values form mountains so inverse to get basins
316
+ dist = -cv2.GaussianBlur(dist, (3, 3), 0)
317
+
318
+ overall = np.array(overall >= overall_thres, dtype=np.int32)
319
+
320
+ marker = blb - overall
321
+ marker[marker < 0] = 0
322
+ marker = binary_fill_holes(marker).astype("uint8")
323
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
324
+ marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel)
325
+ marker = measurements.label(marker)[0]
326
+ marker = remove_small_objects(marker, min_size=obj_size_thres)
327
+
328
+ proced_pred = watershed(dist, markers=marker, mask=blb)
329
+
330
+ return proced_pred
331
+
332
+ ####
333
+ def colorize(ch, vmin, vmax):
334
+ """Will clamp value value outside the provided range to vmax and vmin."""
335
+ cmap = plt.get_cmap("jet")
336
+ ch = np.squeeze(ch.astype("float32"))
337
+ vmin = vmin if vmin is not None else ch.min()
338
+ vmax = vmax if vmax is not None else ch.max()
339
+ ch[ch > vmax] = vmax # clamp value
340
+ ch[ch < vmin] = vmin
341
+ ch = (ch - vmin) / (vmax - vmin + 1.0e-16)
342
+ # take RGB from RGBA heat map
343
+ ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8")
344
+ return ch_cmap
345
+
346
+
347
+ ####
348
+ def random_colors(N, bright=True):
349
+ """Generate random colors.
350
+
351
+ To get visually distinct colors, generate them in HSV space then
352
+ convert to RGB.
353
+ """
354
+ brightness = 1.0 if bright else 0.7
355
+ hsv = [(i / N, 1, brightness) for i in range(N)]
356
+ colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
357
+ random.shuffle(colors)
358
+ return colors
359
+
360
+
361
+ ####
362
+ def visualize_instances_map(
363
+ input_image, inst_map, type_map=None, type_colour=None, line_thickness=2
364
+ ):
365
+ """Overlays segmentation results on image as contours.
366
+
367
+ Args:
368
+ input_image: input image
369
+ inst_map: instance mask with unique value for every object
370
+ type_map: type mask with unique value for every class
371
+ type_colour: a dict of {type : colour} , `type` is from 0-N
372
+ and `colour` is a tuple of (R, G, B)
373
+ line_thickness: line thickness of contours
374
+
375
+ Returns:
376
+ overlay: output image with segmentation overlay as contours
377
+ """
378
+ overlay = np.copy((input_image).astype(np.uint8))
379
+
380
+ inst_list = list(np.unique(inst_map)) # get list of instances
381
+ inst_list.remove(0) # remove background
382
+
383
+ inst_rng_colors = random_colors(len(inst_list))
384
+ inst_rng_colors = np.array(inst_rng_colors) * 255
385
+ inst_rng_colors = inst_rng_colors.astype(np.uint8)
386
+
387
+ for inst_idx, inst_id in enumerate(inst_list):
388
+ inst_map_mask = np.array(inst_map == inst_id, np.uint8) # get single object
389
+ y1, y2, x1, x2 = get_bounding_box(inst_map_mask)
390
+ y1 = y1 - 2 if y1 - 2 >= 0 else y1
391
+ x1 = x1 - 2 if x1 - 2 >= 0 else x1
392
+ x2 = x2 + 2 if x2 + 2 <= inst_map.shape[1] - 1 else x2
393
+ y2 = y2 + 2 if y2 + 2 <= inst_map.shape[0] - 1 else y2
394
+ inst_map_crop = inst_map_mask[y1:y2, x1:x2]
395
+ contours_crop = cv2.findContours(
396
+ inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
397
+ )
398
+ # only has 1 instance per map, no need to check #contour detected by opencv
399
+ contours_crop = np.squeeze(
400
+ contours_crop[0][0].astype("int32")
401
+ ) # * opencv protocol format may break
402
+ contours_crop += np.asarray([[x1, y1]]) # index correction
403
+ if type_map is not None:
404
+ type_map_crop = type_map[y1:y2, x1:x2]
405
+ type_id = np.unique(type_map_crop).max() # non-zero
406
+ inst_colour = type_colour[type_id]
407
+ else:
408
+ inst_colour = (inst_rng_colors[inst_idx]).tolist()
409
+ cv2.drawContours(overlay, [contours_crop], -1, inst_colour, line_thickness)
410
+ return overlay
411
+
412
+
413
+ def sliding_window_inference_large(inputs,block_size,min_overlap,context,roi_size,sw_batch_size,predictor,device):
414
+
415
+ h,w = inputs.shape[0],inputs.shape[1]
416
+ if h < 5000 or w < 5000:
417
+ test_tensor = torch.from_numpy(np.expand_dims(inputs, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
418
+ output_dist,output_prob = sliding_window_inference(test_tensor, roi_size, sw_batch_size, predictor)
419
+ prob = output_prob[0][0].cpu().numpy()
420
+ dist = output_dist[0].cpu().numpy()
421
+ dist = np.transpose(dist,(1,2,0))
422
+ dist = np.maximum(1e-3, dist)
423
+ if h*w < 1500*1500:
424
+ points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.55, nms_thresh=0.4,cut=True)
425
+ else:
426
+ points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4)
427
+
428
+
429
+ labels_out = polygons_to_label(disti, points, prob=probi,shape=prob.shape)
430
+ else:
431
+ n = inputs.ndim
432
+ axes = 'YXC'
433
+ grid = (1,1,1)
434
+ if np.isscalar(block_size): block_size = n*[block_size]
435
+ if np.isscalar(min_overlap): min_overlap = n*[min_overlap]
436
+ if np.isscalar(context): context = n*[context]
437
+ shape_out = (inputs.shape[0],inputs.shape[1])
438
+ labels_out = np.zeros(shape_out, dtype=np.uint64)
439
+ #print(inputs.dtype)
440
+ block_size[2] = inputs.shape[2]
441
+ min_overlap[2] = context[2] = 0
442
+ block_size = tuple(_grid_divisible(g, v, name='block_size', verbose=False) for v,g,a in zip(block_size, grid,axes))
443
+ min_overlap = tuple(_grid_divisible(g, v, name='min_overlap', verbose=False) for v,g,a in zip(min_overlap,grid,axes))
444
+ context = tuple(_grid_divisible(g, v, name='context', verbose=False) for v,g,a in zip(context, grid,axes))
445
+ print(f'effective: block_size={block_size}, min_overlap={min_overlap}, context={context}', flush=True)
446
+ blocks = BlockND.cover(inputs.shape, axes, block_size, min_overlap, context)
447
+ label_offset = 1
448
+ blocks = tqdm(blocks)
449
+ for block in blocks:
450
+ image = block.read(inputs, axes=axes)
451
+ test_tensor = torch.from_numpy(np.expand_dims(image, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
452
+ output_dist,output_prob = sliding_window_inference(test_tensor, roi_size, sw_batch_size, predictor)
453
+ prob = output_prob[0][0].cpu().numpy()
454
+ dist = output_dist[0].cpu().numpy()
455
+ dist = np.transpose(dist,(1,2,0))
456
+ dist = np.maximum(1e-3, dist)
457
+ points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4)
458
+
459
+ coord = dist_to_coord(disti,points)
460
+ polys = dict(coord=coord, points=points, prob=probi)
461
+ labels = polygons_to_label(disti, points, prob=probi,shape=prob.shape)
462
+ labels = block.crop_context(labels, axes='YX')
463
+ labels, polys = block.filter_objects(labels, polys, axes='YX')
464
+ labels = relabel_sequential(labels, label_offset)[0]
465
+ if labels_out is not None:
466
+ block.write(labels_out, labels, axes='YX')
467
+ #for k,v in polys.items():
468
+ #polys_all.setdefault(k,[]).append(v)
469
+ label_offset += len(polys['prob'])
470
+ del labels
471
+ #polys_all = {k: (np.concatenate(v) if k in OBJECT_KEYS else v[0]) for k,v in polys_all.items()}
472
+ return labels_out
473
+ def sliding_window_inference(
474
+ inputs: torch.Tensor,
475
+ roi_size: Union[Sequence[int], int],
476
+ sw_batch_size: int,
477
+ predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]],
478
+ overlap: float = 0.25,
479
+ mode: Union[BlendMode, str] = BlendMode.CONSTANT,
480
+ sigma_scale: Union[Sequence[float], float] = 0.125,
481
+ padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
482
+ cval: float = 0.0,
483
+ sw_device: Union[torch.device, str, None] = None,
484
+ device: Union[torch.device, str, None] = None,
485
+ progress: bool = False,
486
+ roi_weight_map: Union[torch.Tensor, None] = None,
487
+ *args: Any,
488
+ **kwargs: Any,
489
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]:
490
+ """
491
+ Sliding window inference on `inputs` with `predictor`.
492
+
493
+ The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
494
+ Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
495
+ e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
496
+ could be ([128,64,256], [64,32,128]).
497
+ In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
498
+ an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
499
+ so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).
500
+
501
+ When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
502
+ To maintain the same spatial sizes, the output image will be cropped to the original input size.
503
+
504
+ Args:
505
+ inputs: input image to be processed (assuming NCHW[D])
506
+ roi_size: the spatial window size for inferences.
507
+ When its components have None or non-positives, the corresponding inputs dimension will be used.
508
+ if the components of the `roi_size` are non-positive values, the transform will use the
509
+ corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
510
+ to `(32, 64)` if the second spatial dimension size of img is `64`.
511
+ sw_batch_size: the batch size to run window slices.
512
+ predictor: given input tensor ``patch_data`` in shape NCHW[D],
513
+ The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary
514
+ with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];
515
+ where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,
516
+ N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),
517
+ the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).
518
+ In this case, the parameter `overlap` and `roi_size` need to be carefully chosen
519
+ to ensure the scaled output ROI sizes are still integers.
520
+ If the `predictor`'s input and output spatial sizes are different,
521
+ we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.
522
+ overlap: Amount of overlap between scans.
523
+ mode: {``"constant"``, ``"gaussian"``}
524
+ How to blend output of overlapping windows. Defaults to ``"constant"``.
525
+
526
+ - ``"constant``": gives equal weight to all predictions.
527
+ - ``"gaussian``": gives less weight to predictions on edges of windows.
528
+
529
+ sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
530
+ Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
531
+ When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
532
+ spatial dimensions.
533
+ padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
534
+ Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
535
+ See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
536
+ cval: fill value for 'constant' padding mode. Default: 0
537
+ sw_device: device for the window data.
538
+ By default the device (and accordingly the memory) of the `inputs` is used.
539
+ Normally `sw_device` should be consistent with the device where `predictor` is defined.
540
+ device: device for the stitched output prediction.
541
+ By default the device (and accordingly the memory) of the `inputs` is used. If for example
542
+ set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
543
+ `inputs` and `roi_size`. Output is on the `device`.
544
+ progress: whether to print a `tqdm` progress bar.
545
+ roi_weight_map: pre-computed (non-negative) weight map for each ROI.
546
+ If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
547
+ args: optional args to be passed to ``predictor``.
548
+ kwargs: optional keyword args to be passed to ``predictor``.
549
+
550
+ Note:
551
+ - input must be channel-first and have a batch dim, supports N-D sliding window.
552
+
553
+ """
554
+ compute_dtype = inputs.dtype
555
+ num_spatial_dims = len(inputs.shape) - 2
556
+ if overlap < 0 or overlap >= 1:
557
+ raise ValueError("overlap must be >= 0 and < 1.")
558
+
559
+ # determine image spatial size and batch size
560
+ # Note: all input images must have the same image size and batch size
561
+ batch_size, _, *image_size_ = inputs.shape
562
+
563
+ if device is None:
564
+ device = inputs.device
565
+ if sw_device is None:
566
+ sw_device = inputs.device
567
+
568
+ roi_size = fall_back_tuple(roi_size, image_size_)
569
+ # in case that image size is smaller than roi size
570
+ image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
571
+ pad_size = []
572
+ for k in range(len(inputs.shape) - 1, 1, -1):
573
+ diff = max(roi_size[k - 2] - inputs.shape[k], 0)
574
+ half = diff // 2
575
+ pad_size.extend([half, diff - half])
576
+ inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)
577
+ #print('inputs',inputs.shape)
578
+ scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
579
+
580
+ # Store all slices in list
581
+ slices = dense_patch_slices(image_size, roi_size, scan_interval)
582
+ num_win = len(slices) # number of windows per image
583
+ total_slices = num_win * batch_size # total number of windows
584
+
585
+ # Create window-level importance map
586
+ valid_patch_size = get_valid_patch_size(image_size, roi_size)
587
+ if valid_patch_size == roi_size and (roi_weight_map is not None):
588
+ importance_map = roi_weight_map
589
+ else:
590
+ try:
591
+ importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device)
592
+ except BaseException as e:
593
+ raise RuntimeError(
594
+ "Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
595
+ ) from e
596
+ importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore
597
+ # handle non-positive weights
598
+ min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
599
+ importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype)
600
+
601
+ # Perform predictions
602
+ dict_key, output_image_list, count_map_list = None, [], []
603
+ _initialized_ss = -1
604
+ is_tensor_output = True # whether the predictor's output is a tensor (instead of dict/tuple)
605
+
606
+ # for each patch
607
+ for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size):
608
+ slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices))
609
+ unravel_slice = [
610
+ [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win])
611
+ for idx in slice_range
612
+ ]
613
+ window_data = torch.cat(
614
+ [convert_data_type(inputs[win_slice], torch.Tensor)[0] for win_slice in unravel_slice]
615
+ ).to(sw_device)
616
+ seg_prob_out = predictor(window_data, *args, **kwargs) # batched patch segmentation
617
+ #print('seg_prob_out',seg_prob_out[0].shape)
618
+ # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
619
+ seg_prob_tuple: Tuple[torch.Tensor, ...]
620
+ if isinstance(seg_prob_out, torch.Tensor):
621
+ seg_prob_tuple = (seg_prob_out,)
622
+ elif isinstance(seg_prob_out, Mapping):
623
+ if dict_key is None:
624
+ dict_key = sorted(seg_prob_out.keys()) # track predictor's output keys
625
+ seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key)
626
+ is_tensor_output = False
627
+ else:
628
+ seg_prob_tuple = ensure_tuple(seg_prob_out)
629
+ is_tensor_output = False
630
+
631
+ # for each output in multi-output list
632
+ for ss, seg_prob in enumerate(seg_prob_tuple):
633
+ seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN
634
+
635
+ # compute zoom scale: out_roi_size/in_roi_size
636
+ zoom_scale = []
637
+ for axis, (img_s_i, out_w_i, in_w_i) in enumerate(
638
+ zip(image_size, seg_prob.shape[2:], window_data.shape[2:])
639
+ ):
640
+ _scale = out_w_i / float(in_w_i)
641
+ if not (img_s_i * _scale).is_integer():
642
+ warnings.warn(
643
+ f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial "
644
+ f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs."
645
+ )
646
+ zoom_scale.append(_scale)
647
+
648
+ if _initialized_ss < ss: # init. the ss-th buffer at the first iteration
649
+ # construct multi-resolution outputs
650
+ output_classes = seg_prob.shape[1]
651
+ output_shape = [batch_size, output_classes] + [
652
+ int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale)
653
+ ]
654
+ # allocate memory to store the full output and the count for overlapping parts
655
+ output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device=device))
656
+ count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device))
657
+ _initialized_ss += 1
658
+
659
+ # resizing the importance_map
660
+ resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False)
661
+
662
+ # store the result in the proper location of the full output. Apply weights from importance map.
663
+ for idx, original_idx in zip(slice_range, unravel_slice):
664
+ # zoom roi
665
+ original_idx_zoom = list(original_idx) # 4D for 2D image, 5D for 3D image
666
+ for axis in range(2, len(original_idx_zoom)):
667
+ zoomed_start = original_idx[axis].start * zoom_scale[axis - 2]
668
+ zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2]
669
+ if not zoomed_start.is_integer() or (not zoomed_end.is_integer()):
670
+ warnings.warn(
671
+ f"For axis-{axis-2} of output[{ss}], the output roi range is not int. "
672
+ f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). "
673
+ f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. "
674
+ f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n"
675
+ f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. "
676
+ "Tips: if overlap*roi_size*zoom_scale is an integer, it usually works."
677
+ )
678
+ original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None)
679
+ importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(compute_dtype)
680
+ # store results and weights
681
+ output_image_list[ss][original_idx_zoom] += importance_map_zoom * seg_prob[idx - slice_g]
682
+ count_map_list[ss][original_idx_zoom] += (
683
+ importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape)
684
+ )
685
+
686
+ # account for any overlapping sections
687
+ for ss in range(len(output_image_list)):
688
+ output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype)
689
+
690
+ # remove padding if image_size smaller than roi_size
691
+ for ss, output_i in enumerate(output_image_list):
692
+ if torch.isnan(output_i).any() or torch.isinf(output_i).any():
693
+ warnings.warn("Sliding window inference results contain NaN or Inf.")
694
+
695
+ zoom_scale = [
696
+ seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size)
697
+ ]
698
+
699
+ final_slicing: List[slice] = []
700
+ for sp in range(num_spatial_dims):
701
+ slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2])
702
+ slice_dim = slice(
703
+ int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])),
704
+ int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])),
705
+ )
706
+ final_slicing.insert(0, slice_dim)
707
+ while len(final_slicing) < len(output_i.shape):
708
+ final_slicing.insert(0, slice(None))
709
+ output_image_list[ss] = output_i[final_slicing]
710
+
711
+ if dict_key is not None: # if output of predictor is a dict
712
+ final_output = dict(zip(dict_key, output_image_list))
713
+ else:
714
+ final_output = tuple(output_image_list) # type: ignore
715
+ final_output = final_output[0] if is_tensor_output else final_output
716
+
717
+ if isinstance(inputs, MetaTensor):
718
+ final_output = convert_to_dst_type(final_output, inputs, device=device)[0] # type: ignore
719
+ return final_output
720
+
721
+
722
+ def _get_scan_interval(
723
+ image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float
724
+ ) -> Tuple[int, ...]:
725
+ """
726
+ Compute scan interval according to the image size, roi size and overlap.
727
+ Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
728
+ use 1 instead to make sure sliding window works.
729
+
730
+ """
731
+ if len(image_size) != num_spatial_dims:
732
+ raise ValueError("image coord different from spatial dims.")
733
+ if len(roi_size) != num_spatial_dims:
734
+ raise ValueError("roi coord different from spatial dims.")
735
+
736
+ scan_interval = []
737
+ for i in range(num_spatial_dims):
738
+ if roi_size[i] == image_size[i]:
739
+ scan_interval.append(int(roi_size[i]))
740
+ else:
741
+ interval = int(roi_size[i] * (1 - overlap))
742
+ scan_interval.append(interval if interval > 0 else 1)
743
+ return tuple(scan_interval)