Spaces:
Runtime error
Runtime error
Upload 40 files
Browse files- README.md +105 -9
- classifiers.py +261 -0
- config.json +25 -0
- model.pt +3 -0
- models/__init__.py +10 -0
- models/convnext.py +220 -0
- models/flexible_unet.py +312 -0
- models/flexible_unet_convnext.py +447 -0
- overlay.py +116 -0
- pytorch_model.bin +3 -0
- requirements.txt +37 -0
- sribd_cellseg_models.py +100 -0
- stardist_pkg/__init__.py +26 -0
- stardist_pkg/__pycache__/__init__.cpython-37.pyc +0 -0
- stardist_pkg/__pycache__/big.cpython-37.pyc +0 -0
- stardist_pkg/__pycache__/bioimageio_utils.cpython-37.pyc +0 -0
- stardist_pkg/__pycache__/matching.cpython-37.pyc +0 -0
- stardist_pkg/__pycache__/nms.cpython-37.pyc +0 -0
- stardist_pkg/__pycache__/sample_patches.cpython-37.pyc +0 -0
- stardist_pkg/__pycache__/utils.cpython-37.pyc +0 -0
- stardist_pkg/__pycache__/version.cpython-37.pyc +0 -0
- stardist_pkg/big.py +601 -0
- stardist_pkg/bioimageio_utils.py +472 -0
- stardist_pkg/geometry/__init__.py +9 -0
- stardist_pkg/geometry/__pycache__/__init__.cpython-37.pyc +0 -0
- stardist_pkg/geometry/__pycache__/geom2d.cpython-37.pyc +0 -0
- stardist_pkg/geometry/__pycache__/geom3d.cpython-37.pyc +0 -0
- stardist_pkg/geometry/geom2d.py +212 -0
- stardist_pkg/kernels/stardist2d.cl +51 -0
- stardist_pkg/kernels/stardist3d.cl +63 -0
- stardist_pkg/matching.py +483 -0
- stardist_pkg/models/__init__.py +27 -0
- stardist_pkg/models/base.py +1196 -0
- stardist_pkg/models/model2d.py +570 -0
- stardist_pkg/nms.py +387 -0
- stardist_pkg/rays3d.py +373 -0
- stardist_pkg/sample_patches.py +65 -0
- stardist_pkg/utils.py +394 -0
- stardist_pkg/version.py +1 -0
- utils_modify.py +743 -0
README.md
CHANGED
@@ -1,13 +1,109 @@
|
|
1 |
---
|
2 |
-
title: Lewislou Cell Seg Sribd
|
3 |
-
emoji: ⚡
|
4 |
-
colorFrom: blue
|
5 |
-
colorTo: blue
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.38.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
license: apache-2.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
metrics:
|
6 |
+
- f1
|
7 |
+
tags:
|
8 |
+
- cell segmentation
|
9 |
+
- stardist
|
10 |
+
- hover-net
|
11 |
+
library_name: transformers
|
12 |
+
pipeline_tag: image-segmentation
|
13 |
+
datasets:
|
14 |
+
- Lewislou/cell_samples
|
15 |
---
|
16 |
|
17 |
+
|
18 |
+
# Model Card for cell-seg-sribd
|
19 |
+
|
20 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
21 |
+
|
22 |
+
This repository provides the solution of team Sribd-med for NeurIPS-CellSeg Challenge. The details of our method are described in our paper [Multi-stream Cell Segmentation with Low-level Cues for Multi-modality Images]. Some parts of the codes are from the baseline codes of the NeurIPS-CellSeg-Baseline repository,
|
23 |
+
|
24 |
+
You can reproduce our method as follows step by step:
|
25 |
+
|
26 |
+
|
27 |
+
### How to Get Started with the Model
|
28 |
+
|
29 |
+
Install requirements by python -m pip install -r requirements.txt
|
30 |
+
|
31 |
+
## Training Details
|
32 |
+
|
33 |
+
### Training Data
|
34 |
+
|
35 |
+
The competition training and tuning data can be downloaded from https://neurips22-cellseg.grand-challenge.org/dataset/ Besides, you can download three publiced data from the following link: Cellpose: https://www.cellpose.org/dataset Omnipose: http://www.cellpose.org/dataset_omnipose Sartorius: https://www.kaggle.com/competitions/sartorius-cell-instance-segmentation/overview
|
36 |
+
|
37 |
+
## Environments and Requirements:
|
38 |
+
Install requirements by
|
39 |
+
|
40 |
+
```shell
|
41 |
+
python -m pip install -r requirements.txt
|
42 |
+
```
|
43 |
+
|
44 |
+
### How to use
|
45 |
+
|
46 |
+
Here is how to use this model:
|
47 |
+
|
48 |
+
```python
|
49 |
+
|
50 |
+
|
51 |
+
from skimage import io, segmentation, morphology, measure, exposure
|
52 |
+
from sribd_cellseg_models import MultiStreamCellSegModel,ModelConfig
|
53 |
+
import numpy as np
|
54 |
+
import tifffile as tif
|
55 |
+
import requests
|
56 |
+
import torch
|
57 |
+
from PIL import Image
|
58 |
+
from overlay import visualize_instances_map
|
59 |
+
import cv2
|
60 |
+
img_name = 'test_images/cell_00551.tiff'
|
61 |
+
|
62 |
+
def normalize_channel(img, lower=1, upper=99):
|
63 |
+
non_zero_vals = img[np.nonzero(img)]
|
64 |
+
percentiles = np.percentile(non_zero_vals, [lower, upper])
|
65 |
+
if percentiles[1] - percentiles[0] > 0.001:
|
66 |
+
img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8')
|
67 |
+
else:
|
68 |
+
img_norm = img
|
69 |
+
return img_norm.astype(np.uint8)
|
70 |
+
if img_name.endswith('.tif') or img_name.endswith('.tiff'):
|
71 |
+
img_data = tif.imread(img_name)
|
72 |
+
else:
|
73 |
+
img_data = io.imread(img_name)
|
74 |
+
# normalize image data
|
75 |
+
if len(img_data.shape) == 2:
|
76 |
+
img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1)
|
77 |
+
elif len(img_data.shape) == 3 and img_data.shape[-1] > 3:
|
78 |
+
img_data = img_data[:,:, :3]
|
79 |
+
else:
|
80 |
+
pass
|
81 |
+
pre_img_data = np.zeros(img_data.shape, dtype=np.uint8)
|
82 |
+
for i in range(3):
|
83 |
+
img_channel_i = img_data[:,:,i]
|
84 |
+
if len(img_channel_i[np.nonzero(img_channel_i)])>0:
|
85 |
+
pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99)
|
86 |
+
#dummy_input = np.zeros((512,512,3)).astype(np.uint8)
|
87 |
+
my_model = MultiStreamCellSegModel.from_pretrained("Lewislou/cellseg_sribd")
|
88 |
+
checkpoints = torch.load('model.pt')
|
89 |
+
my_model.__init__(ModelConfig())
|
90 |
+
my_model.load_checkpoints(checkpoints)
|
91 |
+
with torch.no_grad():
|
92 |
+
output = my_model(pre_img_data)
|
93 |
+
overlay = visualize_instances_map(pre_img_data,star_label)
|
94 |
+
cv2.imwrite('prediction.png', cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
|
95 |
+
|
96 |
+
|
97 |
+
```
|
98 |
+
|
99 |
+
## Citation
|
100 |
+
If any part of this code is used, please acknowledge it appropriately and cite the paper:
|
101 |
+
```bibtex
|
102 |
+
@misc{
|
103 |
+
lou2022multistream,
|
104 |
+
title={Multi-stream Cell Segmentation with Low-level Cues for Multi-modality Images},
|
105 |
+
author={WEI LOU and Xinyi Yu and Chenyu Liu and Xiang Wan and Guanbin Li and Siqi Liu and Haofeng Li},
|
106 |
+
year={2022},
|
107 |
+
url={https://openreview.net/forum?id=G24BybwKe9}
|
108 |
+
}
|
109 |
+
```
|
classifiers.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Any, Callable, List, Optional, Type, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
|
9 |
+
"""3x3 convolution with padding"""
|
10 |
+
return nn.Conv2d(
|
11 |
+
in_planes,
|
12 |
+
out_planes,
|
13 |
+
kernel_size=3,
|
14 |
+
stride=stride,
|
15 |
+
padding=dilation,
|
16 |
+
groups=groups,
|
17 |
+
bias=False,
|
18 |
+
dilation=dilation,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
|
23 |
+
"""1x1 convolution"""
|
24 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
25 |
+
|
26 |
+
|
27 |
+
class BasicBlock(nn.Module):
|
28 |
+
expansion: int = 1
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
inplanes: int,
|
33 |
+
planes: int,
|
34 |
+
stride: int = 1,
|
35 |
+
downsample: Optional[nn.Module] = None,
|
36 |
+
groups: int = 1,
|
37 |
+
base_width: int = 64,
|
38 |
+
dilation: int = 1,
|
39 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
40 |
+
) -> None:
|
41 |
+
super().__init__()
|
42 |
+
if norm_layer is None:
|
43 |
+
norm_layer = nn.BatchNorm2d
|
44 |
+
if groups != 1 or base_width != 64:
|
45 |
+
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
|
46 |
+
if dilation > 1:
|
47 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
48 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
49 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
50 |
+
self.bn1 = norm_layer(planes)
|
51 |
+
self.relu = nn.ReLU(inplace=True)
|
52 |
+
self.conv2 = conv3x3(planes, planes)
|
53 |
+
self.bn2 = norm_layer(planes)
|
54 |
+
self.downsample = downsample
|
55 |
+
self.stride = stride
|
56 |
+
|
57 |
+
def forward(self, x: Tensor) -> Tensor:
|
58 |
+
identity = x
|
59 |
+
|
60 |
+
out = self.conv1(x)
|
61 |
+
out = self.bn1(out)
|
62 |
+
out = self.relu(out)
|
63 |
+
|
64 |
+
out = self.conv2(out)
|
65 |
+
out = self.bn2(out)
|
66 |
+
|
67 |
+
if self.downsample is not None:
|
68 |
+
identity = self.downsample(x)
|
69 |
+
|
70 |
+
out += identity
|
71 |
+
out = self.relu(out)
|
72 |
+
|
73 |
+
return out
|
74 |
+
|
75 |
+
class Bottleneck(nn.Module):
|
76 |
+
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
77 |
+
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
78 |
+
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
79 |
+
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
80 |
+
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
81 |
+
|
82 |
+
expansion: int = 4
|
83 |
+
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
inplanes: int,
|
87 |
+
planes: int,
|
88 |
+
stride: int = 1,
|
89 |
+
downsample: Optional[nn.Module] = None,
|
90 |
+
groups: int = 1,
|
91 |
+
base_width: int = 64,
|
92 |
+
dilation: int = 1,
|
93 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
94 |
+
) -> None:
|
95 |
+
super().__init__()
|
96 |
+
if norm_layer is None:
|
97 |
+
norm_layer = nn.BatchNorm2d
|
98 |
+
width = int(planes * (base_width / 64.0)) * groups
|
99 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
100 |
+
self.conv1 = conv1x1(inplanes, width)
|
101 |
+
self.bn1 = norm_layer(width)
|
102 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
103 |
+
self.bn2 = norm_layer(width)
|
104 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
105 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
106 |
+
self.relu = nn.ReLU(inplace=True)
|
107 |
+
self.downsample = downsample
|
108 |
+
self.stride = stride
|
109 |
+
|
110 |
+
def forward(self, x: Tensor) -> Tensor:
|
111 |
+
identity = x
|
112 |
+
|
113 |
+
out = self.conv1(x)
|
114 |
+
out = self.bn1(out)
|
115 |
+
out = self.relu(out)
|
116 |
+
|
117 |
+
out = self.conv2(out)
|
118 |
+
out = self.bn2(out)
|
119 |
+
out = self.relu(out)
|
120 |
+
|
121 |
+
out = self.conv3(out)
|
122 |
+
out = self.bn3(out)
|
123 |
+
|
124 |
+
if self.downsample is not None:
|
125 |
+
identity = self.downsample(x)
|
126 |
+
|
127 |
+
out += identity
|
128 |
+
out = self.relu(out)
|
129 |
+
|
130 |
+
return out
|
131 |
+
|
132 |
+
class ResNet(nn.Module):
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
block: Type[Union[BasicBlock, Bottleneck]],
|
136 |
+
layers: List[int],
|
137 |
+
num_classes: int = 1000,
|
138 |
+
zero_init_residual: bool = False,
|
139 |
+
groups: int = 1,
|
140 |
+
width_per_group: int = 64,
|
141 |
+
replace_stride_with_dilation: Optional[List[bool]] = None,
|
142 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
143 |
+
) -> None:
|
144 |
+
super().__init__()
|
145 |
+
# _log_api_usage_once(self)
|
146 |
+
if norm_layer is None:
|
147 |
+
norm_layer = nn.BatchNorm2d
|
148 |
+
self._norm_layer = norm_layer
|
149 |
+
|
150 |
+
self.inplanes = 64
|
151 |
+
self.dilation = 1
|
152 |
+
if replace_stride_with_dilation is None:
|
153 |
+
# each element in the tuple indicates if we should replace
|
154 |
+
# the 2x2 stride with a dilated convolution instead
|
155 |
+
replace_stride_with_dilation = [False, False, False]
|
156 |
+
if len(replace_stride_with_dilation) != 3:
|
157 |
+
raise ValueError(
|
158 |
+
"replace_stride_with_dilation should be None "
|
159 |
+
f"or a 3-element tuple, got {replace_stride_with_dilation}"
|
160 |
+
)
|
161 |
+
self.groups = groups
|
162 |
+
self.base_width = width_per_group
|
163 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
164 |
+
self.bn1 = norm_layer(self.inplanes)
|
165 |
+
self.relu = nn.ReLU(inplace=True)
|
166 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
167 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
168 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
|
169 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
|
170 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
|
171 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
172 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
173 |
+
|
174 |
+
for m in self.modules():
|
175 |
+
if isinstance(m, nn.Conv2d):
|
176 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
177 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
178 |
+
nn.init.constant_(m.weight, 1)
|
179 |
+
nn.init.constant_(m.bias, 0)
|
180 |
+
|
181 |
+
# Zero-initialize the last BN in each residual branch,
|
182 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
183 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
184 |
+
if zero_init_residual:
|
185 |
+
for m in self.modules():
|
186 |
+
if isinstance(m, Bottleneck) and m.bn3.weight is not None:
|
187 |
+
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
|
188 |
+
elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
|
189 |
+
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
|
190 |
+
|
191 |
+
def _make_layer(
|
192 |
+
self,
|
193 |
+
block: Type[Union[BasicBlock, Bottleneck]],
|
194 |
+
planes: int,
|
195 |
+
blocks: int,
|
196 |
+
stride: int = 1,
|
197 |
+
dilate: bool = False,
|
198 |
+
) -> nn.Sequential:
|
199 |
+
norm_layer = self._norm_layer
|
200 |
+
downsample = None
|
201 |
+
previous_dilation = self.dilation
|
202 |
+
if dilate:
|
203 |
+
self.dilation *= stride
|
204 |
+
stride = 1
|
205 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
206 |
+
downsample = nn.Sequential(
|
207 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
208 |
+
norm_layer(planes * block.expansion),
|
209 |
+
)
|
210 |
+
|
211 |
+
layers = []
|
212 |
+
layers.append(
|
213 |
+
block(
|
214 |
+
self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
|
215 |
+
)
|
216 |
+
)
|
217 |
+
self.inplanes = planes * block.expansion
|
218 |
+
for _ in range(1, blocks):
|
219 |
+
layers.append(
|
220 |
+
block(
|
221 |
+
self.inplanes,
|
222 |
+
planes,
|
223 |
+
groups=self.groups,
|
224 |
+
base_width=self.base_width,
|
225 |
+
dilation=self.dilation,
|
226 |
+
norm_layer=norm_layer,
|
227 |
+
)
|
228 |
+
)
|
229 |
+
|
230 |
+
return nn.Sequential(*layers)
|
231 |
+
|
232 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
233 |
+
# See note [TorchScript super()]
|
234 |
+
x = self.conv1(x)
|
235 |
+
x = self.bn1(x)
|
236 |
+
x = self.relu(x)
|
237 |
+
x = self.maxpool(x)
|
238 |
+
|
239 |
+
x = self.layer1(x)
|
240 |
+
x = self.layer2(x)
|
241 |
+
x = self.layer3(x)
|
242 |
+
x = self.layer4(x)
|
243 |
+
|
244 |
+
x = self.avgpool(x)
|
245 |
+
x = torch.flatten(x, 1)
|
246 |
+
x = self.fc(x)
|
247 |
+
|
248 |
+
return x
|
249 |
+
|
250 |
+
def forward(self, x: Tensor) -> Tensor:
|
251 |
+
return self._forward_impl(x)
|
252 |
+
|
253 |
+
def resnet18(weights=None):
|
254 |
+
# weights: path
|
255 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=4)
|
256 |
+
if weights is not None:
|
257 |
+
model.load_state_dict(torch.load(weights))
|
258 |
+
return model
|
259 |
+
|
260 |
+
def resnet10():
|
261 |
+
return ResNet(BasicBlock, [1, 1, 1, 1], num_classes=4)
|
config.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"MultiStreamCellSegModel"
|
4 |
+
],
|
5 |
+
"block_size": 2048,
|
6 |
+
"context": 128,
|
7 |
+
"device": "cpu",
|
8 |
+
"input_channels": 3,
|
9 |
+
"ksize": 15,
|
10 |
+
"min_overlap": 128,
|
11 |
+
"model_type": "cell_sribd",
|
12 |
+
"n_rays": 32,
|
13 |
+
"np_thres": 0.6,
|
14 |
+
"num_classes": 4,
|
15 |
+
"obj_size_thres": 100,
|
16 |
+
"overall_thres": 0.4,
|
17 |
+
"overlap": 0.5,
|
18 |
+
"roi_size": [
|
19 |
+
512,
|
20 |
+
512
|
21 |
+
],
|
22 |
+
"sw_batch_size": 4,
|
23 |
+
"torch_dtype": "float32",
|
24 |
+
"transformers_version": "4.27.1"
|
25 |
+
}
|
model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:460f2c3a9168220ef03404df983b923c7f59d6db873536cd44d9e4b7e4354f6c
|
3 |
+
size 135
|
models/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Created on Sun Mar 20 14:23:55 2022
|
5 |
+
|
6 |
+
@author: jma
|
7 |
+
"""
|
8 |
+
|
9 |
+
#from .unetr2d import UNETR2D
|
10 |
+
#from .swin_unetr import SwinUNETR
|
models/convnext.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
|
3 |
+
# All rights reserved.
|
4 |
+
|
5 |
+
# This source code is licensed under the license found in the
|
6 |
+
# LICENSE file in the root directory of this source tree.
|
7 |
+
|
8 |
+
from functools import partial
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from timm.models.layers import trunc_normal_, DropPath
|
13 |
+
from timm.models.registry import register_model
|
14 |
+
from monai.networks.layers.factories import Act, Conv, Pad, Pool
|
15 |
+
from monai.networks.layers.utils import get_norm_layer
|
16 |
+
from monai.utils.module import look_up_option
|
17 |
+
from typing import List, NamedTuple, Optional, Tuple, Type, Union
|
18 |
+
class Block(nn.Module):
|
19 |
+
r""" ConvNeXt Block. There are two equivalent implementations:
|
20 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
21 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
22 |
+
We use (2) as we find it slightly faster in PyTorch
|
23 |
+
|
24 |
+
Args:
|
25 |
+
dim (int): Number of input channels.
|
26 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
27 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
28 |
+
"""
|
29 |
+
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
|
30 |
+
super().__init__()
|
31 |
+
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
|
32 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
33 |
+
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
34 |
+
self.act = nn.GELU()
|
35 |
+
self.pwconv2 = nn.Linear(4 * dim, dim)
|
36 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
|
37 |
+
requires_grad=True) if layer_scale_init_value > 0 else None
|
38 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
input = x
|
42 |
+
x = self.dwconv(x)
|
43 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
44 |
+
x = self.norm(x)
|
45 |
+
x = self.pwconv1(x)
|
46 |
+
x = self.act(x)
|
47 |
+
x = self.pwconv2(x)
|
48 |
+
if self.gamma is not None:
|
49 |
+
x = self.gamma * x
|
50 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
51 |
+
|
52 |
+
x = input + self.drop_path(x)
|
53 |
+
return x
|
54 |
+
|
55 |
+
class ConvNeXt(nn.Module):
|
56 |
+
r""" ConvNeXt
|
57 |
+
A PyTorch impl of : `A ConvNet for the 2020s` -
|
58 |
+
https://arxiv.org/pdf/2201.03545.pdf
|
59 |
+
|
60 |
+
Args:
|
61 |
+
in_chans (int): Number of input image channels. Default: 3
|
62 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
63 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
64 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
65 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
66 |
+
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
|
67 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
68 |
+
"""
|
69 |
+
def __init__(self, in_chans=3, num_classes=21841,
|
70 |
+
depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
|
71 |
+
layer_scale_init_value=1e-6, head_init_scale=1., out_indices=[0, 1, 2, 3],
|
72 |
+
):
|
73 |
+
super().__init__()
|
74 |
+
# conv_type: Type[Union[nn.Conv1d, nn.Conv2d, nn.Conv3d]] = Conv["conv", 2]
|
75 |
+
# self._conv_stem = conv_type(self.in_channels, self.in_channels, kernel_size=3, stride=stride, bias=False)
|
76 |
+
# self._conv_stem_padding = _make_same_padder(self._conv_stem, current_image_size)
|
77 |
+
|
78 |
+
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
|
79 |
+
stem = nn.Sequential(
|
80 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
81 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
|
82 |
+
)
|
83 |
+
self.downsample_layers.append(stem)
|
84 |
+
for i in range(3):
|
85 |
+
downsample_layer = nn.Sequential(
|
86 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
87 |
+
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
|
88 |
+
)
|
89 |
+
self.downsample_layers.append(downsample_layer)
|
90 |
+
|
91 |
+
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
|
92 |
+
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
93 |
+
cur = 0
|
94 |
+
for i in range(4):
|
95 |
+
stage = nn.Sequential(
|
96 |
+
*[Block(dim=dims[i], drop_path=dp_rates[cur + j],
|
97 |
+
layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
|
98 |
+
)
|
99 |
+
self.stages.append(stage)
|
100 |
+
cur += depths[i]
|
101 |
+
|
102 |
+
|
103 |
+
self.out_indices = out_indices
|
104 |
+
|
105 |
+
norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
|
106 |
+
for i_layer in range(4):
|
107 |
+
layer = norm_layer(dims[i_layer])
|
108 |
+
layer_name = f'norm{i_layer}'
|
109 |
+
self.add_module(layer_name, layer)
|
110 |
+
self.apply(self._init_weights)
|
111 |
+
|
112 |
+
|
113 |
+
def _init_weights(self, m):
|
114 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
115 |
+
trunc_normal_(m.weight, std=.02)
|
116 |
+
nn.init.constant_(m.bias, 0)
|
117 |
+
|
118 |
+
def forward_features(self, x):
|
119 |
+
outs = []
|
120 |
+
|
121 |
+
for i in range(4):
|
122 |
+
x = self.downsample_layers[i](x)
|
123 |
+
x = self.stages[i](x)
|
124 |
+
if i in self.out_indices:
|
125 |
+
norm_layer = getattr(self, f'norm{i}')
|
126 |
+
x_out = norm_layer(x)
|
127 |
+
|
128 |
+
outs.append(x_out)
|
129 |
+
|
130 |
+
return tuple(outs)
|
131 |
+
|
132 |
+
def forward(self, x):
|
133 |
+
x = self.forward_features(x)
|
134 |
+
|
135 |
+
return x
|
136 |
+
|
137 |
+
class LayerNorm(nn.Module):
|
138 |
+
r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
139 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
140 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
141 |
+
with shape (batch_size, channels, height, width).
|
142 |
+
"""
|
143 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
144 |
+
super().__init__()
|
145 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
146 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
147 |
+
self.eps = eps
|
148 |
+
self.data_format = data_format
|
149 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
150 |
+
raise NotImplementedError
|
151 |
+
self.normalized_shape = (normalized_shape, )
|
152 |
+
|
153 |
+
def forward(self, x):
|
154 |
+
if self.data_format == "channels_last":
|
155 |
+
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
156 |
+
elif self.data_format == "channels_first":
|
157 |
+
u = x.mean(1, keepdim=True)
|
158 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
159 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
160 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
161 |
+
return x
|
162 |
+
|
163 |
+
|
164 |
+
model_urls = {
|
165 |
+
"convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
|
166 |
+
"convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
|
167 |
+
"convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
|
168 |
+
"convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
|
169 |
+
"convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
|
170 |
+
"convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
|
171 |
+
"convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
|
172 |
+
"convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
|
173 |
+
"convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
|
174 |
+
}
|
175 |
+
|
176 |
+
@register_model
|
177 |
+
def convnext_tiny(pretrained=False,in_22k=False, **kwargs):
|
178 |
+
model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
|
179 |
+
if pretrained:
|
180 |
+
url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k']
|
181 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
|
182 |
+
model.load_state_dict(checkpoint["model"])
|
183 |
+
return model
|
184 |
+
|
185 |
+
@register_model
|
186 |
+
def convnext_small(pretrained=False,in_22k=False, **kwargs):
|
187 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
|
188 |
+
if pretrained:
|
189 |
+
url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k']
|
190 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
191 |
+
model.load_state_dict(checkpoint["model"], strict=False)
|
192 |
+
return model
|
193 |
+
|
194 |
+
@register_model
|
195 |
+
def convnext_base(pretrained=False, in_22k=False, **kwargs):
|
196 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
|
197 |
+
if pretrained:
|
198 |
+
url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
|
199 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
200 |
+
model.load_state_dict(checkpoint["model"], strict=False)
|
201 |
+
return model
|
202 |
+
|
203 |
+
@register_model
|
204 |
+
def convnext_large(pretrained=False, in_22k=False, **kwargs):
|
205 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
|
206 |
+
if pretrained:
|
207 |
+
url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
|
208 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
209 |
+
model.load_state_dict(checkpoint["model"])
|
210 |
+
return model
|
211 |
+
|
212 |
+
@register_model
|
213 |
+
def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
|
214 |
+
model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
|
215 |
+
if pretrained:
|
216 |
+
assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
|
217 |
+
url = model_urls['convnext_xlarge_22k']
|
218 |
+
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
|
219 |
+
model.load_state_dict(checkpoint["model"])
|
220 |
+
return model
|
models/flexible_unet.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
from typing import List, Optional, Sequence, Tuple, Union
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
|
17 |
+
from monai.networks.blocks import UpSample
|
18 |
+
from monai.networks.layers.factories import Conv
|
19 |
+
from monai.networks.layers.utils import get_act_layer
|
20 |
+
from monai.networks.nets import EfficientNetBNFeatures
|
21 |
+
from monai.networks.nets.basic_unet import UpCat
|
22 |
+
from monai.utils import InterpolateMode
|
23 |
+
|
24 |
+
__all__ = ["FlexibleUNet"]
|
25 |
+
|
26 |
+
encoder_feature_channel = {
|
27 |
+
"efficientnet-b0": (16, 24, 40, 112, 320),
|
28 |
+
"efficientnet-b1": (16, 24, 40, 112, 320),
|
29 |
+
"efficientnet-b2": (16, 24, 48, 120, 352),
|
30 |
+
"efficientnet-b3": (24, 32, 48, 136, 384),
|
31 |
+
"efficientnet-b4": (24, 32, 56, 160, 448),
|
32 |
+
"efficientnet-b5": (24, 40, 64, 176, 512),
|
33 |
+
"efficientnet-b6": (32, 40, 72, 200, 576),
|
34 |
+
"efficientnet-b7": (32, 48, 80, 224, 640),
|
35 |
+
"efficientnet-b8": (32, 56, 88, 248, 704),
|
36 |
+
"efficientnet-l2": (72, 104, 176, 480, 1376),
|
37 |
+
}
|
38 |
+
|
39 |
+
|
40 |
+
def _get_encoder_channels_by_backbone(backbone: str, in_channels: int = 3) -> tuple:
|
41 |
+
"""
|
42 |
+
Get the encoder output channels by given backbone name.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
backbone: name of backbone to generate features, can be from [efficientnet-b0, ..., efficientnet-b7].
|
46 |
+
in_channels: channel of input tensor, default to 3.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
A tuple of output feature map channels' length .
|
50 |
+
"""
|
51 |
+
encoder_channel_tuple = encoder_feature_channel[backbone]
|
52 |
+
encoder_channel_list = [in_channels] + list(encoder_channel_tuple)
|
53 |
+
encoder_channel = tuple(encoder_channel_list)
|
54 |
+
return encoder_channel
|
55 |
+
|
56 |
+
|
57 |
+
class UNetDecoder(nn.Module):
|
58 |
+
"""
|
59 |
+
UNet Decoder.
|
60 |
+
This class refers to `segmentation_models.pytorch
|
61 |
+
<https://github.com/qubvel/segmentation_models.pytorch>`_.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
spatial_dims: number of spatial dimensions.
|
65 |
+
encoder_channels: number of output channels for all feature maps in encoder.
|
66 |
+
`len(encoder_channels)` should be no less than 2.
|
67 |
+
decoder_channels: number of output channels for all feature maps in decoder.
|
68 |
+
`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.
|
69 |
+
act: activation type and arguments.
|
70 |
+
norm: feature normalization type and arguments.
|
71 |
+
dropout: dropout ratio.
|
72 |
+
bias: whether to have a bias term in convolution blocks in this decoder.
|
73 |
+
upsample: upsampling mode, available options are
|
74 |
+
``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
|
75 |
+
pre_conv: a conv block applied before upsampling.
|
76 |
+
Only used in the "nontrainable" or "pixelshuffle" mode.
|
77 |
+
interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
|
78 |
+
Only used in the "nontrainable" mode.
|
79 |
+
align_corners: set the align_corners parameter for upsample. Defaults to True.
|
80 |
+
Only used in the "nontrainable" mode.
|
81 |
+
is_pad: whether to pad upsampling features to fit the encoder spatial dims.
|
82 |
+
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(
|
86 |
+
self,
|
87 |
+
spatial_dims: int,
|
88 |
+
encoder_channels: Sequence[int],
|
89 |
+
decoder_channels: Sequence[int],
|
90 |
+
act: Union[str, tuple],
|
91 |
+
norm: Union[str, tuple],
|
92 |
+
dropout: Union[float, tuple],
|
93 |
+
bias: bool,
|
94 |
+
upsample: str,
|
95 |
+
pre_conv: Optional[str],
|
96 |
+
interp_mode: str,
|
97 |
+
align_corners: Optional[bool],
|
98 |
+
is_pad: bool,
|
99 |
+
):
|
100 |
+
|
101 |
+
super().__init__()
|
102 |
+
if len(encoder_channels) < 2:
|
103 |
+
raise ValueError("the length of `encoder_channels` should be no less than 2.")
|
104 |
+
if len(decoder_channels) != len(encoder_channels) - 1:
|
105 |
+
raise ValueError("`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.")
|
106 |
+
|
107 |
+
in_channels = [encoder_channels[-1]] + list(decoder_channels[:-1])
|
108 |
+
skip_channels = list(encoder_channels[1:-1][::-1]) + [0]
|
109 |
+
halves = [True] * (len(skip_channels) - 1)
|
110 |
+
halves.append(False)
|
111 |
+
blocks = []
|
112 |
+
for in_chn, skip_chn, out_chn, halve in zip(in_channels, skip_channels, decoder_channels, halves):
|
113 |
+
blocks.append(
|
114 |
+
UpCat(
|
115 |
+
spatial_dims=spatial_dims,
|
116 |
+
in_chns=in_chn,
|
117 |
+
cat_chns=skip_chn,
|
118 |
+
out_chns=out_chn,
|
119 |
+
act=act,
|
120 |
+
norm=norm,
|
121 |
+
dropout=dropout,
|
122 |
+
bias=bias,
|
123 |
+
upsample=upsample,
|
124 |
+
pre_conv=pre_conv,
|
125 |
+
interp_mode=interp_mode,
|
126 |
+
align_corners=align_corners,
|
127 |
+
halves=halve,
|
128 |
+
is_pad=is_pad,
|
129 |
+
)
|
130 |
+
)
|
131 |
+
self.blocks = nn.ModuleList(blocks)
|
132 |
+
|
133 |
+
def forward(self, features: List[torch.Tensor], skip_connect: int = 4):
|
134 |
+
skips = features[:-1][::-1]
|
135 |
+
features = features[1:][::-1]
|
136 |
+
|
137 |
+
x = features[0]
|
138 |
+
for i, block in enumerate(self.blocks):
|
139 |
+
if i < skip_connect:
|
140 |
+
skip = skips[i]
|
141 |
+
else:
|
142 |
+
skip = None
|
143 |
+
x = block(x, skip)
|
144 |
+
|
145 |
+
return x
|
146 |
+
|
147 |
+
|
148 |
+
class SegmentationHead(nn.Sequential):
|
149 |
+
"""
|
150 |
+
Segmentation head.
|
151 |
+
This class refers to `segmentation_models.pytorch
|
152 |
+
<https://github.com/qubvel/segmentation_models.pytorch>`_.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
spatial_dims: number of spatial dimensions.
|
156 |
+
in_channels: number of input channels for the block.
|
157 |
+
out_channels: number of output channels for the block.
|
158 |
+
kernel_size: kernel size for the conv layer.
|
159 |
+
act: activation type and arguments.
|
160 |
+
scale_factor: multiplier for spatial size. Has to match input size if it is a tuple.
|
161 |
+
|
162 |
+
"""
|
163 |
+
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
spatial_dims: int,
|
167 |
+
in_channels: int,
|
168 |
+
out_channels: int,
|
169 |
+
kernel_size: int = 3,
|
170 |
+
act: Optional[Union[Tuple, str]] = None,
|
171 |
+
scale_factor: float = 1.0,
|
172 |
+
):
|
173 |
+
|
174 |
+
conv_layer = Conv[Conv.CONV, spatial_dims](
|
175 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=kernel_size // 2
|
176 |
+
)
|
177 |
+
up_layer: nn.Module = nn.Identity()
|
178 |
+
if scale_factor > 1.0:
|
179 |
+
up_layer = UpSample(
|
180 |
+
spatial_dims=spatial_dims,
|
181 |
+
scale_factor=scale_factor,
|
182 |
+
mode="nontrainable",
|
183 |
+
pre_conv=None,
|
184 |
+
interp_mode=InterpolateMode.LINEAR,
|
185 |
+
)
|
186 |
+
if act is not None:
|
187 |
+
act_layer = get_act_layer(act)
|
188 |
+
else:
|
189 |
+
act_layer = nn.Identity()
|
190 |
+
super().__init__(conv_layer, up_layer, act_layer)
|
191 |
+
|
192 |
+
|
193 |
+
class FlexibleUNet(nn.Module):
|
194 |
+
"""
|
195 |
+
A flexible implementation of UNet-like encoder-decoder architecture.
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(
|
199 |
+
self,
|
200 |
+
in_channels: int,
|
201 |
+
out_channels: int,
|
202 |
+
backbone: str,
|
203 |
+
pretrained: bool = False,
|
204 |
+
decoder_channels: Tuple = (256, 128, 64, 32, 16),
|
205 |
+
spatial_dims: int = 2,
|
206 |
+
norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}),
|
207 |
+
act: Union[str, tuple] = ("relu", {"inplace": True}),
|
208 |
+
dropout: Union[float, tuple] = 0.0,
|
209 |
+
decoder_bias: bool = False,
|
210 |
+
upsample: str = "nontrainable",
|
211 |
+
interp_mode: str = "nearest",
|
212 |
+
is_pad: bool = True,
|
213 |
+
) -> None:
|
214 |
+
"""
|
215 |
+
A flexible implement of UNet, in which the backbone/encoder can be replaced with
|
216 |
+
any efficient network. Currently the input must have a 2 or 3 spatial dimension
|
217 |
+
and the spatial size of each dimension must be a multiple of 32 if is pad parameter
|
218 |
+
is False
|
219 |
+
|
220 |
+
Args:
|
221 |
+
in_channels: number of input channels.
|
222 |
+
out_channels: number of output channels.
|
223 |
+
backbone: name of backbones to initialize, only support efficientnet right now,
|
224 |
+
can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
|
225 |
+
pretrained: whether to initialize pretrained ImageNet weights, only available
|
226 |
+
for spatial_dims=2 and batch norm is used, default to False.
|
227 |
+
decoder_channels: number of output channels for all feature maps in decoder.
|
228 |
+
`len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
|
229 |
+
to (256, 128, 64, 32, 16).
|
230 |
+
spatial_dims: number of spatial dimensions, default to 2.
|
231 |
+
norm: normalization type and arguments, default to ("batch", {"eps": 1e-3,
|
232 |
+
"momentum": 0.1}).
|
233 |
+
act: activation type and arguments, default to ("relu", {"inplace": True}).
|
234 |
+
dropout: dropout ratio, default to 0.0.
|
235 |
+
decoder_bias: whether to have a bias term in decoder's convolution blocks.
|
236 |
+
upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``,
|
237 |
+
``"nontrainable"``.
|
238 |
+
interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
|
239 |
+
Only used in the "nontrainable" mode.
|
240 |
+
is_pad: whether to pad upsampling features to fit features from encoder. Default to True.
|
241 |
+
If this parameter is set to "True", the spatial dim of network input can be arbitary
|
242 |
+
size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32.
|
243 |
+
"""
|
244 |
+
super().__init__()
|
245 |
+
|
246 |
+
if backbone not in encoder_feature_channel:
|
247 |
+
raise ValueError(f"invalid model_name {backbone} found, must be one of {encoder_feature_channel.keys()}.")
|
248 |
+
|
249 |
+
if spatial_dims not in (2, 3):
|
250 |
+
raise ValueError("spatial_dims can only be 2 or 3.")
|
251 |
+
|
252 |
+
adv_prop = "ap" in backbone
|
253 |
+
|
254 |
+
self.backbone = backbone
|
255 |
+
self.spatial_dims = spatial_dims
|
256 |
+
model_name = backbone
|
257 |
+
encoder_channels = _get_encoder_channels_by_backbone(backbone, in_channels)
|
258 |
+
self.encoder = EfficientNetBNFeatures(
|
259 |
+
model_name=model_name,
|
260 |
+
pretrained=pretrained,
|
261 |
+
in_channels=in_channels,
|
262 |
+
spatial_dims=spatial_dims,
|
263 |
+
norm=norm,
|
264 |
+
adv_prop=adv_prop,
|
265 |
+
)
|
266 |
+
self.decoder = UNetDecoder(
|
267 |
+
spatial_dims=spatial_dims,
|
268 |
+
encoder_channels=encoder_channels,
|
269 |
+
decoder_channels=decoder_channels,
|
270 |
+
act=act,
|
271 |
+
norm=norm,
|
272 |
+
dropout=dropout,
|
273 |
+
bias=decoder_bias,
|
274 |
+
upsample=upsample,
|
275 |
+
interp_mode=interp_mode,
|
276 |
+
pre_conv=None,
|
277 |
+
align_corners=None,
|
278 |
+
is_pad=is_pad,
|
279 |
+
)
|
280 |
+
self.dist_head = SegmentationHead(
|
281 |
+
spatial_dims=spatial_dims,
|
282 |
+
in_channels=decoder_channels[-1],
|
283 |
+
out_channels=32,
|
284 |
+
kernel_size=1,
|
285 |
+
act='relu',
|
286 |
+
)
|
287 |
+
self.prob_head = SegmentationHead(
|
288 |
+
spatial_dims=spatial_dims,
|
289 |
+
in_channels=decoder_channels[-1],
|
290 |
+
out_channels=1,
|
291 |
+
kernel_size=1,
|
292 |
+
act='sigmoid',
|
293 |
+
)
|
294 |
+
|
295 |
+
def forward(self, inputs: torch.Tensor):
|
296 |
+
"""
|
297 |
+
Do a typical encoder-decoder-header inference.
|
298 |
+
|
299 |
+
Args:
|
300 |
+
inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``,
|
301 |
+
N is defined by `dimensions`.
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.
|
305 |
+
|
306 |
+
"""
|
307 |
+
x = inputs
|
308 |
+
enc_out = self.encoder(x)
|
309 |
+
decoder_out = self.decoder(enc_out)
|
310 |
+
dist = self.dist_head(decoder_out)
|
311 |
+
prob = self.prob_head(decoder_out)
|
312 |
+
return dist,prob
|
models/flexible_unet_convnext.py
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
from typing import List, Optional, Sequence, Tuple, Union
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
from . import convnext
|
17 |
+
from monai.networks.blocks import UpSample
|
18 |
+
from monai.networks.layers.factories import Conv
|
19 |
+
from monai.networks.layers.utils import get_act_layer
|
20 |
+
from monai.networks.nets import EfficientNetBNFeatures
|
21 |
+
from monai.networks.nets.basic_unet import UpCat
|
22 |
+
from monai.utils import InterpolateMode
|
23 |
+
|
24 |
+
__all__ = ["FlexibleUNet"]
|
25 |
+
|
26 |
+
encoder_feature_channel = {
|
27 |
+
"efficientnet-b0": (16, 24, 40, 112, 320),
|
28 |
+
"efficientnet-b1": (16, 24, 40, 112, 320),
|
29 |
+
"efficientnet-b2": (16, 24, 48, 120, 352),
|
30 |
+
"efficientnet-b3": (24, 32, 48, 136, 384),
|
31 |
+
"efficientnet-b4": (24, 32, 56, 160, 448),
|
32 |
+
"efficientnet-b5": (24, 40, 64, 176, 512),
|
33 |
+
"efficientnet-b6": (32, 40, 72, 200, 576),
|
34 |
+
"efficientnet-b7": (32, 48, 80, 224, 640),
|
35 |
+
"efficientnet-b8": (32, 56, 88, 248, 704),
|
36 |
+
"efficientnet-l2": (72, 104, 176, 480, 1376),
|
37 |
+
"convnext_small": (96, 192, 384, 768),
|
38 |
+
"convnext_base": (128, 256, 512, 1024),
|
39 |
+
"van_b2": (64, 128, 320, 512),
|
40 |
+
"van_b1": (64, 128, 320, 512),
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
def _get_encoder_channels_by_backbone(backbone: str, in_channels: int = 3) -> tuple:
|
45 |
+
"""
|
46 |
+
Get the encoder output channels by given backbone name.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
backbone: name of backbone to generate features, can be from [efficientnet-b0, ..., efficientnet-b7].
|
50 |
+
in_channels: channel of input tensor, default to 3.
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
A tuple of output feature map channels' length .
|
54 |
+
"""
|
55 |
+
encoder_channel_tuple = encoder_feature_channel[backbone]
|
56 |
+
encoder_channel_list = [in_channels] + list(encoder_channel_tuple)
|
57 |
+
encoder_channel = tuple(encoder_channel_list)
|
58 |
+
return encoder_channel
|
59 |
+
|
60 |
+
|
61 |
+
class UNetDecoder(nn.Module):
|
62 |
+
"""
|
63 |
+
UNet Decoder.
|
64 |
+
This class refers to `segmentation_models.pytorch
|
65 |
+
<https://github.com/qubvel/segmentation_models.pytorch>`_.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
spatial_dims: number of spatial dimensions.
|
69 |
+
encoder_channels: number of output channels for all feature maps in encoder.
|
70 |
+
`len(encoder_channels)` should be no less than 2.
|
71 |
+
decoder_channels: number of output channels for all feature maps in decoder.
|
72 |
+
`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.
|
73 |
+
act: activation type and arguments.
|
74 |
+
norm: feature normalization type and arguments.
|
75 |
+
dropout: dropout ratio.
|
76 |
+
bias: whether to have a bias term in convolution blocks in this decoder.
|
77 |
+
upsample: upsampling mode, available options are
|
78 |
+
``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
|
79 |
+
pre_conv: a conv block applied before upsampling.
|
80 |
+
Only used in the "nontrainable" or "pixelshuffle" mode.
|
81 |
+
interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
|
82 |
+
Only used in the "nontrainable" mode.
|
83 |
+
align_corners: set the align_corners parameter for upsample. Defaults to True.
|
84 |
+
Only used in the "nontrainable" mode.
|
85 |
+
is_pad: whether to pad upsampling features to fit the encoder spatial dims.
|
86 |
+
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
spatial_dims: int,
|
92 |
+
encoder_channels: Sequence[int],
|
93 |
+
decoder_channels: Sequence[int],
|
94 |
+
act: Union[str, tuple],
|
95 |
+
norm: Union[str, tuple],
|
96 |
+
dropout: Union[float, tuple],
|
97 |
+
bias: bool,
|
98 |
+
upsample: str,
|
99 |
+
pre_conv: Optional[str],
|
100 |
+
interp_mode: str,
|
101 |
+
align_corners: Optional[bool],
|
102 |
+
is_pad: bool,
|
103 |
+
):
|
104 |
+
|
105 |
+
super().__init__()
|
106 |
+
if len(encoder_channels) < 2:
|
107 |
+
raise ValueError("the length of `encoder_channels` should be no less than 2.")
|
108 |
+
if len(decoder_channels) != len(encoder_channels) - 1:
|
109 |
+
raise ValueError("`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.")
|
110 |
+
|
111 |
+
in_channels = [encoder_channels[-1]] + list(decoder_channels[:-1])
|
112 |
+
skip_channels = list(encoder_channels[1:-1][::-1]) + [0]
|
113 |
+
halves = [True] * (len(skip_channels) - 1)
|
114 |
+
halves.append(False)
|
115 |
+
blocks = []
|
116 |
+
for in_chn, skip_chn, out_chn, halve in zip(in_channels, skip_channels, decoder_channels, halves):
|
117 |
+
blocks.append(
|
118 |
+
UpCat(
|
119 |
+
spatial_dims=spatial_dims,
|
120 |
+
in_chns=in_chn,
|
121 |
+
cat_chns=skip_chn,
|
122 |
+
out_chns=out_chn,
|
123 |
+
act=act,
|
124 |
+
norm=norm,
|
125 |
+
dropout=dropout,
|
126 |
+
bias=bias,
|
127 |
+
upsample=upsample,
|
128 |
+
pre_conv=pre_conv,
|
129 |
+
interp_mode=interp_mode,
|
130 |
+
align_corners=align_corners,
|
131 |
+
halves=halve,
|
132 |
+
is_pad=is_pad,
|
133 |
+
)
|
134 |
+
)
|
135 |
+
self.blocks = nn.ModuleList(blocks)
|
136 |
+
|
137 |
+
def forward(self, features: List[torch.Tensor], skip_connect: int = 3):
|
138 |
+
skips = features[:-1][::-1]
|
139 |
+
features = features[1:][::-1]
|
140 |
+
|
141 |
+
x = features[0]
|
142 |
+
for i, block in enumerate(self.blocks):
|
143 |
+
if i < skip_connect:
|
144 |
+
skip = skips[i]
|
145 |
+
else:
|
146 |
+
skip = None
|
147 |
+
x = block(x, skip)
|
148 |
+
|
149 |
+
return x
|
150 |
+
|
151 |
+
|
152 |
+
class SegmentationHead(nn.Sequential):
|
153 |
+
"""
|
154 |
+
Segmentation head.
|
155 |
+
This class refers to `segmentation_models.pytorch
|
156 |
+
<https://github.com/qubvel/segmentation_models.pytorch>`_.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
spatial_dims: number of spatial dimensions.
|
160 |
+
in_channels: number of input channels for the block.
|
161 |
+
out_channels: number of output channels for the block.
|
162 |
+
kernel_size: kernel size for the conv layer.
|
163 |
+
act: activation type and arguments.
|
164 |
+
scale_factor: multiplier for spatial size. Has to match input size if it is a tuple.
|
165 |
+
|
166 |
+
"""
|
167 |
+
|
168 |
+
def __init__(
|
169 |
+
self,
|
170 |
+
spatial_dims: int,
|
171 |
+
in_channels: int,
|
172 |
+
out_channels: int,
|
173 |
+
kernel_size: int = 3,
|
174 |
+
act: Optional[Union[Tuple, str]] = None,
|
175 |
+
scale_factor: float = 1.0,
|
176 |
+
):
|
177 |
+
|
178 |
+
conv_layer = Conv[Conv.CONV, spatial_dims](
|
179 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=kernel_size // 2
|
180 |
+
)
|
181 |
+
up_layer: nn.Module = nn.Identity()
|
182 |
+
# if scale_factor > 1.0:
|
183 |
+
# up_layer = UpSample(
|
184 |
+
# in_channels=out_channels,
|
185 |
+
# spatial_dims=spatial_dims,
|
186 |
+
# scale_factor=scale_factor,
|
187 |
+
# mode="deconv",
|
188 |
+
# pre_conv=None,
|
189 |
+
# interp_mode=InterpolateMode.LINEAR,
|
190 |
+
# )
|
191 |
+
if scale_factor > 1.0:
|
192 |
+
up_layer = UpSample(
|
193 |
+
spatial_dims=spatial_dims,
|
194 |
+
scale_factor=scale_factor,
|
195 |
+
mode="nontrainable",
|
196 |
+
pre_conv=None,
|
197 |
+
interp_mode=InterpolateMode.LINEAR,
|
198 |
+
)
|
199 |
+
if act is not None:
|
200 |
+
act_layer = get_act_layer(act)
|
201 |
+
else:
|
202 |
+
act_layer = nn.Identity()
|
203 |
+
super().__init__(conv_layer, up_layer, act_layer)
|
204 |
+
|
205 |
+
|
206 |
+
class FlexibleUNet_star(nn.Module):
|
207 |
+
"""
|
208 |
+
A flexible implementation of UNet-like encoder-decoder architecture.
|
209 |
+
"""
|
210 |
+
|
211 |
+
def __init__(
|
212 |
+
self,
|
213 |
+
in_channels: int,
|
214 |
+
out_channels: int,
|
215 |
+
backbone: str,
|
216 |
+
pretrained: bool = False,
|
217 |
+
decoder_channels: Tuple = (256, 128, 64, 32),
|
218 |
+
#decoder_channels: Tuple = (1024, 512, 256, 128),
|
219 |
+
spatial_dims: int = 2,
|
220 |
+
norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}),
|
221 |
+
act: Union[str, tuple] = ("relu", {"inplace": True}),
|
222 |
+
dropout: Union[float, tuple] = 0.0,
|
223 |
+
decoder_bias: bool = False,
|
224 |
+
upsample: str = "nontrainable",
|
225 |
+
interp_mode: str = "nearest",
|
226 |
+
is_pad: bool = True,
|
227 |
+
n_rays: int = 32,
|
228 |
+
prob_out_channels: int = 1,
|
229 |
+
) -> None:
|
230 |
+
"""
|
231 |
+
A flexible implement of UNet, in which the backbone/encoder can be replaced with
|
232 |
+
any efficient network. Currently the input must have a 2 or 3 spatial dimension
|
233 |
+
and the spatial size of each dimension must be a multiple of 32 if is pad parameter
|
234 |
+
is False
|
235 |
+
|
236 |
+
Args:
|
237 |
+
in_channels: number of input channels.
|
238 |
+
out_channels: number of output channels.
|
239 |
+
backbone: name of backbones to initialize, only support efficientnet right now,
|
240 |
+
can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
|
241 |
+
pretrained: whether to initialize pretrained ImageNet weights, only available
|
242 |
+
for spatial_dims=2 and batch norm is used, default to False.
|
243 |
+
decoder_channels: number of output channels for all feature maps in decoder.
|
244 |
+
`len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
|
245 |
+
to (256, 128, 64, 32, 16).
|
246 |
+
spatial_dims: number of spatial dimensions, default to 2.
|
247 |
+
norm: normalization type and arguments, default to ("batch", {"eps": 1e-3,
|
248 |
+
"momentum": 0.1}).
|
249 |
+
act: activation type and arguments, default to ("relu", {"inplace": True}).
|
250 |
+
dropout: dropout ratio, default to 0.0.
|
251 |
+
decoder_bias: whether to have a bias term in decoder's convolution blocks.
|
252 |
+
upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``,
|
253 |
+
``"nontrainable"``.
|
254 |
+
interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
|
255 |
+
Only used in the "nontrainable" mode.
|
256 |
+
is_pad: whether to pad upsampling features to fit features from encoder. Default to True.
|
257 |
+
If this parameter is set to "True", the spatial dim of network input can be arbitary
|
258 |
+
size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32.
|
259 |
+
"""
|
260 |
+
super().__init__()
|
261 |
+
|
262 |
+
if backbone not in encoder_feature_channel:
|
263 |
+
raise ValueError(f"invalid model_name {backbone} found, must be one of {encoder_feature_channel.keys()}.")
|
264 |
+
|
265 |
+
if spatial_dims not in (2, 3):
|
266 |
+
raise ValueError("spatial_dims can only be 2 or 3.")
|
267 |
+
|
268 |
+
adv_prop = "ap" in backbone
|
269 |
+
|
270 |
+
self.backbone = backbone
|
271 |
+
self.spatial_dims = spatial_dims
|
272 |
+
model_name = backbone
|
273 |
+
encoder_channels = _get_encoder_channels_by_backbone(backbone, in_channels)
|
274 |
+
|
275 |
+
self.encoder = convnext.convnext_small(pretrained=False,in_22k=True)
|
276 |
+
|
277 |
+
self.decoder = UNetDecoder(
|
278 |
+
spatial_dims=spatial_dims,
|
279 |
+
encoder_channels=encoder_channels,
|
280 |
+
decoder_channels=decoder_channels,
|
281 |
+
act=act,
|
282 |
+
norm=norm,
|
283 |
+
dropout=dropout,
|
284 |
+
bias=decoder_bias,
|
285 |
+
upsample=upsample,
|
286 |
+
interp_mode=interp_mode,
|
287 |
+
pre_conv=None,
|
288 |
+
align_corners=None,
|
289 |
+
is_pad=is_pad,
|
290 |
+
)
|
291 |
+
self.dist_head = SegmentationHead(
|
292 |
+
spatial_dims=spatial_dims,
|
293 |
+
in_channels=decoder_channels[-1],
|
294 |
+
out_channels=n_rays,
|
295 |
+
kernel_size=1,
|
296 |
+
act='relu',
|
297 |
+
scale_factor = 2,
|
298 |
+
)
|
299 |
+
self.prob_head = SegmentationHead(
|
300 |
+
spatial_dims=spatial_dims,
|
301 |
+
in_channels=decoder_channels[-1],
|
302 |
+
out_channels=prob_out_channels,
|
303 |
+
kernel_size=1,
|
304 |
+
act='sigmoid',
|
305 |
+
scale_factor = 2,
|
306 |
+
)
|
307 |
+
|
308 |
+
def forward(self, inputs: torch.Tensor):
|
309 |
+
"""
|
310 |
+
Do a typical encoder-decoder-header inference.
|
311 |
+
|
312 |
+
Args:
|
313 |
+
inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``,
|
314 |
+
N is defined by `dimensions`.
|
315 |
+
|
316 |
+
Returns:
|
317 |
+
A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.
|
318 |
+
|
319 |
+
"""
|
320 |
+
x = inputs
|
321 |
+
enc_out = self.encoder(x)
|
322 |
+
decoder_out = self.decoder(enc_out)
|
323 |
+
|
324 |
+
dist = self.dist_head(decoder_out)
|
325 |
+
prob = self.prob_head(decoder_out)
|
326 |
+
|
327 |
+
return dist,prob
|
328 |
+
|
329 |
+
|
330 |
+
|
331 |
+
class FlexibleUNet_hv(nn.Module):
|
332 |
+
"""
|
333 |
+
A flexible implementation of UNet-like encoder-decoder architecture.
|
334 |
+
"""
|
335 |
+
|
336 |
+
def __init__(
|
337 |
+
self,
|
338 |
+
in_channels: int,
|
339 |
+
out_channels: int,
|
340 |
+
backbone: str,
|
341 |
+
pretrained: bool = False,
|
342 |
+
decoder_channels: Tuple = (1024, 512, 256, 128),
|
343 |
+
spatial_dims: int = 2,
|
344 |
+
norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}),
|
345 |
+
act: Union[str, tuple] = ("relu", {"inplace": True}),
|
346 |
+
dropout: Union[float, tuple] = 0.0,
|
347 |
+
decoder_bias: bool = False,
|
348 |
+
upsample: str = "nontrainable",
|
349 |
+
interp_mode: str = "nearest",
|
350 |
+
is_pad: bool = True,
|
351 |
+
n_rays: int = 32,
|
352 |
+
prob_out_channels: int = 1,
|
353 |
+
) -> None:
|
354 |
+
"""
|
355 |
+
A flexible implement of UNet, in which the backbone/encoder can be replaced with
|
356 |
+
any efficient network. Currently the input must have a 2 or 3 spatial dimension
|
357 |
+
and the spatial size of each dimension must be a multiple of 32 if is pad parameter
|
358 |
+
is False
|
359 |
+
|
360 |
+
Args:
|
361 |
+
in_channels: number of input channels.
|
362 |
+
out_channels: number of output channels.
|
363 |
+
backbone: name of backbones to initialize, only support efficientnet right now,
|
364 |
+
can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
|
365 |
+
pretrained: whether to initialize pretrained ImageNet weights, only available
|
366 |
+
for spatial_dims=2 and batch norm is used, default to False.
|
367 |
+
decoder_channels: number of output channels for all feature maps in decoder.
|
368 |
+
`len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
|
369 |
+
to (256, 128, 64, 32, 16).
|
370 |
+
spatial_dims: number of spatial dimensions, default to 2.
|
371 |
+
norm: normalization type and arguments, default to ("batch", {"eps": 1e-3,
|
372 |
+
"momentum": 0.1}).
|
373 |
+
act: activation type and arguments, default to ("relu", {"inplace": True}).
|
374 |
+
dropout: dropout ratio, default to 0.0.
|
375 |
+
decoder_bias: whether to have a bias term in decoder's convolution blocks.
|
376 |
+
upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``,
|
377 |
+
``"nontrainable"``.
|
378 |
+
interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
|
379 |
+
Only used in the "nontrainable" mode.
|
380 |
+
is_pad: whether to pad upsampling features to fit features from encoder. Default to True.
|
381 |
+
If this parameter is set to "True", the spatial dim of network input can be arbitary
|
382 |
+
size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32.
|
383 |
+
"""
|
384 |
+
super().__init__()
|
385 |
+
|
386 |
+
if backbone not in encoder_feature_channel:
|
387 |
+
raise ValueError(f"invalid model_name {backbone} found, must be one of {encoder_feature_channel.keys()}.")
|
388 |
+
|
389 |
+
if spatial_dims not in (2, 3):
|
390 |
+
raise ValueError("spatial_dims can only be 2 or 3.")
|
391 |
+
|
392 |
+
adv_prop = "ap" in backbone
|
393 |
+
|
394 |
+
self.backbone = backbone
|
395 |
+
self.spatial_dims = spatial_dims
|
396 |
+
model_name = backbone
|
397 |
+
encoder_channels = _get_encoder_channels_by_backbone(backbone, in_channels)
|
398 |
+
self.encoder = convnext.convnext_small(pretrained=False,in_22k=True)
|
399 |
+
self.decoder = UNetDecoder(
|
400 |
+
spatial_dims=spatial_dims,
|
401 |
+
encoder_channels=encoder_channels,
|
402 |
+
decoder_channels=decoder_channels,
|
403 |
+
act=act,
|
404 |
+
norm=norm,
|
405 |
+
dropout=dropout,
|
406 |
+
bias=decoder_bias,
|
407 |
+
upsample=upsample,
|
408 |
+
interp_mode=interp_mode,
|
409 |
+
pre_conv=None,
|
410 |
+
align_corners=None,
|
411 |
+
is_pad=is_pad,
|
412 |
+
)
|
413 |
+
self.dist_head = SegmentationHead(
|
414 |
+
spatial_dims=spatial_dims,
|
415 |
+
in_channels=decoder_channels[-1],
|
416 |
+
out_channels=n_rays,
|
417 |
+
kernel_size=1,
|
418 |
+
act=None,
|
419 |
+
scale_factor = 2,
|
420 |
+
)
|
421 |
+
self.prob_head = SegmentationHead(
|
422 |
+
spatial_dims=spatial_dims,
|
423 |
+
in_channels=decoder_channels[-1],
|
424 |
+
out_channels=prob_out_channels,
|
425 |
+
kernel_size=1,
|
426 |
+
act='sigmoid',
|
427 |
+
scale_factor = 2,
|
428 |
+
)
|
429 |
+
|
430 |
+
def forward(self, inputs: torch.Tensor):
|
431 |
+
"""
|
432 |
+
Do a typical encoder-decoder-header inference.
|
433 |
+
|
434 |
+
Args:
|
435 |
+
inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``,
|
436 |
+
N is defined by `dimensions`.
|
437 |
+
|
438 |
+
Returns:
|
439 |
+
A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.
|
440 |
+
|
441 |
+
"""
|
442 |
+
x = inputs
|
443 |
+
enc_out = self.encoder(x)
|
444 |
+
decoder_out = self.decoder(enc_out)
|
445 |
+
dist = self.dist_head(decoder_out)
|
446 |
+
prob = self.prob_head(decoder_out)
|
447 |
+
return dist,prob
|
overlay.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
|
4 |
+
###overlay
|
5 |
+
import cv2
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
import colorsys
|
9 |
+
import numpy as np
|
10 |
+
import itertools
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
from matplotlib import cm
|
13 |
+
import os
|
14 |
+
import scipy.io as io
|
15 |
+
def get_bounding_box(img):
|
16 |
+
"""Get bounding box coordinate information."""
|
17 |
+
rows = np.any(img, axis=1)
|
18 |
+
cols = np.any(img, axis=0)
|
19 |
+
rmin, rmax = np.where(rows)[0][[0, -1]]
|
20 |
+
cmin, cmax = np.where(cols)[0][[0, -1]]
|
21 |
+
# due to python indexing, need to add 1 to max
|
22 |
+
# else accessing will be 1px in the box, not out
|
23 |
+
rmax += 1
|
24 |
+
cmax += 1
|
25 |
+
return [rmin, rmax, cmin, cmax]
|
26 |
+
####
|
27 |
+
def colorize(ch, vmin, vmax):
|
28 |
+
"""Will clamp value value outside the provided range to vmax and vmin."""
|
29 |
+
cmap = plt.get_cmap("jet")
|
30 |
+
ch = np.squeeze(ch.astype("float32"))
|
31 |
+
vmin = vmin if vmin is not None else ch.min()
|
32 |
+
vmax = vmax if vmax is not None else ch.max()
|
33 |
+
ch[ch > vmax] = vmax # clamp value
|
34 |
+
ch[ch < vmin] = vmin
|
35 |
+
ch = (ch - vmin) / (vmax - vmin + 1.0e-16)
|
36 |
+
# take RGB from RGBA heat map
|
37 |
+
ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8")
|
38 |
+
return ch_cmap
|
39 |
+
|
40 |
+
|
41 |
+
####
|
42 |
+
def random_colors(N, bright=True):
|
43 |
+
"""Generate random colors.
|
44 |
+
|
45 |
+
To get visually distinct colors, generate them in HSV space then
|
46 |
+
convert to RGB.
|
47 |
+
"""
|
48 |
+
brightness = 1.0 if bright else 0.7
|
49 |
+
hsv = [(i / N, 1, brightness) for i in range(N)]
|
50 |
+
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
|
51 |
+
random.shuffle(colors)
|
52 |
+
return colors
|
53 |
+
|
54 |
+
|
55 |
+
####
|
56 |
+
def visualize_instances_map(
|
57 |
+
input_image, inst_map, type_map=None, type_colour=None, line_thickness=2
|
58 |
+
):
|
59 |
+
"""Overlays segmentation results on image as contours.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
input_image: input image
|
63 |
+
inst_map: instance mask with unique value for every object
|
64 |
+
type_map: type mask with unique value for every class
|
65 |
+
type_colour: a dict of {type : colour} , `type` is from 0-N
|
66 |
+
and `colour` is a tuple of (R, G, B)
|
67 |
+
line_thickness: line thickness of contours
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
overlay: output image with segmentation overlay as contours
|
71 |
+
"""
|
72 |
+
overlay = np.copy((input_image).astype(np.uint8))
|
73 |
+
|
74 |
+
inst_list = list(np.unique(inst_map)) # get list of instances
|
75 |
+
inst_list.remove(0) # remove background
|
76 |
+
|
77 |
+
inst_rng_colors = random_colors(len(inst_list))
|
78 |
+
inst_rng_colors = np.array(inst_rng_colors) * 255
|
79 |
+
inst_rng_colors = inst_rng_colors.astype(np.uint8)
|
80 |
+
|
81 |
+
for inst_idx, inst_id in enumerate(inst_list):
|
82 |
+
inst_map_mask = np.array(inst_map == inst_id, np.uint8) # get single object
|
83 |
+
y1, y2, x1, x2 = get_bounding_box(inst_map_mask)
|
84 |
+
y1 = y1 - 2 if y1 - 2 >= 0 else y1
|
85 |
+
x1 = x1 - 2 if x1 - 2 >= 0 else x1
|
86 |
+
x2 = x2 + 2 if x2 + 2 <= inst_map.shape[1] - 1 else x2
|
87 |
+
y2 = y2 + 2 if y2 + 2 <= inst_map.shape[0] - 1 else y2
|
88 |
+
inst_map_crop = inst_map_mask[y1:y2, x1:x2]
|
89 |
+
contours_crop = cv2.findContours(
|
90 |
+
inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
91 |
+
)
|
92 |
+
# only has 1 instance per map, no need to check #contour detected by opencv
|
93 |
+
#print(contours_crop)
|
94 |
+
contours_crop = np.squeeze(
|
95 |
+
contours_crop[0][0].astype("int32")
|
96 |
+
) # * opencv protocol format may break
|
97 |
+
|
98 |
+
if len(contours_crop.shape) == 1:
|
99 |
+
contours_crop = contours_crop.reshape(1,-1)
|
100 |
+
#print(contours_crop.shape)
|
101 |
+
contours_crop += np.asarray([[x1, y1]]) # index correction
|
102 |
+
if type_map is not None:
|
103 |
+
type_map_crop = type_map[y1:y2, x1:x2]
|
104 |
+
type_id = np.unique(type_map_crop).max() # non-zero
|
105 |
+
inst_colour = type_colour[type_id]
|
106 |
+
else:
|
107 |
+
inst_colour = (inst_rng_colors[inst_idx]).tolist()
|
108 |
+
cv2.drawContours(overlay, [contours_crop], -1, inst_colour, line_thickness)
|
109 |
+
return overlay
|
110 |
+
|
111 |
+
|
112 |
+
# In[ ]:
|
113 |
+
|
114 |
+
|
115 |
+
|
116 |
+
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d978b42e9e63e949f0dcd3685be14146b6c5b5bfb48f703bfbc308b4ac190b64
|
3 |
+
size 135
|
requirements.txt
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gputools==0.2.13
|
2 |
+
h5py==3.7.0
|
3 |
+
huggingface-hub==0.10.1
|
4 |
+
imagecodecs
|
5 |
+
imageio==2.22.2
|
6 |
+
importlib-metadata==5.0.0
|
7 |
+
kiwisolver==1.4.4
|
8 |
+
llvmlite==0.39.1
|
9 |
+
Mako==1.2.3
|
10 |
+
Markdown==3.4.1
|
11 |
+
MarkupSafe==2.1.1
|
12 |
+
matplotlib==3.6.1
|
13 |
+
mkl-fft==1.3.1
|
14 |
+
mkl-service==2.4.0
|
15 |
+
monai==1.0.0
|
16 |
+
networkx==2.8.7
|
17 |
+
numba==0.56.3
|
18 |
+
numexpr
|
19 |
+
numpy
|
20 |
+
oauthlib==3.2.2
|
21 |
+
opencv-python==4.6.0.66
|
22 |
+
packaging
|
23 |
+
pandas==1.4.4
|
24 |
+
Pillow==9.2.0
|
25 |
+
scikit-image==0.19.3
|
26 |
+
scipy==1.9.2
|
27 |
+
stardist==0.8.3
|
28 |
+
tensorboard==2.10.1
|
29 |
+
tensorboard-data-server==0.6.1
|
30 |
+
tensorboard-plugin-wit==1.8.1
|
31 |
+
tifffile==2022.10.10
|
32 |
+
timm==0.6.11
|
33 |
+
torch==1.12.1
|
34 |
+
torchaudio==0.12.1
|
35 |
+
torchvision==0.13.1
|
36 |
+
tqdm==4.64.1
|
37 |
+
|
sribd_cellseg_models.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
join = os.path.join
|
4 |
+
import argparse
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from collections import OrderedDict
|
9 |
+
from torchvision import datasets, models, transforms
|
10 |
+
from classifiers import resnet10, resnet18
|
11 |
+
|
12 |
+
from utils_modify import sliding_window_inference,sliding_window_inference_large,__proc_np_hv
|
13 |
+
from PIL import Image
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from skimage import io, segmentation, morphology, measure, exposure
|
16 |
+
import tifffile as tif
|
17 |
+
from models.flexible_unet_convnext import FlexibleUNet_star,FlexibleUNet_hv
|
18 |
+
from transformers import PretrainedConfig
|
19 |
+
from typing import List
|
20 |
+
from transformers import PreTrainedModel
|
21 |
+
from huggingface_hub import PyTorchModelHubMixin
|
22 |
+
from torch import nn
|
23 |
+
class ModelConfig(PretrainedConfig):
|
24 |
+
model_type = "cell_sribd"
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
version = 1,
|
28 |
+
input_channels: int = 3,
|
29 |
+
roi_size: int = 512,
|
30 |
+
overlap: float = 0.5,
|
31 |
+
device: str = 'cpu',
|
32 |
+
**kwargs,
|
33 |
+
):
|
34 |
+
|
35 |
+
self.device = device
|
36 |
+
self.roi_size = (roi_size, roi_size)
|
37 |
+
self.input_channels = input_channels
|
38 |
+
self.overlap = overlap
|
39 |
+
self.np_thres, self.ksize, self.overall_thres, self.obj_size_thres = 0.6, 15, 0.4, 100
|
40 |
+
self.n_rays = 32
|
41 |
+
self.sw_batch_size = 4
|
42 |
+
self.num_classes= 4
|
43 |
+
self.block_size = 2048
|
44 |
+
self.min_overlap = 128
|
45 |
+
self.context = 128
|
46 |
+
super().__init__(**kwargs)
|
47 |
+
|
48 |
+
|
49 |
+
class MultiStreamCellSegModel(PreTrainedModel):
|
50 |
+
config_class = ModelConfig
|
51 |
+
#print(config.input_channels)
|
52 |
+
def __init__(self, config):
|
53 |
+
super().__init__(config)
|
54 |
+
#print(config.input_channels)
|
55 |
+
self.config = config
|
56 |
+
self.cls_model = resnet18()
|
57 |
+
self.model0 = FlexibleUNet_star(in_channels=config.input_channels,out_channels=config.n_rays+1,backbone='convnext_small',pretrained=False,n_rays=config.n_rays,prob_out_channels=1,)
|
58 |
+
self.model1 = FlexibleUNet_star(in_channels=config.input_channels,out_channels=config.n_rays+1,backbone='convnext_small',pretrained=False,n_rays=config.n_rays,prob_out_channels=1,)
|
59 |
+
self.model2 = FlexibleUNet_star(in_channels=config.input_channels,out_channels=config.n_rays+1,backbone='convnext_small',pretrained=False,n_rays=config.n_rays,prob_out_channels=1,)
|
60 |
+
self.model3 = FlexibleUNet_hv(in_channels=config.input_channels,out_channels=2+2,backbone='convnext_small',pretrained=False,n_rays=2,prob_out_channels=2,)
|
61 |
+
self.preprocess=transforms.Compose([
|
62 |
+
transforms.Resize(size=256),
|
63 |
+
transforms.CenterCrop(size=224),
|
64 |
+
transforms.ToTensor(),
|
65 |
+
transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
|
66 |
+
def load_checkpoints(self,checkpoints):
|
67 |
+
self.cls_model.load_state_dict(checkpoints['cls_model'])
|
68 |
+
self.model0.load_state_dict(checkpoints['class1_model']['model_state_dict'])
|
69 |
+
self.model1.load_state_dict(checkpoints['class2_model']['model_state_dict'])
|
70 |
+
self.model2.load_state_dict(checkpoints['class3_model']['model_state_dict'])
|
71 |
+
self.model3.load_state_dict(checkpoints['class4_model'])
|
72 |
+
|
73 |
+
def forward(self, pre_img_data):
|
74 |
+
inputs=self.preprocess(Image.fromarray(pre_img_data)).unsqueeze(0)
|
75 |
+
outputs = self.cls_model(inputs)
|
76 |
+
_, preds = torch.max(outputs, 1)
|
77 |
+
label=preds[0].cpu().numpy()
|
78 |
+
test_npy01 = pre_img_data
|
79 |
+
if label in [0,1,2]:
|
80 |
+
if label == 0:
|
81 |
+
output_label = sliding_window_inference_large(test_npy01,self.config.block_size,self.config.min_overlap,self.config.context, self.config.roi_size,self.config.sw_batch_size,predictor=self.model0,device=self.config.device)
|
82 |
+
elif label == 1:
|
83 |
+
output_label = sliding_window_inference_large(test_npy01,self.config.block_size,self.config.min_overlap,self.config.context, self.config.roi_size,self.config.sw_batch_size,predictor=self.model1,device=self.config.device)
|
84 |
+
elif label == 2:
|
85 |
+
output_label = sliding_window_inference_large(test_npy01,self.config.block_size,self.config.min_overlap,self.config.context, self.config.roi_size,self.config.sw_batch_size,predictor=self.model2,device=self.config.device)
|
86 |
+
else:
|
87 |
+
test_tensor = torch.from_numpy(np.expand_dims(test_npy01, 0)).permute(0, 3, 1, 2).type(torch.FloatTensor)
|
88 |
+
|
89 |
+
output_hv, output_np = sliding_window_inference(test_tensor, self.config.roi, self.config.sw_batch_size, self.model3, overlap=self.config.overlap,device=self.config.device)
|
90 |
+
pred_dict = {'np': output_np, 'hv': output_hv}
|
91 |
+
pred_dict = OrderedDict(
|
92 |
+
[[k, v.permute(0, 2, 3, 1).contiguous()] for k, v in pred_dict.items()] # NHWC
|
93 |
+
)
|
94 |
+
pred_dict["np"] = F.softmax(pred_dict["np"], dim=-1)[..., 1:]
|
95 |
+
pred_output = torch.cat(list(pred_dict.values()), -1).cpu().numpy() # NHW3
|
96 |
+
pred_map = np.squeeze(pred_output) # HW3
|
97 |
+
pred_inst = __proc_np_hv(pred_map, self.config.np_thres, self.config.ksize, self.config.overall_thres, self.config.obj_size_thres)
|
98 |
+
raw_pred_shape = pred_inst.shape[:2]
|
99 |
+
output_label = pred_inst
|
100 |
+
return output_label
|
stardist_pkg/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, print_function
|
2 |
+
|
3 |
+
import warnings
|
4 |
+
def format_warning(message, category, filename, lineno, line=''):
|
5 |
+
import pathlib
|
6 |
+
return f"{pathlib.Path(filename).name} ({lineno}): {message}\n"
|
7 |
+
warnings.formatwarning = format_warning
|
8 |
+
del warnings
|
9 |
+
|
10 |
+
from .version import __version__
|
11 |
+
|
12 |
+
# TODO: which functions to expose here? all?
|
13 |
+
from .nms import non_maximum_suppression
|
14 |
+
from .utils import edt_prob, fill_label_holes, sample_points, calculate_extents, export_imagej_rois, gputools_available
|
15 |
+
from .geometry import star_dist, polygons_to_label, relabel_image_stardist, ray_angles, dist_to_coord
|
16 |
+
from .sample_patches import sample_patches
|
17 |
+
from .bioimageio_utils import export_bioimageio, import_bioimageio
|
18 |
+
|
19 |
+
def _py_deprecation(ver_python=(3,6), ver_stardist='0.9.0'):
|
20 |
+
import sys
|
21 |
+
from distutils.version import LooseVersion
|
22 |
+
if sys.version_info[:2] == ver_python and LooseVersion(__version__) < LooseVersion(ver_stardist):
|
23 |
+
print(f"You are using Python {ver_python[0]}.{ver_python[1]}, which will no longer be supported in StarDist {ver_stardist}.\n"
|
24 |
+
f"→ Please upgrade to Python {ver_python[0]}.{ver_python[1]+1} or later.", file=sys.stderr, flush=True)
|
25 |
+
_py_deprecation()
|
26 |
+
del _py_deprecation
|
stardist_pkg/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (1.62 kB). View file
|
|
stardist_pkg/__pycache__/big.cpython-37.pyc
ADDED
Binary file (20.7 kB). View file
|
|
stardist_pkg/__pycache__/bioimageio_utils.cpython-37.pyc
ADDED
Binary file (15.2 kB). View file
|
|
stardist_pkg/__pycache__/matching.cpython-37.pyc
ADDED
Binary file (16.9 kB). View file
|
|
stardist_pkg/__pycache__/nms.cpython-37.pyc
ADDED
Binary file (9.59 kB). View file
|
|
stardist_pkg/__pycache__/sample_patches.cpython-37.pyc
ADDED
Binary file (4.22 kB). View file
|
|
stardist_pkg/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (15.4 kB). View file
|
|
stardist_pkg/__pycache__/version.cpython-37.pyc
ADDED
Binary file (199 Bytes). View file
|
|
stardist_pkg/big.py
ADDED
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import warnings
|
3 |
+
import math
|
4 |
+
from tqdm import tqdm
|
5 |
+
from skimage.measure import regionprops
|
6 |
+
from skimage.draw import polygon
|
7 |
+
from csbdeep.utils import _raise, axes_check_and_normalize, axes_dict
|
8 |
+
from itertools import product
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
OBJECT_KEYS = set(('prob', 'points', 'coord', 'dist', 'class_prob', 'class_id'))
|
14 |
+
COORD_KEYS = set(('points', 'coord'))
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
class Block:
|
19 |
+
"""One-dimensional block as part of a chain.
|
20 |
+
|
21 |
+
There are no explicit start and end positions. Instead, each block is
|
22 |
+
aware of its predecessor and successor and derives such things (recursively)
|
23 |
+
based on its neighbors.
|
24 |
+
|
25 |
+
Blocks overlap with one another (at least min_overlap + 2*context) and
|
26 |
+
have a read region (the entire block) and a write region (ignoring context).
|
27 |
+
Given a query interval, Block.is_responsible will return true for only one
|
28 |
+
block of a chain (or raise an exception if the interval is larger than
|
29 |
+
min_overlap or even the entire block without context).
|
30 |
+
|
31 |
+
"""
|
32 |
+
def __init__(self, size, min_overlap, context, pred):
|
33 |
+
self.size = int(size)
|
34 |
+
self.min_overlap = int(min_overlap)
|
35 |
+
self.context = int(context)
|
36 |
+
self.pred = pred
|
37 |
+
self.succ = None
|
38 |
+
assert 0 <= self.min_overlap + 2*self.context < self.size
|
39 |
+
self.stride = self.size - (self.min_overlap + 2*self.context)
|
40 |
+
self._start = 0
|
41 |
+
self._frozen = False
|
42 |
+
|
43 |
+
@property
|
44 |
+
def start(self):
|
45 |
+
return self._start if (self.frozen or self.at_begin) else self.pred.succ_start
|
46 |
+
|
47 |
+
@property
|
48 |
+
def end(self):
|
49 |
+
return self.start + self.size
|
50 |
+
|
51 |
+
@property
|
52 |
+
def succ_start(self):
|
53 |
+
return self.start + self.stride
|
54 |
+
|
55 |
+
def add_succ(self):
|
56 |
+
assert self.succ is None and not self.frozen
|
57 |
+
self.succ = Block(self.size, self.min_overlap, self.context, self)
|
58 |
+
return self.succ
|
59 |
+
|
60 |
+
def decrease_stride(self, amount):
|
61 |
+
amount = int(amount)
|
62 |
+
assert 0 <= amount < self.stride and not self.frozen
|
63 |
+
self.stride -= amount
|
64 |
+
|
65 |
+
def freeze(self):
|
66 |
+
"""Call on first block to freeze entire chain (after construction is done)"""
|
67 |
+
assert not self.frozen and (self.at_begin or self.pred.frozen)
|
68 |
+
self._start = self.start
|
69 |
+
self._frozen = True
|
70 |
+
if not self.at_end:
|
71 |
+
self.succ.freeze()
|
72 |
+
|
73 |
+
@property
|
74 |
+
def slice_read(self):
|
75 |
+
return slice(self.start, self.end)
|
76 |
+
|
77 |
+
@property
|
78 |
+
def slice_crop_context(self):
|
79 |
+
"""Crop context relative to read region"""
|
80 |
+
return slice(self.context_start, self.size - self.context_end)
|
81 |
+
|
82 |
+
@property
|
83 |
+
def slice_write(self):
|
84 |
+
return slice(self.start + self.context_start, self.end - self.context_end)
|
85 |
+
|
86 |
+
def is_responsible(self, bbox):
|
87 |
+
"""Responsibility for query interval bbox, which is assumed to be smaller than min_overlap.
|
88 |
+
|
89 |
+
If the assumption is met, only one block of a chain will return true.
|
90 |
+
If violated, one or more blocks of a chain may raise a NotFullyVisible exception.
|
91 |
+
The exception will have an argument that is
|
92 |
+
False if bbox is larger than min_overlap, and
|
93 |
+
True if bbox is even larger than the entire block without context.
|
94 |
+
|
95 |
+
bbox: (int,int)
|
96 |
+
1D bounding box interval with coordinates relative to size without context
|
97 |
+
|
98 |
+
"""
|
99 |
+
bmin, bmax = bbox
|
100 |
+
|
101 |
+
r_start = 0 if self.at_begin else (self.pred.overlap - self.pred.context_end - self.context_start)
|
102 |
+
r_end = self.size - self.context_start - self.context_end
|
103 |
+
assert 0 <= bmin < bmax <= r_end
|
104 |
+
|
105 |
+
# assert not (bmin == 0 and bmax >= r_start and not self.at_begin), [(r_start,r_end), bbox, self]
|
106 |
+
|
107 |
+
if bmin == 0 and bmax >= r_start:
|
108 |
+
if bmax == r_end:
|
109 |
+
# object spans the entire block, i.e. is probably larger than size (minus the context)
|
110 |
+
raise NotFullyVisible(True)
|
111 |
+
if not self.at_begin:
|
112 |
+
# object spans the entire overlap region, i.e. is only partially visible here and also by the predecessor block
|
113 |
+
raise NotFullyVisible(False)
|
114 |
+
|
115 |
+
# object ends before responsible region start
|
116 |
+
if bmax < r_start: return False
|
117 |
+
# object touches the end of the responsible region (only take if at end)
|
118 |
+
if bmax == r_end and not self.at_end: return False
|
119 |
+
return True
|
120 |
+
|
121 |
+
# ------------------------
|
122 |
+
|
123 |
+
@property
|
124 |
+
def frozen(self):
|
125 |
+
return self._frozen
|
126 |
+
|
127 |
+
@property
|
128 |
+
def at_begin(self):
|
129 |
+
return self.pred is None
|
130 |
+
|
131 |
+
@property
|
132 |
+
def at_end(self):
|
133 |
+
return self.succ is None
|
134 |
+
|
135 |
+
@property
|
136 |
+
def overlap(self):
|
137 |
+
return self.size - self.stride
|
138 |
+
|
139 |
+
@property
|
140 |
+
def context_start(self):
|
141 |
+
return 0 if self.at_begin else self.context
|
142 |
+
|
143 |
+
@property
|
144 |
+
def context_end(self):
|
145 |
+
return 0 if self.at_end else self.context
|
146 |
+
|
147 |
+
def __repr__(self):
|
148 |
+
shared = f'{self.start:03}:{self.end:03}'
|
149 |
+
shared += f', size={self.context_start}-{self.size-self.context_start-self.context_end}-{self.context_end}'
|
150 |
+
if self.at_end:
|
151 |
+
return f'{self.__class__.__name__}({shared})'
|
152 |
+
else:
|
153 |
+
return f'{self.__class__.__name__}({shared}, overlap={self.overlap}/{self.overlap-self.context_start-self.context_end})'
|
154 |
+
|
155 |
+
@property
|
156 |
+
def chain(self):
|
157 |
+
blocks = [self]
|
158 |
+
while not blocks[-1].at_end:
|
159 |
+
blocks.append(blocks[-1].succ)
|
160 |
+
return blocks
|
161 |
+
|
162 |
+
def __iter__(self):
|
163 |
+
return iter(self.chain)
|
164 |
+
|
165 |
+
# ------------------------
|
166 |
+
|
167 |
+
@staticmethod
|
168 |
+
def cover(size, block_size, min_overlap, context, grid=1, verbose=True):
|
169 |
+
"""Return chain of grid-aligned blocks to cover the interval [0,size].
|
170 |
+
|
171 |
+
Parameters block_size, min_overlap, and context will be used
|
172 |
+
for all blocks of the chain. Only the size of the last block
|
173 |
+
may differ.
|
174 |
+
|
175 |
+
Except for the last block, start and end positions of all blocks will
|
176 |
+
be multiples of grid. To that end, the provided block parameters may
|
177 |
+
be increased to achieve that.
|
178 |
+
|
179 |
+
Note that parameters must be chosen such that the write regions of only
|
180 |
+
neighboring blocks are overlapping.
|
181 |
+
|
182 |
+
"""
|
183 |
+
assert 0 <= min_overlap+2*context < block_size <= size
|
184 |
+
assert 0 < grid <= block_size
|
185 |
+
block_size = _grid_divisible(grid, block_size, name='block_size', verbose=verbose)
|
186 |
+
min_overlap = _grid_divisible(grid, min_overlap, name='min_overlap', verbose=verbose)
|
187 |
+
context = _grid_divisible(grid, context, name='context', verbose=verbose)
|
188 |
+
|
189 |
+
# allow size not to be divisible by grid
|
190 |
+
size_orig = size
|
191 |
+
size = _grid_divisible(grid, size, name='size', verbose=False)
|
192 |
+
|
193 |
+
# divide all sizes by grid
|
194 |
+
assert all(v % grid == 0 for v in (size, block_size, min_overlap, context))
|
195 |
+
size //= grid
|
196 |
+
block_size //= grid
|
197 |
+
min_overlap //= grid
|
198 |
+
context //= grid
|
199 |
+
|
200 |
+
# compute cover in grid-multiples
|
201 |
+
t = first = Block(block_size, min_overlap, context, None)
|
202 |
+
while t.end < size:
|
203 |
+
t = t.add_succ()
|
204 |
+
last = t
|
205 |
+
|
206 |
+
# [print(t) for t in first]
|
207 |
+
|
208 |
+
# move blocks around to make it fit
|
209 |
+
excess = last.end - size
|
210 |
+
t = first
|
211 |
+
while excess > 0:
|
212 |
+
t.decrease_stride(1)
|
213 |
+
excess -= 1
|
214 |
+
t = t.succ
|
215 |
+
if (t == last): t = first
|
216 |
+
|
217 |
+
# make a copy of the cover and multiply sizes by grid
|
218 |
+
if grid > 1:
|
219 |
+
size *= grid
|
220 |
+
block_size *= grid
|
221 |
+
min_overlap *= grid
|
222 |
+
context *= grid
|
223 |
+
#
|
224 |
+
_t = _first = first
|
225 |
+
t = first = Block(block_size, min_overlap, context, None)
|
226 |
+
t.stride = _t.stride*grid
|
227 |
+
while not _t.at_end:
|
228 |
+
_t = _t.succ
|
229 |
+
t = t.add_succ()
|
230 |
+
t.stride = _t.stride*grid
|
231 |
+
last = t
|
232 |
+
|
233 |
+
# change size of last block
|
234 |
+
# will be padded internally to the same size
|
235 |
+
# as the others by model.predict_instances
|
236 |
+
size_delta = size - size_orig
|
237 |
+
last.size -= size_delta
|
238 |
+
assert 0 <= size_delta < grid
|
239 |
+
|
240 |
+
# for efficiency (to not determine starts recursively from now on)
|
241 |
+
first.freeze()
|
242 |
+
|
243 |
+
blocks = first.chain
|
244 |
+
|
245 |
+
# sanity checks
|
246 |
+
assert first.start == 0 and last.end == size_orig
|
247 |
+
assert all(t.overlap-2*context >= min_overlap for t in blocks if t != last)
|
248 |
+
assert all(t.start % grid == 0 and t.end % grid == 0 for t in blocks if t != last)
|
249 |
+
# print(); [print(t) for t in first]
|
250 |
+
|
251 |
+
# only neighboring blocks should be overlapping
|
252 |
+
if len(blocks) >= 3:
|
253 |
+
for t in blocks[:-2]:
|
254 |
+
assert t.slice_write.stop <= t.succ.succ.slice_write.start
|
255 |
+
|
256 |
+
return blocks
|
257 |
+
|
258 |
+
|
259 |
+
|
260 |
+
class BlockND:
|
261 |
+
"""N-dimensional block.
|
262 |
+
|
263 |
+
Each BlockND simply consists of a 1-dimensional Block per axis and also
|
264 |
+
has an id (which should be unique). The n-dimensional region represented
|
265 |
+
by each BlockND is the intersection of all 1D Blocks per axis.
|
266 |
+
|
267 |
+
Also see `Block`.
|
268 |
+
|
269 |
+
"""
|
270 |
+
def __init__(self, id, blocks, axes):
|
271 |
+
self.id = id
|
272 |
+
self.blocks = tuple(blocks)
|
273 |
+
self.axes = axes_check_and_normalize(axes, length=len(self.blocks))
|
274 |
+
self.axis_to_block = dict(zip(self.axes,self.blocks))
|
275 |
+
|
276 |
+
def blocks_for_axes(self, axes=None):
|
277 |
+
axes = self.axes if axes is None else axes_check_and_normalize(axes)
|
278 |
+
return tuple(self.axis_to_block[a] for a in axes)
|
279 |
+
|
280 |
+
def slice_read(self, axes=None):
|
281 |
+
return tuple(t.slice_read for t in self.blocks_for_axes(axes))
|
282 |
+
|
283 |
+
def slice_crop_context(self, axes=None):
|
284 |
+
return tuple(t.slice_crop_context for t in self.blocks_for_axes(axes))
|
285 |
+
|
286 |
+
def slice_write(self, axes=None):
|
287 |
+
return tuple(t.slice_write for t in self.blocks_for_axes(axes))
|
288 |
+
|
289 |
+
def read(self, x, axes=None):
|
290 |
+
"""Read block "read region" from x (numpy.ndarray or similar)"""
|
291 |
+
return x[self.slice_read(axes)]
|
292 |
+
|
293 |
+
def crop_context(self, labels, axes=None):
|
294 |
+
return labels[self.slice_crop_context(axes)]
|
295 |
+
|
296 |
+
def write(self, x, labels, axes=None):
|
297 |
+
"""Write (only entries > 0 of) labels to block "write region" of x (numpy.ndarray or similar)"""
|
298 |
+
s = self.slice_write(axes)
|
299 |
+
mask = labels > 0
|
300 |
+
# x[s][mask] = labels[mask] # doesn't work with zarr
|
301 |
+
region = x[s]
|
302 |
+
region[mask] = labels[mask]
|
303 |
+
x[s] = region
|
304 |
+
|
305 |
+
def is_responsible(self, slices, axes=None):
|
306 |
+
return all(t.is_responsible((s.start,s.stop)) for t,s in zip(self.blocks_for_axes(axes),slices))
|
307 |
+
|
308 |
+
def __repr__(self):
|
309 |
+
slices = ','.join(f'{a}={t.start:03}:{t.end:03}' for t,a in zip(self.blocks,self.axes))
|
310 |
+
return f'{self.__class__.__name__}({self.id}|{slices})'
|
311 |
+
|
312 |
+
def __iter__(self):
|
313 |
+
return iter(self.blocks)
|
314 |
+
|
315 |
+
# ------------------------
|
316 |
+
|
317 |
+
def filter_objects(self, labels, polys, axes=None):
|
318 |
+
"""Filter out objects that block is not responsible for.
|
319 |
+
|
320 |
+
Given label image 'labels' and dictionary 'polys' of polygon/polyhedron objects,
|
321 |
+
only retain those objects that this block is responsible for.
|
322 |
+
|
323 |
+
This function will return a pair (labels, polys) of the modified label image and dictionary.
|
324 |
+
It will raise a RuntimeError if an object is found in the overlap area
|
325 |
+
of neighboring blocks that violates the assumption to be smaller than 'min_overlap'.
|
326 |
+
|
327 |
+
If parameter 'polys' is None, only the filtered label image will be returned.
|
328 |
+
|
329 |
+
Notes
|
330 |
+
-----
|
331 |
+
- Important: It is assumed that the object label ids in 'labels' and
|
332 |
+
the entries in 'polys' are sorted in the same way.
|
333 |
+
- Does not modify 'labels' and 'polys', but returns modified copies.
|
334 |
+
|
335 |
+
Example
|
336 |
+
-------
|
337 |
+
>>> labels, polys = model.predict_instances(block.read(img))
|
338 |
+
>>> labels = block.crop_context(labels)
|
339 |
+
>>> labels, polys = block.filter_objects(labels, polys)
|
340 |
+
|
341 |
+
"""
|
342 |
+
# TODO: option to update labels in-place
|
343 |
+
assert np.issubdtype(labels.dtype, np.integer)
|
344 |
+
ndim = len(self.blocks_for_axes(axes))
|
345 |
+
assert ndim in (2,3)
|
346 |
+
assert labels.ndim == ndim and labels.shape == tuple(s.stop-s.start for s in self.slice_crop_context(axes))
|
347 |
+
|
348 |
+
labels_filtered = np.zeros_like(labels)
|
349 |
+
# problem_ids = []
|
350 |
+
for r in regionprops(labels):
|
351 |
+
slices = tuple(slice(r.bbox[i],r.bbox[i+labels.ndim]) for i in range(labels.ndim))
|
352 |
+
try:
|
353 |
+
if self.is_responsible(slices, axes):
|
354 |
+
labels_filtered[slices][r.image] = r.label
|
355 |
+
except NotFullyVisible as e:
|
356 |
+
# shape_block_write = tuple(s.stop-s.start for s in self.slice_write(axes))
|
357 |
+
shape_object = tuple(s.stop-s.start for s in slices)
|
358 |
+
shape_min_overlap = tuple(t.min_overlap for t in self.blocks_for_axes(axes))
|
359 |
+
raise RuntimeError(f"Found object of shape {shape_object}, which violates the assumption of being smaller than 'min_overlap' {shape_min_overlap}. Increase 'min_overlap' to avoid this problem.")
|
360 |
+
|
361 |
+
# if e.args[0]: # object larger than block write region
|
362 |
+
# assert any(o >= b for o,b in zip(shape_object,shape_block_write))
|
363 |
+
# # problem, since this object will probably be saved by another block too
|
364 |
+
# raise RuntimeError(f"Found object of shape {shape_object}, larger than an entire block's write region of shape {shape_block_write}. Increase 'block_size' to avoid this problem.")
|
365 |
+
# # print("found object larger than 'block_size'")
|
366 |
+
# else:
|
367 |
+
# assert any(o >= b for o,b in zip(shape_object,shape_min_overlap))
|
368 |
+
# # print("found object larger than 'min_overlap'")
|
369 |
+
|
370 |
+
# # keep object, because will be dealt with later, i.e.
|
371 |
+
# # render the poly again into the label image, but this is not
|
372 |
+
# # ideal since the assumption is that the object outside that
|
373 |
+
# # region is not reliable because it's in the context
|
374 |
+
# labels_filtered[slices][r.image] = r.label
|
375 |
+
# problem_ids.append(r.label)
|
376 |
+
|
377 |
+
if polys is None:
|
378 |
+
# assert len(problem_ids) == 0
|
379 |
+
return labels_filtered
|
380 |
+
else:
|
381 |
+
# it is assumed that ids in 'labels' map to entries in 'polys'
|
382 |
+
assert isinstance(polys,dict) and any(k in polys for k in COORD_KEYS)
|
383 |
+
filtered_labels = np.unique(labels_filtered)
|
384 |
+
filtered_ind = [i-1 for i in filtered_labels if i > 0]
|
385 |
+
polys_out = {k: (v[filtered_ind] if k in OBJECT_KEYS else v) for k,v in polys.items()}
|
386 |
+
for k in COORD_KEYS:
|
387 |
+
if k in polys_out.keys():
|
388 |
+
polys_out[k] = self.translate_coordinates(polys_out[k], axes=axes)
|
389 |
+
|
390 |
+
return labels_filtered, polys_out#, tuple(problem_ids)
|
391 |
+
|
392 |
+
def translate_coordinates(self, coordinates, axes=None):
|
393 |
+
"""Translate local block coordinates (of read region) to global ones based on block position"""
|
394 |
+
ndim = len(self.blocks_for_axes(axes))
|
395 |
+
assert isinstance(coordinates, np.ndarray) and coordinates.ndim >= 2 and coordinates.shape[1] == ndim
|
396 |
+
start = [s.start for s in self.slice_read(axes)]
|
397 |
+
shape = tuple(1 if d!=1 else ndim for d in range(coordinates.ndim))
|
398 |
+
start = np.array(start).reshape(shape)
|
399 |
+
return coordinates + start
|
400 |
+
|
401 |
+
# ------------------------
|
402 |
+
|
403 |
+
@staticmethod
|
404 |
+
def cover(shape, axes, block_size, min_overlap, context, grid=1):
|
405 |
+
"""Return grid-aligned n-dimensional blocks to cover region
|
406 |
+
of the given shape with axes semantics.
|
407 |
+
|
408 |
+
Parameters block_size, min_overlap, and context can be different per
|
409 |
+
dimension/axis (if provided as list) or the same (if provided as
|
410 |
+
scalar value).
|
411 |
+
|
412 |
+
Also see `Block.cover`.
|
413 |
+
|
414 |
+
"""
|
415 |
+
shape = tuple(shape)
|
416 |
+
n = len(shape)
|
417 |
+
axes = axes_check_and_normalize(axes, length=n)
|
418 |
+
if np.isscalar(block_size): block_size = n*[block_size]
|
419 |
+
if np.isscalar(min_overlap): min_overlap = n*[min_overlap]
|
420 |
+
if np.isscalar(context): context = n*[context]
|
421 |
+
if np.isscalar(grid): grid = n*[grid]
|
422 |
+
assert n == len(block_size) == len(min_overlap) == len(context) == len(grid)
|
423 |
+
|
424 |
+
# compute cover for each dimension
|
425 |
+
cover_1d = [Block.cover(*args) for args in zip(shape, block_size, min_overlap, context, grid)]
|
426 |
+
# return cover as Cartesian product of 1-dimensional blocks
|
427 |
+
return tuple(BlockND(i,blocks,axes) for i,blocks in enumerate(product(*cover_1d)))
|
428 |
+
|
429 |
+
|
430 |
+
|
431 |
+
class Polygon:
|
432 |
+
|
433 |
+
def __init__(self, coord, bbox=None, shape_max=None):
|
434 |
+
self.bbox = self.coords_bbox(coord, shape_max=shape_max) if bbox is None else bbox
|
435 |
+
self.coord = coord - np.array([r[0] for r in self.bbox]).reshape(2,1)
|
436 |
+
self.slice = tuple(slice(*r) for r in self.bbox)
|
437 |
+
self.shape = tuple(r[1]-r[0] for r in self.bbox)
|
438 |
+
rr,cc = polygon(*self.coord, self.shape)
|
439 |
+
self.mask = np.zeros(self.shape, bool)
|
440 |
+
self.mask[rr,cc] = True
|
441 |
+
|
442 |
+
@staticmethod
|
443 |
+
def coords_bbox(*coords, shape_max=None):
|
444 |
+
assert all(isinstance(c, np.ndarray) and c.ndim==2 and c.shape[0]==2 for c in coords)
|
445 |
+
if shape_max is None:
|
446 |
+
shape_max = (np.inf, np.inf)
|
447 |
+
coord = np.concatenate(coords, axis=1)
|
448 |
+
mins = np.maximum(0, np.floor(np.min(coord,axis=1))).astype(int)
|
449 |
+
maxs = np.minimum(shape_max, np.ceil (np.max(coord,axis=1))).astype(int)
|
450 |
+
return tuple(zip(tuple(mins),tuple(maxs)))
|
451 |
+
|
452 |
+
|
453 |
+
|
454 |
+
class Polyhedron:
|
455 |
+
|
456 |
+
def __init__(self, dist, origin, rays, bbox=None, shape_max=None):
|
457 |
+
self.bbox = self.coords_bbox((dist, origin), rays=rays, shape_max=shape_max) if bbox is None else bbox
|
458 |
+
self.slice = tuple(slice(*r) for r in self.bbox)
|
459 |
+
self.shape = tuple(r[1]-r[0] for r in self.bbox)
|
460 |
+
_origin = origin.reshape(1,3) - np.array([r[0] for r in self.bbox]).reshape(1,3)
|
461 |
+
self.mask = polyhedron_to_label(dist[np.newaxis], _origin, rays, shape=self.shape, verbose=False).astype(bool)
|
462 |
+
|
463 |
+
@staticmethod
|
464 |
+
def coords_bbox(*dist_origin, rays, shape_max=None):
|
465 |
+
dists, points = zip(*dist_origin)
|
466 |
+
assert all(isinstance(d, np.ndarray) and d.ndim==1 and len(d)==len(rays) for d in dists)
|
467 |
+
assert all(isinstance(p, np.ndarray) and p.ndim==1 and len(p)==3 for p in points)
|
468 |
+
dists, points, verts = np.stack(dists)[...,np.newaxis], np.stack(points)[:,np.newaxis], rays.vertices[np.newaxis]
|
469 |
+
coord = dists * verts + points
|
470 |
+
coord = np.concatenate(coord, axis=0)
|
471 |
+
if shape_max is None:
|
472 |
+
shape_max = (np.inf, np.inf, np.inf)
|
473 |
+
mins = np.maximum(0, np.floor(np.min(coord,axis=0))).astype(int)
|
474 |
+
maxs = np.minimum(shape_max, np.ceil (np.max(coord,axis=0))).astype(int)
|
475 |
+
return tuple(zip(tuple(mins),tuple(maxs)))
|
476 |
+
|
477 |
+
|
478 |
+
|
479 |
+
# def repaint_labels(output, labels, polys, show_progress=True):
|
480 |
+
# """Repaint object instances in correct order based on probability scores.
|
481 |
+
|
482 |
+
# Does modify 'output' and 'polys' in-place, but will only write sparsely to 'output' where needed.
|
483 |
+
|
484 |
+
# output: numpy.ndarray or similar
|
485 |
+
# Label image (integer-valued)
|
486 |
+
# labels: iterable of int
|
487 |
+
# List of integer label ids that occur in output
|
488 |
+
# polys: dict
|
489 |
+
# Dictionary of polygon/polyhedra properties.
|
490 |
+
# Assumption is that the label id (-1) corresponds to the index in the polys dict
|
491 |
+
|
492 |
+
# """
|
493 |
+
# assert output.ndim in (2,3)
|
494 |
+
|
495 |
+
# if show_progress:
|
496 |
+
# labels = tqdm(labels, leave=True)
|
497 |
+
|
498 |
+
# labels_eliminated = set()
|
499 |
+
|
500 |
+
# # TODO: inelegant to have so much duplicated code here
|
501 |
+
# if output.ndim == 2:
|
502 |
+
# coord = lambda i: polys['coord'][i-1]
|
503 |
+
# prob = lambda i: polys['prob'][i-1]
|
504 |
+
|
505 |
+
# for i in labels:
|
506 |
+
# if i in labels_eliminated: continue
|
507 |
+
# poly_i = Polygon(coord(i), shape_max=output.shape)
|
508 |
+
|
509 |
+
# # find all labels that overlap with i (including i)
|
510 |
+
# overlapping = set(np.unique(output[poly_i.slice][poly_i.mask])) - {0}
|
511 |
+
# assert i in overlapping
|
512 |
+
# # compute bbox union to find area to crop/replace in large output label image
|
513 |
+
# bbox_union = Polygon.coords_bbox(*[coord(j) for j in overlapping], shape_max=output.shape)
|
514 |
+
|
515 |
+
# # crop out label i, including the region that include all overlapping labels
|
516 |
+
# poly_i = Polygon(coord(i), bbox=bbox_union)
|
517 |
+
# mask = poly_i.mask.copy()
|
518 |
+
|
519 |
+
# # remove pixels from mask that belong to labels with higher probability
|
520 |
+
# for j in [j for j in overlapping if prob(j) > prob(i)]:
|
521 |
+
# mask[ Polygon(coord(j), bbox=bbox_union).mask ] = False
|
522 |
+
|
523 |
+
# crop = output[poly_i.slice]
|
524 |
+
# crop[crop==i] = 0 # delete all remnants of i in crop
|
525 |
+
# crop[mask] = i # paint i where mask still active
|
526 |
+
|
527 |
+
# labels_remaining = set(np.unique(output[poly_i.slice][poly_i.mask])) - {0}
|
528 |
+
# labels_eliminated.update(overlapping - labels_remaining)
|
529 |
+
# else:
|
530 |
+
|
531 |
+
# dist = lambda i: polys['dist'][i-1]
|
532 |
+
# origin = lambda i: polys['points'][i-1]
|
533 |
+
# prob = lambda i: polys['prob'][i-1]
|
534 |
+
# rays = polys['rays']
|
535 |
+
|
536 |
+
# for i in labels:
|
537 |
+
# if i in labels_eliminated: continue
|
538 |
+
# poly_i = Polyhedron(dist(i), origin(i), rays, shape_max=output.shape)
|
539 |
+
|
540 |
+
# # find all labels that overlap with i (including i)
|
541 |
+
# overlapping = set(np.unique(output[poly_i.slice][poly_i.mask])) - {0}
|
542 |
+
# assert i in overlapping
|
543 |
+
# # compute bbox union to find area to crop/replace in large output label image
|
544 |
+
# bbox_union = Polyhedron.coords_bbox(*[(dist(j),origin(j)) for j in overlapping], rays=rays, shape_max=output.shape)
|
545 |
+
|
546 |
+
# # crop out label i, including the region that include all overlapping labels
|
547 |
+
# poly_i = Polyhedron(dist(i), origin(i), rays, bbox=bbox_union)
|
548 |
+
# mask = poly_i.mask.copy()
|
549 |
+
|
550 |
+
# # remove pixels from mask that belong to labels with higher probability
|
551 |
+
# for j in [j for j in overlapping if prob(j) > prob(i)]:
|
552 |
+
# mask[ Polyhedron(dist(j), origin(j), rays, bbox=bbox_union).mask ] = False
|
553 |
+
|
554 |
+
# crop = output[poly_i.slice]
|
555 |
+
# crop[crop==i] = 0 # delete all remnants of i in crop
|
556 |
+
# crop[mask] = i # paint i where mask still active
|
557 |
+
|
558 |
+
# labels_remaining = set(np.unique(output[poly_i.slice][poly_i.mask])) - {0}
|
559 |
+
# labels_eliminated.update(overlapping - labels_remaining)
|
560 |
+
|
561 |
+
# if len(labels_eliminated) > 0:
|
562 |
+
# ind = [i-1 for i in labels_eliminated]
|
563 |
+
# for k,v in polys.items():
|
564 |
+
# if k in OBJECT_KEYS:
|
565 |
+
# polys[k] = np.delete(v, ind, axis=0)
|
566 |
+
|
567 |
+
|
568 |
+
|
569 |
+
############
|
570 |
+
|
571 |
+
|
572 |
+
|
573 |
+
def predict_big(model, *args, **kwargs):
|
574 |
+
from .models import StarDist2D, StarDist3D
|
575 |
+
if isinstance(model,(StarDist2D,StarDist3D)):
|
576 |
+
dst = model.__class__.__name__
|
577 |
+
else:
|
578 |
+
dst = '{StarDist2D, StarDist3D}'
|
579 |
+
raise RuntimeError(f"This function has moved to {dst}.predict_instances_big.")
|
580 |
+
|
581 |
+
|
582 |
+
|
583 |
+
class NotFullyVisible(Exception):
|
584 |
+
pass
|
585 |
+
|
586 |
+
|
587 |
+
|
588 |
+
def _grid_divisible(grid, size, name=None, verbose=True):
|
589 |
+
if size % grid == 0:
|
590 |
+
return size
|
591 |
+
_size = size
|
592 |
+
size = math.ceil(size / grid) * grid
|
593 |
+
if bool(verbose):
|
594 |
+
print(f"{verbose if isinstance(verbose,str) else ''}increasing '{'value' if name is None else name}' from {_size} to {size} to be evenly divisible by {grid} (grid)", flush=True)
|
595 |
+
assert size % grid == 0
|
596 |
+
return size
|
597 |
+
|
598 |
+
|
599 |
+
|
600 |
+
# def render_polygons(polys, shape):
|
601 |
+
# return polygons_to_label_coord(polys['coord'], shape=shape)
|
stardist_pkg/bioimageio_utils.py
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from pkg_resources import get_distribution
|
3 |
+
from zipfile import ZipFile
|
4 |
+
import numpy as np
|
5 |
+
import tempfile
|
6 |
+
from distutils.version import LooseVersion
|
7 |
+
from csbdeep.utils import axes_check_and_normalize, normalize, _raise
|
8 |
+
|
9 |
+
|
10 |
+
DEEPIMAGEJ_MACRO = \
|
11 |
+
"""
|
12 |
+
//*******************************************************************
|
13 |
+
// Date: July-2021
|
14 |
+
// Credits: StarDist, DeepImageJ
|
15 |
+
// URL:
|
16 |
+
// https://github.com/stardist/stardist
|
17 |
+
// https://deepimagej.github.io/deepimagej
|
18 |
+
// This macro was adapted from
|
19 |
+
// https://github.com/deepimagej/imagej-macros/blob/648caa867f6ccb459649d4d3799efa1e2e0c5204/StarDist2D_Post-processing.ijm
|
20 |
+
// Please cite the respective contributions when using this code.
|
21 |
+
//*******************************************************************
|
22 |
+
// Macro to run StarDist postprocessing on 2D images.
|
23 |
+
// StarDist and deepImageJ plugins need to be installed.
|
24 |
+
// The macro assumes that the image to process is a stack in which
|
25 |
+
// the first channel corresponds to the object probability map
|
26 |
+
// and the remaining channels are the radial distances from each
|
27 |
+
// pixel to the object boundary.
|
28 |
+
//*******************************************************************
|
29 |
+
|
30 |
+
// Get the name of the image to call it
|
31 |
+
getDimensions(width, height, channels, slices, frames);
|
32 |
+
name=getTitle();
|
33 |
+
|
34 |
+
probThresh={probThresh};
|
35 |
+
nmsThresh={nmsThresh};
|
36 |
+
|
37 |
+
// Isolate the detection probability scores
|
38 |
+
run("Make Substack...", "channels=1");
|
39 |
+
rename("scores");
|
40 |
+
|
41 |
+
// Isolate the oriented distances
|
42 |
+
run("Fire");
|
43 |
+
selectWindow(name);
|
44 |
+
run("Delete Slice", "delete=channel");
|
45 |
+
selectWindow(name);
|
46 |
+
run("Properties...", "channels=" + maxOf(channels, slices) - 1 + " slices=1 frames=1 pixel_width=1.0000 pixel_height=1.0000 voxel_depth=1.0000");
|
47 |
+
rename("distances");
|
48 |
+
run("royal");
|
49 |
+
|
50 |
+
// Run StarDist plugin
|
51 |
+
run("Command From Macro", "command=[de.csbdresden.stardist.StarDist2DNMS], args=['prob':'scores', 'dist':'distances', 'probThresh':'" + probThresh + "', 'nmsThresh':'" + nmsThresh + "', 'outputType':'Both', 'excludeBoundary':'2', 'roiPosition':'Stack', 'verbose':'false'], process=[false]");
|
52 |
+
"""
|
53 |
+
|
54 |
+
|
55 |
+
def _import(error=True):
|
56 |
+
try:
|
57 |
+
from importlib_metadata import metadata
|
58 |
+
from bioimageio.core.build_spec import build_model # type: ignore
|
59 |
+
import xarray as xr
|
60 |
+
import bioimageio.core # type: ignore
|
61 |
+
except ImportError:
|
62 |
+
if error:
|
63 |
+
raise RuntimeError(
|
64 |
+
"Required libraries are missing for bioimage.io model export.\n"
|
65 |
+
"Please install StarDist as follows: pip install 'stardist[bioimageio]'\n"
|
66 |
+
"(You do not need to uninstall StarDist first.)"
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
return None
|
70 |
+
return metadata, build_model, bioimageio.core, xr
|
71 |
+
|
72 |
+
|
73 |
+
def _create_stardist_dependencies(outdir):
|
74 |
+
from ruamel.yaml import YAML
|
75 |
+
from tensorflow import __version__ as tf_version
|
76 |
+
from . import __version__ as stardist_version
|
77 |
+
pkg_info = get_distribution("stardist")
|
78 |
+
# dependencies that start with the name "bioimageio" will be added as conda dependencies
|
79 |
+
reqs_conda = [str(req) for req in pkg_info.requires(extras=['bioimageio']) if str(req).startswith('bioimageio')]
|
80 |
+
# only stardist and tensorflow as pip dependencies
|
81 |
+
tf_major, tf_minor = LooseVersion(tf_version).version[:2]
|
82 |
+
reqs_pip = (f"stardist>={stardist_version}", f"tensorflow>={tf_major}.{tf_minor},<{tf_major+1}")
|
83 |
+
# conda environment
|
84 |
+
env = dict(
|
85 |
+
name = 'stardist',
|
86 |
+
channels = ['defaults', 'conda-forge'],
|
87 |
+
dependencies = [
|
88 |
+
('python>=3.7,<3.8' if tf_major == 1 else 'python>=3.7'),
|
89 |
+
*reqs_conda,
|
90 |
+
'pip', {'pip': reqs_pip},
|
91 |
+
],
|
92 |
+
)
|
93 |
+
yaml = YAML(typ='safe')
|
94 |
+
path = outdir / "environment.yaml"
|
95 |
+
with open(path, "w") as f:
|
96 |
+
yaml.dump(env, f)
|
97 |
+
return f"conda:{path}"
|
98 |
+
|
99 |
+
|
100 |
+
def _create_stardist_doc(outdir):
|
101 |
+
doc_path = outdir / "README.md"
|
102 |
+
text = (
|
103 |
+
"# StarDist Model\n"
|
104 |
+
"This is a model for object detection with star-convex shapes.\n"
|
105 |
+
"Please see the [StarDist repository](https://github.com/stardist/stardist) for details."
|
106 |
+
)
|
107 |
+
with open(doc_path, "w") as f:
|
108 |
+
f.write(text)
|
109 |
+
return doc_path
|
110 |
+
|
111 |
+
|
112 |
+
def _get_stardist_metadata(outdir, model):
|
113 |
+
metadata, *_ = _import()
|
114 |
+
package_data = metadata("stardist")
|
115 |
+
doi_2d = "https://doi.org/10.1007/978-3-030-00934-2_30"
|
116 |
+
doi_3d = "https://doi.org/10.1109/WACV45572.2020.9093435"
|
117 |
+
authors = {
|
118 |
+
'Martin Weigert': dict(name='Martin Weigert', github_user='maweigert'),
|
119 |
+
'Uwe Schmidt': dict(name='Uwe Schmidt', github_user='uschmidt83'),
|
120 |
+
}
|
121 |
+
data = dict(
|
122 |
+
description=package_data["Summary"],
|
123 |
+
authors=list(authors.get(name.strip(),dict(name=name.strip())) for name in package_data["Author"].split(",")),
|
124 |
+
git_repo=package_data["Home-Page"],
|
125 |
+
license=package_data["License"],
|
126 |
+
dependencies=_create_stardist_dependencies(outdir),
|
127 |
+
cite=[{"text": "Cell Detection with Star-Convex Polygons", "doi": doi_2d},
|
128 |
+
{"text": "Star-convex Polyhedra for 3D Object Detection and Segmentation in Microscopy", "doi": doi_3d}],
|
129 |
+
tags=[
|
130 |
+
'fluorescence-light-microscopy', 'whole-slide-imaging', 'other', # modality
|
131 |
+
f'{model.config.n_dim}d', # dims
|
132 |
+
'cells', 'nuclei', # content
|
133 |
+
'tensorflow', # framework
|
134 |
+
'fiji', # software
|
135 |
+
'unet', # network
|
136 |
+
'instance-segmentation', 'object-detection', # task
|
137 |
+
'stardist',
|
138 |
+
],
|
139 |
+
covers=["https://raw.githubusercontent.com/stardist/stardist/master/images/stardist_logo.jpg"],
|
140 |
+
documentation=_create_stardist_doc(outdir),
|
141 |
+
)
|
142 |
+
return data
|
143 |
+
|
144 |
+
|
145 |
+
def _predict_tf(model_path, test_input):
|
146 |
+
import tensorflow as tf
|
147 |
+
from csbdeep.utils.tf import IS_TF_1
|
148 |
+
# need to unzip the model assets
|
149 |
+
model_assets = model_path.parent / "tf_model"
|
150 |
+
with ZipFile(model_path, "r") as f:
|
151 |
+
f.extractall(model_assets)
|
152 |
+
if IS_TF_1:
|
153 |
+
# make a new graph, i.e. don't use the global default graph
|
154 |
+
with tf.Graph().as_default():
|
155 |
+
with tf.Session() as sess:
|
156 |
+
tf_model = tf.saved_model.load_v2(str(model_assets))
|
157 |
+
x = tf.convert_to_tensor(test_input, dtype=tf.float32)
|
158 |
+
model = tf_model.signatures["serving_default"]
|
159 |
+
y = model(x)
|
160 |
+
sess.run(tf.global_variables_initializer())
|
161 |
+
output = sess.run(y["output"])
|
162 |
+
else:
|
163 |
+
tf_model = tf.saved_model.load(str(model_assets))
|
164 |
+
x = tf.convert_to_tensor(test_input, dtype=tf.float32)
|
165 |
+
model = tf_model.signatures["serving_default"]
|
166 |
+
y = model(x)
|
167 |
+
output = y["output"].numpy()
|
168 |
+
return output
|
169 |
+
|
170 |
+
|
171 |
+
def _get_weights_and_model_metadata(outdir, model, test_input, test_input_axes, test_input_norm_axes, mode, min_percentile, max_percentile):
|
172 |
+
|
173 |
+
# get the path to the exported model assets (saved in outdir)
|
174 |
+
if mode == "keras_hdf5":
|
175 |
+
raise NotImplementedError("Export to keras format is not supported yet")
|
176 |
+
elif mode == "tensorflow_saved_model_bundle":
|
177 |
+
assets_uri = outdir / "TF_SavedModel.zip"
|
178 |
+
model_csbdeep = model.export_TF(assets_uri, single_output=True, upsample_grid=True)
|
179 |
+
else:
|
180 |
+
raise ValueError(f"Unsupported mode: {mode}")
|
181 |
+
|
182 |
+
# to force "inputs.data_type: float32" in the spec (bonus: disables normalization warning in model._predict_setup)
|
183 |
+
test_input = test_input.astype(np.float32)
|
184 |
+
|
185 |
+
# convert test_input to axes_net semantics and shape, also resize if necessary (to adhere to axes_net_div_by)
|
186 |
+
test_input, axes_img, axes_net, axes_net_div_by, *_ = model._predict_setup(
|
187 |
+
img=test_input,
|
188 |
+
axes=test_input_axes,
|
189 |
+
normalizer=None,
|
190 |
+
n_tiles=None,
|
191 |
+
show_tile_progress=False,
|
192 |
+
predict_kwargs={},
|
193 |
+
)
|
194 |
+
|
195 |
+
# normalization axes string and numeric indices
|
196 |
+
axes_norm = set(axes_net).intersection(set(axes_check_and_normalize(test_input_norm_axes, disallowed='S')))
|
197 |
+
axes_norm = "".join(a for a in axes_net if a in axes_norm) # preserve order of axes_net
|
198 |
+
axes_norm_num = tuple(axes_net.index(a) for a in axes_norm)
|
199 |
+
|
200 |
+
# normalize input image
|
201 |
+
test_input_norm = normalize(test_input, pmin=min_percentile, pmax=max_percentile, axis=axes_norm_num)
|
202 |
+
|
203 |
+
net_axes_in = axes_net.lower()
|
204 |
+
net_axes_out = axes_check_and_normalize(model._axes_out).lower()
|
205 |
+
ndim_tensor = len(net_axes_out) + 1
|
206 |
+
|
207 |
+
input_min_shape = list(axes_net_div_by)
|
208 |
+
input_min_shape[axes_net.index('C')] = model.config.n_channel_in
|
209 |
+
input_step = list(axes_net_div_by)
|
210 |
+
input_step[axes_net.index('C')] = 0
|
211 |
+
|
212 |
+
# add the batch axis to shape and step
|
213 |
+
input_min_shape = [1] + input_min_shape
|
214 |
+
input_step = [0] + input_step
|
215 |
+
|
216 |
+
# the axes strings in bioimageio convention
|
217 |
+
input_axes = "b" + net_axes_in.lower()
|
218 |
+
output_axes = "b" + net_axes_out.lower()
|
219 |
+
|
220 |
+
if mode == "keras_hdf5":
|
221 |
+
output_names = ("prob", "dist") + (("class_prob",) if model._is_multiclass() else ())
|
222 |
+
output_n_channels = (1, model.config.n_rays,) + ((1,) if model._is_multiclass() else ())
|
223 |
+
# the output shape is computed from the input shape using
|
224 |
+
# output_shape[i] = output_scale[i] * input_shape[i] + 2 * output_offset[i]
|
225 |
+
output_scale = [1]+list(1/g for g in model.config.grid) + [0]
|
226 |
+
output_offset = [0]*(ndim_tensor)
|
227 |
+
|
228 |
+
elif mode == "tensorflow_saved_model_bundle":
|
229 |
+
if model._is_multiclass():
|
230 |
+
raise NotImplementedError("Tensorflow SavedModel not supported for multiclass models yet")
|
231 |
+
# regarding input/output names: https://github.com/CSBDeep/CSBDeep/blob/b0d2f5f344ebe65a9b4c3007f4567fe74268c813/csbdeep/utils/tf.py#L193-L194
|
232 |
+
input_names = ["input"]
|
233 |
+
output_names = ["output"]
|
234 |
+
output_n_channels = (1 + model.config.n_rays,)
|
235 |
+
# the output shape is computed from the input shape using
|
236 |
+
# output_shape[i] = output_scale[i] * input_shape[i] + 2 * output_offset[i]
|
237 |
+
# same shape as input except for the channel dimension
|
238 |
+
output_scale = [1]*(ndim_tensor)
|
239 |
+
output_scale[output_axes.index("c")] = 0
|
240 |
+
# no offset, except for the input axes, where it is output channel / 2
|
241 |
+
output_offset = [0.0]*(ndim_tensor)
|
242 |
+
output_offset[output_axes.index("c")] = output_n_channels[0] / 2.0
|
243 |
+
|
244 |
+
assert all(s in (0, 1) for s in output_scale), "halo computation assumption violated"
|
245 |
+
halo = model._axes_tile_overlap(output_axes.replace('b', 's'))
|
246 |
+
halo = [int(np.ceil(v/8)*8) for v in halo] # optional: round up to be divisible by 8
|
247 |
+
|
248 |
+
# the output shape needs to be valid after cropping the halo, so we add the halo to the input min shape
|
249 |
+
input_min_shape = [ms + 2 * ha for ms, ha in zip(input_min_shape, halo)]
|
250 |
+
|
251 |
+
# make sure the input min shape is still divisible by the min axis divisor
|
252 |
+
input_min_shape = input_min_shape[:1] + [ms + (-ms % div_by) for ms, div_by in zip(input_min_shape[1:], axes_net_div_by)]
|
253 |
+
assert all(ms % div_by == 0 for ms, div_by in zip(input_min_shape[1:], axes_net_div_by))
|
254 |
+
|
255 |
+
metadata, *_ = _import()
|
256 |
+
package_data = metadata("stardist")
|
257 |
+
is_2D = model.config.n_dim == 2
|
258 |
+
|
259 |
+
weights_file = outdir / "stardist_weights.h5"
|
260 |
+
model.keras_model.save_weights(str(weights_file))
|
261 |
+
|
262 |
+
config = dict(
|
263 |
+
stardist=dict(
|
264 |
+
python_version=package_data["Version"],
|
265 |
+
thresholds=dict(model.thresholds._asdict()),
|
266 |
+
weights=weights_file.name,
|
267 |
+
config=vars(model.config),
|
268 |
+
)
|
269 |
+
)
|
270 |
+
|
271 |
+
if is_2D:
|
272 |
+
macro_file = outdir / "stardist_postprocessing.ijm"
|
273 |
+
with open(str(macro_file), 'w', encoding='utf-8') as f:
|
274 |
+
f.write(DEEPIMAGEJ_MACRO.format(probThresh=model.thresholds.prob, nmsThresh=model.thresholds.nms))
|
275 |
+
config['stardist'].update(postprocessing_macro=macro_file.name)
|
276 |
+
|
277 |
+
n_inputs = len(input_names)
|
278 |
+
assert n_inputs == 1
|
279 |
+
input_config = dict(
|
280 |
+
input_names=input_names,
|
281 |
+
input_min_shape=[input_min_shape],
|
282 |
+
input_step=[input_step],
|
283 |
+
input_axes=[input_axes],
|
284 |
+
input_data_range=[["-inf", "inf"]],
|
285 |
+
preprocessing=[[dict(
|
286 |
+
name="scale_range",
|
287 |
+
kwargs=dict(
|
288 |
+
mode="per_sample",
|
289 |
+
axes=axes_norm.lower(),
|
290 |
+
min_percentile=min_percentile,
|
291 |
+
max_percentile=max_percentile,
|
292 |
+
))]]
|
293 |
+
)
|
294 |
+
|
295 |
+
n_outputs = len(output_names)
|
296 |
+
output_config = dict(
|
297 |
+
output_names=output_names,
|
298 |
+
output_data_range=[["-inf", "inf"]] * n_outputs,
|
299 |
+
output_axes=[output_axes] * n_outputs,
|
300 |
+
output_reference=[input_names[0]] * n_outputs,
|
301 |
+
output_scale=[output_scale] * n_outputs,
|
302 |
+
output_offset=[output_offset] * n_outputs,
|
303 |
+
halo=[halo] * n_outputs
|
304 |
+
)
|
305 |
+
|
306 |
+
in_path = outdir / "test_input.npy"
|
307 |
+
np.save(in_path, test_input[np.newaxis])
|
308 |
+
|
309 |
+
if mode == "tensorflow_saved_model_bundle":
|
310 |
+
test_outputs = _predict_tf(assets_uri, test_input_norm[np.newaxis])
|
311 |
+
else:
|
312 |
+
test_outputs = model.predict(test_input_norm)
|
313 |
+
|
314 |
+
# out_paths = []
|
315 |
+
# for i, out in enumerate(test_outputs):
|
316 |
+
# p = outdir / f"test_output{i}.npy"
|
317 |
+
# np.save(p, out)
|
318 |
+
# out_paths.append(p)
|
319 |
+
assert n_outputs == 1
|
320 |
+
out_paths = [outdir / "test_output.npy"]
|
321 |
+
np.save(out_paths[0], test_outputs)
|
322 |
+
|
323 |
+
from tensorflow import __version__ as tf_version
|
324 |
+
data = dict(weight_uri=assets_uri, test_inputs=[in_path], test_outputs=out_paths,
|
325 |
+
config=config, tensorflow_version=tf_version)
|
326 |
+
data.update(input_config)
|
327 |
+
data.update(output_config)
|
328 |
+
_files = [str(weights_file)]
|
329 |
+
if is_2D:
|
330 |
+
_files.append(str(macro_file))
|
331 |
+
data.update(attachments=dict(files=_files))
|
332 |
+
|
333 |
+
return data
|
334 |
+
|
335 |
+
|
336 |
+
def export_bioimageio(
|
337 |
+
model,
|
338 |
+
outpath,
|
339 |
+
test_input,
|
340 |
+
test_input_axes=None,
|
341 |
+
test_input_norm_axes='ZYX',
|
342 |
+
name=None,
|
343 |
+
mode="tensorflow_saved_model_bundle",
|
344 |
+
min_percentile=1.0,
|
345 |
+
max_percentile=99.8,
|
346 |
+
overwrite_spec_kwargs=None,
|
347 |
+
):
|
348 |
+
"""Export stardist model into bioimage.io format, https://github.com/bioimage-io/spec-bioimage-io.
|
349 |
+
|
350 |
+
Parameters
|
351 |
+
----------
|
352 |
+
model: StarDist2D, StarDist3D
|
353 |
+
the model to convert
|
354 |
+
outpath: str, Path
|
355 |
+
where to save the model
|
356 |
+
test_input: np.ndarray
|
357 |
+
input image for generating test data
|
358 |
+
test_input_axes: str or None
|
359 |
+
the axes of the test input, for example 'YX' for a 2d image or 'ZYX' for a 3d volume
|
360 |
+
using None assumes that axes of test_input are the same as those of model
|
361 |
+
test_input_norm_axes: str
|
362 |
+
the axes of the test input which will be jointly normalized, for example 'ZYX' for all spatial dimensions ('Z' ignored for 2D input)
|
363 |
+
use 'ZYXC' to also jointly normalize channels (e.g. for RGB input images)
|
364 |
+
name: str
|
365 |
+
the name of this model (default: None)
|
366 |
+
if None, uses the (folder) name of the model (i.e. `model.name`)
|
367 |
+
mode: str
|
368 |
+
the export type for this model (default: "tensorflow_saved_model_bundle")
|
369 |
+
min_percentile: float
|
370 |
+
min percentile to be used for image normalization (default: 1.0)
|
371 |
+
max_percentile: float
|
372 |
+
max percentile to be used for image normalization (default: 99.8)
|
373 |
+
overwrite_spec_kwargs: dict or None
|
374 |
+
spec keywords that should be overloaded (default: None)
|
375 |
+
"""
|
376 |
+
_, build_model, *_ = _import()
|
377 |
+
from .models import StarDist2D, StarDist3D
|
378 |
+
isinstance(model, (StarDist2D, StarDist3D)) or _raise(ValueError("not a valid model"))
|
379 |
+
0 <= min_percentile < max_percentile <= 100 or _raise(ValueError("invalid percentile values"))
|
380 |
+
|
381 |
+
if name is None:
|
382 |
+
name = model.name
|
383 |
+
name = str(name)
|
384 |
+
|
385 |
+
outpath = Path(outpath)
|
386 |
+
if outpath.suffix == "":
|
387 |
+
outdir = outpath
|
388 |
+
zip_path = outdir / f"{name}.zip"
|
389 |
+
elif outpath.suffix == ".zip":
|
390 |
+
outdir = outpath.parent
|
391 |
+
zip_path = outpath
|
392 |
+
else:
|
393 |
+
raise ValueError(f"outpath has to be a folder or zip file, got {outpath}")
|
394 |
+
outdir.mkdir(exist_ok=True, parents=True)
|
395 |
+
|
396 |
+
with tempfile.TemporaryDirectory() as _tmp_dir:
|
397 |
+
tmp_dir = Path(_tmp_dir)
|
398 |
+
kwargs = _get_stardist_metadata(tmp_dir, model)
|
399 |
+
model_kwargs = _get_weights_and_model_metadata(tmp_dir, model, test_input, test_input_axes, test_input_norm_axes, mode,
|
400 |
+
min_percentile=min_percentile, max_percentile=max_percentile)
|
401 |
+
kwargs.update(model_kwargs)
|
402 |
+
if overwrite_spec_kwargs is not None:
|
403 |
+
kwargs.update(overwrite_spec_kwargs)
|
404 |
+
|
405 |
+
build_model(name=name, output_path=zip_path, add_deepimagej_config=(model.config.n_dim==2), root=tmp_dir, **kwargs)
|
406 |
+
print(f"\nbioimage.io model with name '{name}' exported to '{zip_path}'")
|
407 |
+
|
408 |
+
|
409 |
+
def import_bioimageio(source, outpath):
|
410 |
+
"""Import stardist model from bioimage.io format, https://github.com/bioimage-io/spec-bioimage-io.
|
411 |
+
|
412 |
+
Load a model in bioimage.io format from the given `source` (e.g. path to zip file, URL)
|
413 |
+
and convert it to a regular stardist model, which will be saved in the folder `outpath`.
|
414 |
+
|
415 |
+
Parameters
|
416 |
+
----------
|
417 |
+
source: str, Path
|
418 |
+
bioimage.io resource (e.g. path, URL)
|
419 |
+
outpath: str, Path
|
420 |
+
folder to save the stardist model (must not exist previously)
|
421 |
+
|
422 |
+
Returns
|
423 |
+
-------
|
424 |
+
StarDist2D or StarDist3D
|
425 |
+
stardist model loaded from `outpath`
|
426 |
+
|
427 |
+
"""
|
428 |
+
import shutil, uuid
|
429 |
+
from csbdeep.utils import save_json
|
430 |
+
from .models import StarDist2D, StarDist3D
|
431 |
+
*_, bioimageio_core, _ = _import()
|
432 |
+
|
433 |
+
outpath = Path(outpath)
|
434 |
+
not outpath.exists() or _raise(FileExistsError(f"'{outpath}' already exists"))
|
435 |
+
|
436 |
+
with tempfile.TemporaryDirectory() as _tmp_dir:
|
437 |
+
tmp_dir = Path(_tmp_dir)
|
438 |
+
# download the full model content to a temporary folder
|
439 |
+
zip_path = tmp_dir / f"{str(uuid.uuid4())}.zip"
|
440 |
+
bioimageio_core.export_resource_package(source, output_path=zip_path)
|
441 |
+
with ZipFile(zip_path, "r") as zip_ref:
|
442 |
+
zip_ref.extractall(tmp_dir)
|
443 |
+
zip_path.unlink()
|
444 |
+
rdf_path = tmp_dir / "rdf.yaml"
|
445 |
+
biomodel = bioimageio_core.load_resource_description(rdf_path)
|
446 |
+
|
447 |
+
# read the stardist specific content
|
448 |
+
'stardist' in biomodel.config or _raise(RuntimeError("bioimage.io model not compatible"))
|
449 |
+
config = biomodel.config['stardist']['config']
|
450 |
+
thresholds = biomodel.config['stardist']['thresholds']
|
451 |
+
weights = biomodel.config['stardist']['weights']
|
452 |
+
|
453 |
+
# make sure that the keras weights are in the attachments
|
454 |
+
weights_file = None
|
455 |
+
for f in biomodel.attachments.files:
|
456 |
+
if f.name == weights and f.exists():
|
457 |
+
weights_file = f
|
458 |
+
break
|
459 |
+
weights_file is not None or _raise(FileNotFoundError(f"couldn't find weights file '{weights}'"))
|
460 |
+
|
461 |
+
# save the config and threshold to json, and weights to hdf5 to enable loading as stardist model
|
462 |
+
# copy bioimageio files to separate sub-folder
|
463 |
+
outpath.mkdir(parents=True)
|
464 |
+
save_json(config, str(outpath / 'config.json'))
|
465 |
+
save_json(thresholds, str(outpath / 'thresholds.json'))
|
466 |
+
shutil.copy(str(weights_file), str(outpath / "weights_bioimageio.h5"))
|
467 |
+
shutil.copytree(str(tmp_dir), str(outpath / "bioimageio"))
|
468 |
+
|
469 |
+
model_class = (StarDist2D if config['n_dim'] == 2 else StarDist3D)
|
470 |
+
model = model_class(None, outpath.name, basedir=str(outpath.parent))
|
471 |
+
|
472 |
+
return model
|
stardist_pkg/geometry/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, print_function
|
2 |
+
|
3 |
+
# TODO: rethink naming for 2D/3D functions
|
4 |
+
|
5 |
+
from .geom2d import star_dist, relabel_image_stardist, ray_angles, dist_to_coord, polygons_to_label, polygons_to_label_coord
|
6 |
+
|
7 |
+
from .geom2d import _dist_to_coord_old, _polygons_to_label_old
|
8 |
+
|
9 |
+
#, dist_to_volume, dist_to_centroid
|
stardist_pkg/geometry/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (522 Bytes). View file
|
|
stardist_pkg/geometry/__pycache__/geom2d.cpython-37.pyc
ADDED
Binary file (7.23 kB). View file
|
|
stardist_pkg/geometry/__pycache__/geom3d.cpython-37.pyc
ADDED
Binary file (11 kB). View file
|
|
stardist_pkg/geometry/geom2d.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
2 |
+
import numpy as np
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
from skimage.measure import regionprops
|
6 |
+
from skimage.draw import polygon
|
7 |
+
from csbdeep.utils import _raise
|
8 |
+
|
9 |
+
from ..utils import path_absolute, _is_power_of_2, _normalize_grid
|
10 |
+
from ..matching import _check_label_array
|
11 |
+
from stardist.lib.stardist2d import c_star_dist
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
def _ocl_star_dist(lbl, n_rays=32, grid=(1,1)):
|
16 |
+
from gputools import OCLProgram, OCLArray, OCLImage
|
17 |
+
(np.isscalar(n_rays) and 0 < int(n_rays)) or _raise(ValueError())
|
18 |
+
n_rays = int(n_rays)
|
19 |
+
# slicing with grid is done with tuple(slice(0, None, g) for g in grid)
|
20 |
+
res_shape = tuple((s-1)//g+1 for s, g in zip(lbl.shape, grid))
|
21 |
+
|
22 |
+
src = OCLImage.from_array(lbl.astype(np.uint16,copy=False))
|
23 |
+
dst = OCLArray.empty(res_shape+(n_rays,), dtype=np.float32)
|
24 |
+
program = OCLProgram(path_absolute("kernels/stardist2d.cl"), build_options=['-D', 'N_RAYS=%d' % n_rays])
|
25 |
+
program.run_kernel('star_dist', res_shape[::-1], None, dst.data, src, np.int32(grid[0]),np.int32(grid[1]))
|
26 |
+
return dst.get()
|
27 |
+
|
28 |
+
|
29 |
+
def _cpp_star_dist(lbl, n_rays=32, grid=(1,1)):
|
30 |
+
(np.isscalar(n_rays) and 0 < int(n_rays)) or _raise(ValueError())
|
31 |
+
return c_star_dist(lbl.astype(np.uint16,copy=False), np.int32(n_rays), np.int32(grid[0]),np.int32(grid[1]))
|
32 |
+
|
33 |
+
|
34 |
+
def _py_star_dist(a, n_rays=32, grid=(1,1)):
|
35 |
+
(np.isscalar(n_rays) and 0 < int(n_rays)) or _raise(ValueError())
|
36 |
+
if grid != (1,1):
|
37 |
+
raise NotImplementedError(grid)
|
38 |
+
|
39 |
+
n_rays = int(n_rays)
|
40 |
+
a = a.astype(np.uint16,copy=False)
|
41 |
+
dst = np.empty(a.shape+(n_rays,),np.float32)
|
42 |
+
|
43 |
+
for i in range(a.shape[0]):
|
44 |
+
for j in range(a.shape[1]):
|
45 |
+
value = a[i,j]
|
46 |
+
if value == 0:
|
47 |
+
dst[i,j] = 0
|
48 |
+
else:
|
49 |
+
st_rays = np.float32((2*np.pi) / n_rays)
|
50 |
+
for k in range(n_rays):
|
51 |
+
phi = np.float32(k*st_rays)
|
52 |
+
dy = np.cos(phi)
|
53 |
+
dx = np.sin(phi)
|
54 |
+
x, y = np.float32(0), np.float32(0)
|
55 |
+
while True:
|
56 |
+
x += dx
|
57 |
+
y += dy
|
58 |
+
ii = int(round(i+x))
|
59 |
+
jj = int(round(j+y))
|
60 |
+
if (ii < 0 or ii >= a.shape[0] or
|
61 |
+
jj < 0 or jj >= a.shape[1] or
|
62 |
+
value != a[ii,jj]):
|
63 |
+
# small correction as we overshoot the boundary
|
64 |
+
t_corr = 1-.5/max(np.abs(dx),np.abs(dy))
|
65 |
+
x -= t_corr*dx
|
66 |
+
y -= t_corr*dy
|
67 |
+
dist = np.sqrt(x**2+y**2)
|
68 |
+
dst[i,j,k] = dist
|
69 |
+
break
|
70 |
+
return dst
|
71 |
+
|
72 |
+
|
73 |
+
def star_dist(a, n_rays=32, grid=(1,1), mode='cpp'):
|
74 |
+
"""'a' assumbed to be a label image with integer values that encode object ids. id 0 denotes background."""
|
75 |
+
|
76 |
+
n_rays >= 3 or _raise(ValueError("need 'n_rays' >= 3"))
|
77 |
+
|
78 |
+
if mode == 'python':
|
79 |
+
return _py_star_dist(a, n_rays, grid=grid)
|
80 |
+
elif mode == 'cpp':
|
81 |
+
return _cpp_star_dist(a, n_rays, grid=grid)
|
82 |
+
elif mode == 'opencl':
|
83 |
+
return _ocl_star_dist(a, n_rays, grid=grid)
|
84 |
+
else:
|
85 |
+
_raise(ValueError("Unknown mode %s" % mode))
|
86 |
+
|
87 |
+
|
88 |
+
def _dist_to_coord_old(rhos, grid=(1,1)):
|
89 |
+
"""convert from polar to cartesian coordinates for a single image (3-D array) or multiple images (4-D array)"""
|
90 |
+
|
91 |
+
grid = _normalize_grid(grid,2)
|
92 |
+
is_single_image = rhos.ndim == 3
|
93 |
+
if is_single_image:
|
94 |
+
rhos = np.expand_dims(rhos,0)
|
95 |
+
assert rhos.ndim == 4
|
96 |
+
|
97 |
+
n_images,h,w,n_rays = rhos.shape
|
98 |
+
coord = np.empty((n_images,h,w,2,n_rays),dtype=rhos.dtype)
|
99 |
+
|
100 |
+
start = np.indices((h,w))
|
101 |
+
for i in range(2):
|
102 |
+
coord[...,i,:] = grid[i] * np.broadcast_to(start[i].reshape(1,h,w,1), (n_images,h,w,n_rays))
|
103 |
+
|
104 |
+
phis = ray_angles(n_rays).reshape(1,1,1,n_rays)
|
105 |
+
|
106 |
+
coord[...,0,:] += rhos * np.sin(phis) # row coordinate
|
107 |
+
coord[...,1,:] += rhos * np.cos(phis) # col coordinate
|
108 |
+
|
109 |
+
return coord[0] if is_single_image else coord
|
110 |
+
|
111 |
+
|
112 |
+
def _polygons_to_label_old(coord, prob, points, shape=None, thr=-np.inf):
|
113 |
+
sh = coord.shape[:2] if shape is None else shape
|
114 |
+
lbl = np.zeros(sh,np.int32)
|
115 |
+
# sort points with increasing probability
|
116 |
+
ind = np.argsort([ prob[p[0],p[1]] for p in points ])
|
117 |
+
points = points[ind]
|
118 |
+
|
119 |
+
i = 1
|
120 |
+
for p in points:
|
121 |
+
if prob[p[0],p[1]] < thr:
|
122 |
+
continue
|
123 |
+
rr,cc = polygon(coord[p[0],p[1],0], coord[p[0],p[1],1], sh)
|
124 |
+
lbl[rr,cc] = i
|
125 |
+
i += 1
|
126 |
+
|
127 |
+
return lbl
|
128 |
+
|
129 |
+
|
130 |
+
def dist_to_coord(dist, points, scale_dist=(1,1)):
|
131 |
+
"""convert from polar to cartesian coordinates for a list of distances and center points
|
132 |
+
dist.shape = (n_polys, n_rays)
|
133 |
+
points.shape = (n_polys, 2)
|
134 |
+
len(scale_dist) = 2
|
135 |
+
return coord.shape = (n_polys,2,n_rays)
|
136 |
+
"""
|
137 |
+
dist = np.asarray(dist)
|
138 |
+
points = np.asarray(points)
|
139 |
+
assert dist.ndim==2 and points.ndim==2 and len(dist)==len(points) \
|
140 |
+
and points.shape[1]==2 and len(scale_dist)==2
|
141 |
+
n_rays = dist.shape[1]
|
142 |
+
phis = ray_angles(n_rays)
|
143 |
+
coord = (dist[:,np.newaxis]*np.array([np.sin(phis),np.cos(phis)])).astype(np.float32)
|
144 |
+
coord *= np.asarray(scale_dist).reshape(1,2,1)
|
145 |
+
coord += points[...,np.newaxis]
|
146 |
+
return coord
|
147 |
+
|
148 |
+
|
149 |
+
def polygons_to_label_coord(coord, shape, labels=None):
|
150 |
+
"""renders polygons to image of given shape
|
151 |
+
|
152 |
+
coord.shape = (n_polys, n_rays)
|
153 |
+
"""
|
154 |
+
coord = np.asarray(coord)
|
155 |
+
if labels is None: labels = np.arange(len(coord))
|
156 |
+
|
157 |
+
_check_label_array(labels, "labels")
|
158 |
+
assert coord.ndim==3 and coord.shape[1]==2 and len(coord)==len(labels)
|
159 |
+
|
160 |
+
lbl = np.zeros(shape,np.int32)
|
161 |
+
|
162 |
+
for i,c in zip(labels,coord):
|
163 |
+
rr,cc = polygon(*c, shape)
|
164 |
+
lbl[rr,cc] = i+1
|
165 |
+
|
166 |
+
return lbl
|
167 |
+
|
168 |
+
|
169 |
+
def polygons_to_label(dist, points, shape, prob=None, thr=-np.inf, scale_dist=(1,1)):
|
170 |
+
"""converts distances and center points to label image
|
171 |
+
|
172 |
+
dist.shape = (n_polys, n_rays)
|
173 |
+
points.shape = (n_polys, 2)
|
174 |
+
|
175 |
+
label ids will be consecutive and adhere to the order given
|
176 |
+
"""
|
177 |
+
dist = np.asarray(dist)
|
178 |
+
points = np.asarray(points)
|
179 |
+
prob = np.inf*np.ones(len(points)) if prob is None else np.asarray(prob)
|
180 |
+
|
181 |
+
assert dist.ndim==2 and points.ndim==2 and len(dist)==len(points)
|
182 |
+
assert len(points)==len(prob) and points.shape[1]==2 and prob.ndim==1
|
183 |
+
|
184 |
+
n_rays = dist.shape[1]
|
185 |
+
|
186 |
+
ind = prob>thr
|
187 |
+
points = points[ind]
|
188 |
+
dist = dist[ind]
|
189 |
+
prob = prob[ind]
|
190 |
+
|
191 |
+
ind = np.argsort(prob, kind='stable')
|
192 |
+
points = points[ind]
|
193 |
+
dist = dist[ind]
|
194 |
+
|
195 |
+
coord = dist_to_coord(dist, points, scale_dist=scale_dist)
|
196 |
+
|
197 |
+
return polygons_to_label_coord(coord, shape=shape, labels=ind)
|
198 |
+
|
199 |
+
|
200 |
+
def relabel_image_stardist(lbl, n_rays, **kwargs):
|
201 |
+
"""relabel each label region in `lbl` with its star representation"""
|
202 |
+
_check_label_array(lbl, "lbl")
|
203 |
+
if not lbl.ndim==2:
|
204 |
+
raise ValueError("lbl image should be 2 dimensional")
|
205 |
+
dist = star_dist(lbl, n_rays, **kwargs)
|
206 |
+
points = np.array(tuple(np.array(r.centroid).astype(int) for r in regionprops(lbl)))
|
207 |
+
dist = dist[tuple(points.T)]
|
208 |
+
return polygons_to_label(dist, points, shape=lbl.shape)
|
209 |
+
|
210 |
+
|
211 |
+
def ray_angles(n_rays=32):
|
212 |
+
return np.linspace(0,2*np.pi,n_rays,endpoint=False)
|
stardist_pkg/kernels/stardist2d.cl
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef M_PI
|
2 |
+
#define M_PI 3.141592653589793
|
3 |
+
#endif
|
4 |
+
|
5 |
+
__constant sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
6 |
+
|
7 |
+
inline float2 pol2cart(const float rho, const float phi) {
|
8 |
+
const float x = rho * cos(phi);
|
9 |
+
const float y = rho * sin(phi);
|
10 |
+
return (float2)(x,y);
|
11 |
+
}
|
12 |
+
|
13 |
+
__kernel void star_dist(__global float* dst, read_only image2d_t src, const int grid_y, const int grid_x) {
|
14 |
+
|
15 |
+
const int i = get_global_id(0), j = get_global_id(1);
|
16 |
+
const int Nx = get_global_size(0), Ny = get_global_size(1);
|
17 |
+
const float2 grid = (float2)(grid_x, grid_y);
|
18 |
+
|
19 |
+
const float2 origin = (float2)(i,j) * grid;
|
20 |
+
const int value = read_imageui(src,sampler,origin).x;
|
21 |
+
|
22 |
+
if (value == 0) {
|
23 |
+
// background pixel -> nothing to do, write all zeros
|
24 |
+
for (int k = 0; k < N_RAYS; k++) {
|
25 |
+
dst[k + i*N_RAYS + j*N_RAYS*Nx] = 0;
|
26 |
+
}
|
27 |
+
} else {
|
28 |
+
float st_rays = (2*M_PI) / N_RAYS; // step size for ray angles
|
29 |
+
// for all rays
|
30 |
+
for (int k = 0; k < N_RAYS; k++) {
|
31 |
+
const float phi = k*st_rays; // current ray angle phi
|
32 |
+
const float2 dir = pol2cart(1,phi); // small vector in direction of ray
|
33 |
+
float2 offset = 0; // offset vector to be added to origin
|
34 |
+
// find radius that leaves current object
|
35 |
+
while (1) {
|
36 |
+
offset += dir;
|
37 |
+
const int offset_value = read_imageui(src,sampler,round(origin+offset)).x;
|
38 |
+
if (offset_value != value) {
|
39 |
+
// small correction as we overshoot the boundary
|
40 |
+
const float t_corr = .5f/fmax(fabs(dir.x),fabs(dir.y));
|
41 |
+
offset += (t_corr-1.f)*dir;
|
42 |
+
|
43 |
+
const float dist = sqrt(offset.x*offset.x + offset.y*offset.y);
|
44 |
+
dst[k + i*N_RAYS + j*N_RAYS*Nx] = dist;
|
45 |
+
break;
|
46 |
+
}
|
47 |
+
}
|
48 |
+
}
|
49 |
+
}
|
50 |
+
|
51 |
+
}
|
stardist_pkg/kernels/stardist3d.cl
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef M_PI
|
2 |
+
#define M_PI 3.141592653589793
|
3 |
+
#endif
|
4 |
+
|
5 |
+
__constant sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
6 |
+
|
7 |
+
inline int round_to_int(float r) {
|
8 |
+
return (int)rint(r);
|
9 |
+
}
|
10 |
+
|
11 |
+
|
12 |
+
__kernel void stardist3d(read_only image3d_t lbl, __constant float * rays, __global float* dist, const int grid_z, const int grid_y, const int grid_x) {
|
13 |
+
|
14 |
+
const int i = get_global_id(0);
|
15 |
+
const int j = get_global_id(1);
|
16 |
+
const int k = get_global_id(2);
|
17 |
+
|
18 |
+
const int Nx = get_global_size(0);
|
19 |
+
const int Ny = get_global_size(1);
|
20 |
+
const int Nz = get_global_size(2);
|
21 |
+
|
22 |
+
const float4 grid = (float4)(grid_x, grid_y, grid_z, 1);
|
23 |
+
const float4 origin = (float4)(i,j,k,0) * grid;
|
24 |
+
const int value = read_imageui(lbl,sampler,origin).x;
|
25 |
+
|
26 |
+
if (value == 0) {
|
27 |
+
// background pixel -> nothing to do, write all zeros
|
28 |
+
for (int m = 0; m < N_RAYS; m++) {
|
29 |
+
dist[m + i*N_RAYS + j*N_RAYS*Nx+k*N_RAYS*Nx*Ny] = 0;
|
30 |
+
}
|
31 |
+
|
32 |
+
}
|
33 |
+
else {
|
34 |
+
for (int m = 0; m < N_RAYS; m++) {
|
35 |
+
|
36 |
+
const float4 dx = (float4)(rays[3*m+2],rays[3*m+1],rays[3*m],0);
|
37 |
+
// if ((i==Nx/2)&&(j==Ny/2)&(k==Nz/2)){
|
38 |
+
// printf("kernel: %.2f %.2f %.2f \n",dx.x,dx.y,dx.z);
|
39 |
+
// }
|
40 |
+
float4 x = (float4)(0,0,0,0);
|
41 |
+
|
42 |
+
// move along ray
|
43 |
+
while (1) {
|
44 |
+
x += dx;
|
45 |
+
// if ((i==10)&&(j==10)&(k==10)){
|
46 |
+
// printf("kernel run: %.2f %.2f %.2f value %d \n",x.x,x.y,x.z, read_imageui(lbl,sampler,origin+x).x);
|
47 |
+
// }
|
48 |
+
|
49 |
+
// to make it equivalent to the cpp version...
|
50 |
+
const float4 x_int = (float4)(round_to_int(x.x),
|
51 |
+
round_to_int(x.y),
|
52 |
+
round_to_int(x.z),
|
53 |
+
0);
|
54 |
+
|
55 |
+
if (value != read_imageui(lbl,sampler,origin+x_int).x){
|
56 |
+
|
57 |
+
dist[m + i*N_RAYS + j*N_RAYS*Nx+k*N_RAYS*Nx*Ny] = length(x_int);
|
58 |
+
break;
|
59 |
+
}
|
60 |
+
}
|
61 |
+
}
|
62 |
+
}
|
63 |
+
}
|
stardist_pkg/matching.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
from numba import jit
|
4 |
+
from tqdm import tqdm
|
5 |
+
from scipy.optimize import linear_sum_assignment
|
6 |
+
from skimage.measure import regionprops
|
7 |
+
from collections import namedtuple
|
8 |
+
from csbdeep.utils import _raise
|
9 |
+
|
10 |
+
matching_criteria = dict()
|
11 |
+
|
12 |
+
|
13 |
+
def label_are_sequential(y):
|
14 |
+
""" returns true if y has only sequential labels from 1... """
|
15 |
+
labels = np.unique(y)
|
16 |
+
return (set(labels)-{0}) == set(range(1,1+labels.max()))
|
17 |
+
|
18 |
+
|
19 |
+
def is_array_of_integers(y):
|
20 |
+
return isinstance(y,np.ndarray) and np.issubdtype(y.dtype, np.integer)
|
21 |
+
|
22 |
+
|
23 |
+
def _check_label_array(y, name=None, check_sequential=False):
|
24 |
+
err = ValueError("{label} must be an array of {integers}.".format(
|
25 |
+
label = 'labels' if name is None else name,
|
26 |
+
integers = ('sequential ' if check_sequential else '') + 'non-negative integers',
|
27 |
+
))
|
28 |
+
is_array_of_integers(y) or _raise(err)
|
29 |
+
if len(y) == 0:
|
30 |
+
return True
|
31 |
+
if check_sequential:
|
32 |
+
label_are_sequential(y) or _raise(err)
|
33 |
+
else:
|
34 |
+
y.min() >= 0 or _raise(err)
|
35 |
+
return True
|
36 |
+
|
37 |
+
|
38 |
+
def label_overlap(x, y, check=True):
|
39 |
+
if check:
|
40 |
+
_check_label_array(x,'x',True)
|
41 |
+
_check_label_array(y,'y',True)
|
42 |
+
x.shape == y.shape or _raise(ValueError("x and y must have the same shape"))
|
43 |
+
return _label_overlap(x, y)
|
44 |
+
|
45 |
+
@jit(nopython=True)
|
46 |
+
def _label_overlap(x, y):
|
47 |
+
x = x.ravel()
|
48 |
+
y = y.ravel()
|
49 |
+
overlap = np.zeros((1+x.max(),1+y.max()), dtype=np.uint)
|
50 |
+
for i in range(len(x)):
|
51 |
+
overlap[x[i],y[i]] += 1
|
52 |
+
return overlap
|
53 |
+
|
54 |
+
|
55 |
+
def _safe_divide(x,y, eps=1e-10):
|
56 |
+
"""computes a safe divide which returns 0 if y is zero"""
|
57 |
+
if np.isscalar(x) and np.isscalar(y):
|
58 |
+
return x/y if np.abs(y)>eps else 0.0
|
59 |
+
else:
|
60 |
+
out = np.zeros(np.broadcast(x,y).shape, np.float32)
|
61 |
+
np.divide(x,y, out=out, where=np.abs(y)>eps)
|
62 |
+
return out
|
63 |
+
|
64 |
+
|
65 |
+
def intersection_over_union(overlap):
|
66 |
+
_check_label_array(overlap,'overlap')
|
67 |
+
if np.sum(overlap) == 0:
|
68 |
+
return overlap
|
69 |
+
n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
|
70 |
+
n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
|
71 |
+
return _safe_divide(overlap, (n_pixels_pred + n_pixels_true - overlap))
|
72 |
+
|
73 |
+
matching_criteria['iou'] = intersection_over_union
|
74 |
+
|
75 |
+
|
76 |
+
def intersection_over_true(overlap):
|
77 |
+
_check_label_array(overlap,'overlap')
|
78 |
+
if np.sum(overlap) == 0:
|
79 |
+
return overlap
|
80 |
+
n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
|
81 |
+
return _safe_divide(overlap, n_pixels_true)
|
82 |
+
|
83 |
+
matching_criteria['iot'] = intersection_over_true
|
84 |
+
|
85 |
+
|
86 |
+
def intersection_over_pred(overlap):
|
87 |
+
_check_label_array(overlap,'overlap')
|
88 |
+
if np.sum(overlap) == 0:
|
89 |
+
return overlap
|
90 |
+
n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
|
91 |
+
return _safe_divide(overlap, n_pixels_pred)
|
92 |
+
|
93 |
+
matching_criteria['iop'] = intersection_over_pred
|
94 |
+
|
95 |
+
|
96 |
+
def precision(tp,fp,fn):
|
97 |
+
return tp/(tp+fp) if tp > 0 else 0
|
98 |
+
def recall(tp,fp,fn):
|
99 |
+
return tp/(tp+fn) if tp > 0 else 0
|
100 |
+
def accuracy(tp,fp,fn):
|
101 |
+
# also known as "average precision" (?)
|
102 |
+
# -> https://www.kaggle.com/c/data-science-bowl-2018#evaluation
|
103 |
+
return tp/(tp+fp+fn) if tp > 0 else 0
|
104 |
+
def f1(tp,fp,fn):
|
105 |
+
# also known as "dice coefficient"
|
106 |
+
return (2*tp)/(2*tp+fp+fn) if tp > 0 else 0
|
107 |
+
|
108 |
+
|
109 |
+
def matching(y_true, y_pred, thresh=0.5, criterion='iou', report_matches=False):
|
110 |
+
"""Calculate detection/instance segmentation metrics between ground truth and predicted label images.
|
111 |
+
|
112 |
+
Currently, the following metrics are implemented:
|
113 |
+
|
114 |
+
'fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'
|
115 |
+
|
116 |
+
Corresponding objects of y_true and y_pred are counted as true positives (tp), false positives (fp), and false negatives (fn)
|
117 |
+
whether their intersection over union (IoU) >= thresh (for criterion='iou', which can be changed)
|
118 |
+
|
119 |
+
* mean_matched_score is the mean IoUs of matched true positives
|
120 |
+
|
121 |
+
* mean_true_score is the mean IoUs of matched true positives but normalized by the total number of GT objects
|
122 |
+
|
123 |
+
* panoptic_quality defined as in Eq. 1 of Kirillov et al. "Panoptic Segmentation", CVPR 2019
|
124 |
+
|
125 |
+
Parameters
|
126 |
+
----------
|
127 |
+
y_true: ndarray
|
128 |
+
ground truth label image (integer valued)
|
129 |
+
y_pred: ndarray
|
130 |
+
predicted label image (integer valued)
|
131 |
+
thresh: float
|
132 |
+
threshold for matching criterion (default 0.5)
|
133 |
+
criterion: string
|
134 |
+
matching criterion (default IoU)
|
135 |
+
report_matches: bool
|
136 |
+
if True, additionally calculate matched_pairs and matched_scores (note, that this returns even gt-pred pairs whose scores are below 'thresh')
|
137 |
+
|
138 |
+
Returns
|
139 |
+
-------
|
140 |
+
Matching object with different metrics as attributes
|
141 |
+
|
142 |
+
Examples
|
143 |
+
--------
|
144 |
+
>>> y_true = np.zeros((100,100), np.uint16)
|
145 |
+
>>> y_true[10:20,10:20] = 1
|
146 |
+
>>> y_pred = np.roll(y_true,5,axis = 0)
|
147 |
+
|
148 |
+
>>> stats = matching(y_true, y_pred)
|
149 |
+
>>> print(stats)
|
150 |
+
Matching(criterion='iou', thresh=0.5, fp=1, tp=0, fn=1, precision=0, recall=0, accuracy=0, f1=0, n_true=1, n_pred=1, mean_true_score=0.0, mean_matched_score=0.0, panoptic_quality=0.0)
|
151 |
+
|
152 |
+
"""
|
153 |
+
_check_label_array(y_true,'y_true')
|
154 |
+
_check_label_array(y_pred,'y_pred')
|
155 |
+
y_true.shape == y_pred.shape or _raise(ValueError("y_true ({y_true.shape}) and y_pred ({y_pred.shape}) have different shapes".format(y_true=y_true, y_pred=y_pred)))
|
156 |
+
criterion in matching_criteria or _raise(ValueError("Matching criterion '%s' not supported." % criterion))
|
157 |
+
if thresh is None: thresh = 0
|
158 |
+
thresh = float(thresh) if np.isscalar(thresh) else map(float,thresh)
|
159 |
+
|
160 |
+
y_true, _, map_rev_true = relabel_sequential(y_true)
|
161 |
+
y_pred, _, map_rev_pred = relabel_sequential(y_pred)
|
162 |
+
|
163 |
+
overlap = label_overlap(y_true, y_pred, check=False)
|
164 |
+
scores = matching_criteria[criterion](overlap)
|
165 |
+
assert 0 <= np.min(scores) <= np.max(scores) <= 1
|
166 |
+
|
167 |
+
# ignoring background
|
168 |
+
scores = scores[1:,1:]
|
169 |
+
n_true, n_pred = scores.shape
|
170 |
+
n_matched = min(n_true, n_pred)
|
171 |
+
|
172 |
+
def _single(thr):
|
173 |
+
# not_trivial = n_matched > 0 and np.any(scores >= thr)
|
174 |
+
not_trivial = n_matched > 0
|
175 |
+
if not_trivial:
|
176 |
+
# compute optimal matching with scores as tie-breaker
|
177 |
+
costs = -(scores >= thr).astype(float) - scores / (2*n_matched)
|
178 |
+
true_ind, pred_ind = linear_sum_assignment(costs)
|
179 |
+
assert n_matched == len(true_ind) == len(pred_ind)
|
180 |
+
match_ok = scores[true_ind,pred_ind] >= thr
|
181 |
+
tp = np.count_nonzero(match_ok)
|
182 |
+
else:
|
183 |
+
tp = 0
|
184 |
+
fp = n_pred - tp
|
185 |
+
fn = n_true - tp
|
186 |
+
# assert tp+fp == n_pred
|
187 |
+
# assert tp+fn == n_true
|
188 |
+
|
189 |
+
# the score sum over all matched objects (tp)
|
190 |
+
sum_matched_score = np.sum(scores[true_ind,pred_ind][match_ok]) if not_trivial else 0.0
|
191 |
+
|
192 |
+
# the score average over all matched objects (tp)
|
193 |
+
mean_matched_score = _safe_divide(sum_matched_score, tp)
|
194 |
+
# the score average over all gt/true objects
|
195 |
+
mean_true_score = _safe_divide(sum_matched_score, n_true)
|
196 |
+
panoptic_quality = _safe_divide(sum_matched_score, tp+fp/2+fn/2)
|
197 |
+
|
198 |
+
stats_dict = dict (
|
199 |
+
criterion = criterion,
|
200 |
+
thresh = thr,
|
201 |
+
fp = fp,
|
202 |
+
tp = tp,
|
203 |
+
fn = fn,
|
204 |
+
precision = precision(tp,fp,fn),
|
205 |
+
recall = recall(tp,fp,fn),
|
206 |
+
accuracy = accuracy(tp,fp,fn),
|
207 |
+
f1 = f1(tp,fp,fn),
|
208 |
+
n_true = n_true,
|
209 |
+
n_pred = n_pred,
|
210 |
+
mean_true_score = mean_true_score,
|
211 |
+
mean_matched_score = mean_matched_score,
|
212 |
+
panoptic_quality = panoptic_quality,
|
213 |
+
)
|
214 |
+
if bool(report_matches):
|
215 |
+
if not_trivial:
|
216 |
+
stats_dict.update (
|
217 |
+
# int() to be json serializable
|
218 |
+
matched_pairs = tuple((int(map_rev_true[i]),int(map_rev_pred[j])) for i,j in zip(1+true_ind,1+pred_ind)),
|
219 |
+
matched_scores = tuple(scores[true_ind,pred_ind]),
|
220 |
+
matched_tps = tuple(map(int,np.flatnonzero(match_ok))),
|
221 |
+
)
|
222 |
+
else:
|
223 |
+
stats_dict.update (
|
224 |
+
matched_pairs = (),
|
225 |
+
matched_scores = (),
|
226 |
+
matched_tps = (),
|
227 |
+
)
|
228 |
+
return namedtuple('Matching',stats_dict.keys())(*stats_dict.values())
|
229 |
+
|
230 |
+
return _single(thresh) if np.isscalar(thresh) else tuple(map(_single,thresh))
|
231 |
+
|
232 |
+
|
233 |
+
|
234 |
+
def matching_dataset(y_true, y_pred, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False):
|
235 |
+
"""matching metrics for list of images, see `stardist.matching.matching`
|
236 |
+
"""
|
237 |
+
len(y_true) == len(y_pred) or _raise(ValueError("y_true and y_pred must have the same length."))
|
238 |
+
return matching_dataset_lazy (
|
239 |
+
tuple(zip(y_true,y_pred)), thresh=thresh, criterion=criterion, by_image=by_image, show_progress=show_progress, parallel=parallel,
|
240 |
+
)
|
241 |
+
|
242 |
+
|
243 |
+
|
244 |
+
def matching_dataset_lazy(y_gen, thresh=0.5, criterion='iou', by_image=False, show_progress=True, parallel=False):
|
245 |
+
|
246 |
+
expected_keys = set(('fp', 'tp', 'fn', 'precision', 'recall', 'accuracy', 'f1', 'criterion', 'thresh', 'n_true', 'n_pred', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'))
|
247 |
+
|
248 |
+
single_thresh = False
|
249 |
+
if np.isscalar(thresh):
|
250 |
+
single_thresh = True
|
251 |
+
thresh = (thresh,)
|
252 |
+
|
253 |
+
tqdm_kwargs = {}
|
254 |
+
tqdm_kwargs['disable'] = not bool(show_progress)
|
255 |
+
if int(show_progress) > 1:
|
256 |
+
tqdm_kwargs['total'] = int(show_progress)
|
257 |
+
|
258 |
+
# compute matching stats for every pair of label images
|
259 |
+
if parallel:
|
260 |
+
from concurrent.futures import ThreadPoolExecutor
|
261 |
+
fn = lambda pair: matching(*pair, thresh=thresh, criterion=criterion, report_matches=False)
|
262 |
+
with ThreadPoolExecutor() as pool:
|
263 |
+
stats_all = tuple(pool.map(fn, tqdm(y_gen,**tqdm_kwargs)))
|
264 |
+
else:
|
265 |
+
stats_all = tuple (
|
266 |
+
matching(y_t, y_p, thresh=thresh, criterion=criterion, report_matches=False)
|
267 |
+
for y_t,y_p in tqdm(y_gen,**tqdm_kwargs)
|
268 |
+
)
|
269 |
+
|
270 |
+
# accumulate results over all images for each threshold separately
|
271 |
+
n_images, n_threshs = len(stats_all), len(thresh)
|
272 |
+
accumulate = [{} for _ in range(n_threshs)]
|
273 |
+
for stats in stats_all:
|
274 |
+
for i,s in enumerate(stats):
|
275 |
+
acc = accumulate[i]
|
276 |
+
for k,v in s._asdict().items():
|
277 |
+
if k == 'mean_true_score' and not bool(by_image):
|
278 |
+
# convert mean_true_score to "sum_matched_score"
|
279 |
+
acc[k] = acc.setdefault(k,0) + v * s.n_true
|
280 |
+
else:
|
281 |
+
try:
|
282 |
+
acc[k] = acc.setdefault(k,0) + v
|
283 |
+
except TypeError:
|
284 |
+
pass
|
285 |
+
|
286 |
+
# normalize/compute 'precision', 'recall', 'accuracy', 'f1'
|
287 |
+
for thr,acc in zip(thresh,accumulate):
|
288 |
+
set(acc.keys()) == expected_keys or _raise(ValueError("unexpected keys"))
|
289 |
+
acc['criterion'] = criterion
|
290 |
+
acc['thresh'] = thr
|
291 |
+
acc['by_image'] = bool(by_image)
|
292 |
+
if bool(by_image):
|
293 |
+
for k in ('precision', 'recall', 'accuracy', 'f1', 'mean_true_score', 'mean_matched_score', 'panoptic_quality'):
|
294 |
+
acc[k] /= n_images
|
295 |
+
else:
|
296 |
+
tp, fp, fn, n_true = acc['tp'], acc['fp'], acc['fn'], acc['n_true']
|
297 |
+
sum_matched_score = acc['mean_true_score']
|
298 |
+
|
299 |
+
mean_matched_score = _safe_divide(sum_matched_score, tp)
|
300 |
+
mean_true_score = _safe_divide(sum_matched_score, n_true)
|
301 |
+
panoptic_quality = _safe_divide(sum_matched_score, tp+fp/2+fn/2)
|
302 |
+
|
303 |
+
acc.update(
|
304 |
+
precision = precision(tp,fp,fn),
|
305 |
+
recall = recall(tp,fp,fn),
|
306 |
+
accuracy = accuracy(tp,fp,fn),
|
307 |
+
f1 = f1(tp,fp,fn),
|
308 |
+
mean_true_score = mean_true_score,
|
309 |
+
mean_matched_score = mean_matched_score,
|
310 |
+
panoptic_quality = panoptic_quality,
|
311 |
+
)
|
312 |
+
|
313 |
+
accumulate = tuple(namedtuple('DatasetMatching',acc.keys())(*acc.values()) for acc in accumulate)
|
314 |
+
return accumulate[0] if single_thresh else accumulate
|
315 |
+
|
316 |
+
|
317 |
+
|
318 |
+
# copied from scikit-image master for now (remove when part of a release)
|
319 |
+
def relabel_sequential(label_field, offset=1):
|
320 |
+
"""Relabel arbitrary labels to {`offset`, ... `offset` + number_of_labels}.
|
321 |
+
|
322 |
+
This function also returns the forward map (mapping the original labels to
|
323 |
+
the reduced labels) and the inverse map (mapping the reduced labels back
|
324 |
+
to the original ones).
|
325 |
+
|
326 |
+
Parameters
|
327 |
+
----------
|
328 |
+
label_field : numpy array of int, arbitrary shape
|
329 |
+
An array of labels, which must be non-negative integers.
|
330 |
+
offset : int, optional
|
331 |
+
The return labels will start at `offset`, which should be
|
332 |
+
strictly positive.
|
333 |
+
|
334 |
+
Returns
|
335 |
+
-------
|
336 |
+
relabeled : numpy array of int, same shape as `label_field`
|
337 |
+
The input label field with labels mapped to
|
338 |
+
{offset, ..., number_of_labels + offset - 1}.
|
339 |
+
The data type will be the same as `label_field`, except when
|
340 |
+
offset + number_of_labels causes overflow of the current data type.
|
341 |
+
forward_map : numpy array of int, shape ``(label_field.max() + 1,)``
|
342 |
+
The map from the original label space to the returned label
|
343 |
+
space. Can be used to re-apply the same mapping. See examples
|
344 |
+
for usage. The data type will be the same as `relabeled`.
|
345 |
+
inverse_map : 1D numpy array of int, of length offset + number of labels
|
346 |
+
The map from the new label space to the original space. This
|
347 |
+
can be used to reconstruct the original label field from the
|
348 |
+
relabeled one. The data type will be the same as `relabeled`.
|
349 |
+
|
350 |
+
Notes
|
351 |
+
-----
|
352 |
+
The label 0 is assumed to denote the background and is never remapped.
|
353 |
+
|
354 |
+
The forward map can be extremely big for some inputs, since its
|
355 |
+
length is given by the maximum of the label field. However, in most
|
356 |
+
situations, ``label_field.max()`` is much smaller than
|
357 |
+
``label_field.size``, and in these cases the forward map is
|
358 |
+
guaranteed to be smaller than either the input or output images.
|
359 |
+
|
360 |
+
Examples
|
361 |
+
--------
|
362 |
+
>>> from skimage.segmentation import relabel_sequential
|
363 |
+
>>> label_field = np.array([1, 1, 5, 5, 8, 99, 42])
|
364 |
+
>>> relab, fw, inv = relabel_sequential(label_field)
|
365 |
+
>>> relab
|
366 |
+
array([1, 1, 2, 2, 3, 5, 4])
|
367 |
+
>>> fw
|
368 |
+
array([0, 1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
369 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0,
|
370 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
371 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
372 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5])
|
373 |
+
>>> inv
|
374 |
+
array([ 0, 1, 5, 8, 42, 99])
|
375 |
+
>>> (fw[label_field] == relab).all()
|
376 |
+
True
|
377 |
+
>>> (inv[relab] == label_field).all()
|
378 |
+
True
|
379 |
+
>>> relab, fw, inv = relabel_sequential(label_field, offset=5)
|
380 |
+
>>> relab
|
381 |
+
array([5, 5, 6, 6, 7, 9, 8])
|
382 |
+
"""
|
383 |
+
offset = int(offset)
|
384 |
+
if offset <= 0:
|
385 |
+
raise ValueError("Offset must be strictly positive.")
|
386 |
+
if np.min(label_field) < 0:
|
387 |
+
raise ValueError("Cannot relabel array that contains negative values.")
|
388 |
+
max_label = int(label_field.max()) # Ensure max_label is an integer
|
389 |
+
if not np.issubdtype(label_field.dtype, np.integer):
|
390 |
+
new_type = np.min_scalar_type(max_label)
|
391 |
+
label_field = label_field.astype(new_type)
|
392 |
+
labels = np.unique(label_field)
|
393 |
+
labels0 = labels[labels != 0]
|
394 |
+
new_max_label = offset - 1 + len(labels0)
|
395 |
+
new_labels0 = np.arange(offset, new_max_label + 1)
|
396 |
+
output_type = label_field.dtype
|
397 |
+
required_type = np.min_scalar_type(new_max_label)
|
398 |
+
if np.dtype(required_type).itemsize > np.dtype(label_field.dtype).itemsize:
|
399 |
+
output_type = required_type
|
400 |
+
forward_map = np.zeros(max_label + 1, dtype=output_type)
|
401 |
+
forward_map[labels0] = new_labels0
|
402 |
+
inverse_map = np.zeros(new_max_label + 1, dtype=output_type)
|
403 |
+
inverse_map[offset:] = labels0
|
404 |
+
relabeled = forward_map[label_field]
|
405 |
+
return relabeled, forward_map, inverse_map
|
406 |
+
|
407 |
+
|
408 |
+
|
409 |
+
def group_matching_labels(ys, thresh=1e-10, criterion='iou'):
|
410 |
+
"""
|
411 |
+
Group matching objects (i.e. assign the same label id) in a
|
412 |
+
list of label images (e.g. consecutive frames of a time-lapse).
|
413 |
+
|
414 |
+
Uses function `matching` (with provided `criterion` and `thresh`) to
|
415 |
+
iteratively/greedily match and group objects/labels in consecutive images of `ys`.
|
416 |
+
To that end, matching objects are grouped together by assigning the same label id,
|
417 |
+
whereas unmatched objects are assigned a new label id.
|
418 |
+
At the end of this process, each label group will have been assigned a unique id.
|
419 |
+
|
420 |
+
Note that the label images `ys` will not be modified. Instead, they will initially
|
421 |
+
be duplicated and converted to data type `np.int32` before objects are grouped and the result
|
422 |
+
is returned. (Note that `np.int32` limits the number of label groups to at most 2147483647.)
|
423 |
+
|
424 |
+
Example
|
425 |
+
-------
|
426 |
+
import numpy as np
|
427 |
+
from stardist.data import test_image_nuclei_2d
|
428 |
+
from stardist.matching import group_matching_labels
|
429 |
+
|
430 |
+
_y = test_image_nuclei_2d(return_mask=True)[1]
|
431 |
+
labels = np.stack([_y, 2*np.roll(_y,10)], axis=0)
|
432 |
+
|
433 |
+
labels_new = group_matching_labels(labels)
|
434 |
+
|
435 |
+
Parameters
|
436 |
+
----------
|
437 |
+
ys : np.ndarray or list/tuple of np.ndarray
|
438 |
+
list/array of integer labels (2D or 3D)
|
439 |
+
|
440 |
+
"""
|
441 |
+
# check 'ys' without making a copy
|
442 |
+
len(ys) > 1 or _raise(ValueError("'ys' must have 2 or more entries"))
|
443 |
+
if isinstance(ys, np.ndarray):
|
444 |
+
_check_label_array(ys, 'ys')
|
445 |
+
ys.ndim > 1 or _raise(ValueError("'ys' must be at least 2-dimensional"))
|
446 |
+
ys_grouped = np.empty_like(ys, dtype=np.int32)
|
447 |
+
else:
|
448 |
+
all(_check_label_array(y, 'ys') for y in ys) or _raise(ValueError("'ys' must be a list of label images"))
|
449 |
+
all(y.shape==ys[0].shape for y in ys) or _raise(ValueError("all label images must have the same shape"))
|
450 |
+
ys_grouped = np.empty((len(ys),)+ys[0].shape, dtype=np.int32)
|
451 |
+
|
452 |
+
def _match_single(y_prev, y, next_id):
|
453 |
+
y = y.astype(np.int32, copy=False)
|
454 |
+
res = matching(y_prev, y, report_matches=True, thresh=thresh, criterion=criterion)
|
455 |
+
# relabel dict (for matching labels) that maps label ids from y -> y_prev
|
456 |
+
relabel = dict(reversed(res.matched_pairs[i]) for i in res.matched_tps)
|
457 |
+
y_grouped = np.zeros_like(y)
|
458 |
+
for r in regionprops(y):
|
459 |
+
m = (y[r.slice] == r.label)
|
460 |
+
if r.label in relabel:
|
461 |
+
y_grouped[r.slice][m] = relabel[r.label]
|
462 |
+
else:
|
463 |
+
y_grouped[r.slice][m] = next_id
|
464 |
+
next_id += 1
|
465 |
+
return y_grouped, next_id
|
466 |
+
|
467 |
+
ys_grouped[0] = ys[0]
|
468 |
+
next_id = ys_grouped[0].max() + 1
|
469 |
+
for i in range(len(ys)-1):
|
470 |
+
ys_grouped[i+1], next_id = _match_single(ys_grouped[i], ys[i+1], next_id)
|
471 |
+
return ys_grouped
|
472 |
+
|
473 |
+
|
474 |
+
|
475 |
+
def _shuffle_labels(y):
|
476 |
+
_check_label_array(y, 'y')
|
477 |
+
y2 = np.zeros_like(y)
|
478 |
+
ids = tuple(set(np.unique(y)) - {0})
|
479 |
+
relabel = dict(zip(ids,np.random.permutation(ids)))
|
480 |
+
for r in regionprops(y):
|
481 |
+
m = (y[r.slice] == r.label)
|
482 |
+
y2[r.slice][m] = relabel[r.label]
|
483 |
+
return y2
|
stardist_pkg/models/__init__.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, print_function
|
2 |
+
|
3 |
+
from .model2d import Config2D, StarDist2D, StarDistData2D
|
4 |
+
|
5 |
+
from csbdeep.utils import backend_channels_last
|
6 |
+
from csbdeep.utils.tf import keras_import
|
7 |
+
K = keras_import('backend')
|
8 |
+
if not backend_channels_last():
|
9 |
+
raise NotImplementedError(
|
10 |
+
"Keras is configured to use the '%s' image data format, which is currently not supported. "
|
11 |
+
"Please change it to use 'channels_last' instead: "
|
12 |
+
"https://keras.io/getting-started/faq/#where-is-the-keras-configuration-file-stored" % K.image_data_format()
|
13 |
+
)
|
14 |
+
del backend_channels_last, K
|
15 |
+
|
16 |
+
from csbdeep.models import register_model, register_aliases, clear_models_and_aliases
|
17 |
+
# register pre-trained models and aliases (TODO: replace with updatable solution)
|
18 |
+
clear_models_and_aliases(StarDist2D, StarDist3D)
|
19 |
+
register_model(StarDist2D, '2D_versatile_fluo', 'https://github.com/stardist/stardist-models/releases/download/v0.1/python_2D_versatile_fluo.zip', '8db40dacb5a1311b8d2c447ad934fb8a')
|
20 |
+
register_model(StarDist2D, '2D_versatile_he', 'https://github.com/stardist/stardist-models/releases/download/v0.1/python_2D_versatile_he.zip', 'bf34cb3c0e5b3435971e18d66778a4ec')
|
21 |
+
register_model(StarDist2D, '2D_paper_dsb2018', 'https://github.com/stardist/stardist-models/releases/download/v0.1/python_2D_paper_dsb2018.zip', '6287bf283f85c058ec3e7094b41039b5')
|
22 |
+
register_model(StarDist2D, '2D_demo', 'https://github.com/stardist/stardist-models/releases/download/v0.1/python_2D_demo.zip', '31f70402f58c50dd231ec31b4375ea2c')
|
23 |
+
|
24 |
+
register_aliases(StarDist2D, '2D_paper_dsb2018', 'DSB 2018 (from StarDist 2D paper)')
|
25 |
+
register_aliases(StarDist2D, '2D_versatile_fluo', 'Versatile (fluorescent nuclei)')
|
26 |
+
register_aliases(StarDist2D, '2D_versatile_he', 'Versatile (H&E nuclei)')
|
27 |
+
del register_model, register_aliases, clear_models_and_aliases
|
stardist_pkg/models/base.py
ADDED
@@ -0,0 +1,1196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import sys
|
5 |
+
import warnings
|
6 |
+
import math
|
7 |
+
from tqdm import tqdm
|
8 |
+
from collections import namedtuple
|
9 |
+
from pathlib import Path
|
10 |
+
import threading
|
11 |
+
import functools
|
12 |
+
import scipy.ndimage as ndi
|
13 |
+
import numbers
|
14 |
+
|
15 |
+
from csbdeep.models.base_model import BaseModel
|
16 |
+
from csbdeep.utils.tf import export_SavedModel, keras_import, IS_TF_1, CARETensorBoard
|
17 |
+
|
18 |
+
import tensorflow as tf
|
19 |
+
K = keras_import('backend')
|
20 |
+
Sequence = keras_import('utils', 'Sequence')
|
21 |
+
Adam = keras_import('optimizers', 'Adam')
|
22 |
+
ReduceLROnPlateau, TensorBoard = keras_import('callbacks', 'ReduceLROnPlateau', 'TensorBoard')
|
23 |
+
|
24 |
+
from csbdeep.utils import _raise, backend_channels_last, axes_check_and_normalize, axes_dict, load_json, save_json
|
25 |
+
from csbdeep.internals.predict import tile_iterator, total_n_tiles
|
26 |
+
from csbdeep.internals.train import RollingSequence
|
27 |
+
from csbdeep.data import Resizer
|
28 |
+
|
29 |
+
from ..sample_patches import get_valid_inds
|
30 |
+
from ..nms import _ind_prob_thresh
|
31 |
+
from ..utils import _is_power_of_2, _is_floatarray, optimize_threshold
|
32 |
+
|
33 |
+
# TODO: helper function to check if receptive field of cnn is sufficient for object sizes in GT
|
34 |
+
|
35 |
+
def generic_masked_loss(mask, loss, weights=1, norm_by_mask=True, reg_weight=0, reg_penalty=K.abs):
|
36 |
+
def _loss(y_true, y_pred):
|
37 |
+
actual_loss = K.mean(mask * weights * loss(y_true, y_pred), axis=-1)
|
38 |
+
norm_mask = (K.mean(mask) + K.epsilon()) if norm_by_mask else 1
|
39 |
+
if reg_weight > 0:
|
40 |
+
reg_loss = K.mean((1-mask) * reg_penalty(y_pred), axis=-1)
|
41 |
+
return actual_loss / norm_mask + reg_weight * reg_loss
|
42 |
+
else:
|
43 |
+
return actual_loss / norm_mask
|
44 |
+
return _loss
|
45 |
+
|
46 |
+
def masked_loss(mask, penalty, reg_weight, norm_by_mask):
|
47 |
+
loss = lambda y_true, y_pred: penalty(y_true - y_pred)
|
48 |
+
return generic_masked_loss(mask, loss, reg_weight=reg_weight, norm_by_mask=norm_by_mask)
|
49 |
+
|
50 |
+
# TODO: should we use norm_by_mask=True in the loss or only in a metric?
|
51 |
+
# previous 2D behavior was norm_by_mask=False
|
52 |
+
# same question for reg_weight? use 1e-4 (as in 3D) or 0 (as in 2D)?
|
53 |
+
|
54 |
+
def masked_loss_mae(mask, reg_weight=0, norm_by_mask=True):
|
55 |
+
return masked_loss(mask, K.abs, reg_weight=reg_weight, norm_by_mask=norm_by_mask)
|
56 |
+
|
57 |
+
def masked_loss_mse(mask, reg_weight=0, norm_by_mask=True):
|
58 |
+
return masked_loss(mask, K.square, reg_weight=reg_weight, norm_by_mask=norm_by_mask)
|
59 |
+
|
60 |
+
def masked_metric_mae(mask):
|
61 |
+
def relevant_mae(y_true, y_pred):
|
62 |
+
return masked_loss(mask, K.abs, reg_weight=0, norm_by_mask=True)(y_true, y_pred)
|
63 |
+
return relevant_mae
|
64 |
+
|
65 |
+
def masked_metric_mse(mask):
|
66 |
+
def relevant_mse(y_true, y_pred):
|
67 |
+
return masked_loss(mask, K.square, reg_weight=0, norm_by_mask=True)(y_true, y_pred)
|
68 |
+
return relevant_mse
|
69 |
+
|
70 |
+
def kld(y_true, y_pred):
|
71 |
+
y_true = K.clip(y_true, K.epsilon(), 1)
|
72 |
+
y_pred = K.clip(y_pred, K.epsilon(), 1)
|
73 |
+
return K.mean(K.binary_crossentropy(y_true, y_pred) - K.binary_crossentropy(y_true, y_true), axis=-1)
|
74 |
+
|
75 |
+
|
76 |
+
def masked_loss_iou(mask, reg_weight=0, norm_by_mask=True):
|
77 |
+
def iou_loss(y_true, y_pred):
|
78 |
+
axis = -1 if backend_channels_last() else 1
|
79 |
+
# y_pred can be negative (since not constrained) -> 'inter' can be very large for y_pred << 0
|
80 |
+
# - clipping y_pred values at 0 can lead to vanishing gradients
|
81 |
+
# - 'K.sign(y_pred)' term fixes issue by enforcing that y_pred values >= 0 always lead to larger 'inter' (lower loss)
|
82 |
+
inter = K.mean(K.sign(y_pred)*K.square(K.minimum(y_true,y_pred)), axis=axis)
|
83 |
+
union = K.mean(K.square(K.maximum(y_true,y_pred)), axis=axis)
|
84 |
+
iou = inter/(union+K.epsilon())
|
85 |
+
iou = K.expand_dims(iou,axis)
|
86 |
+
loss = 1. - iou # + 0.005*K.abs(y_true-y_pred)
|
87 |
+
return loss
|
88 |
+
return generic_masked_loss(mask, iou_loss, reg_weight=reg_weight, norm_by_mask=norm_by_mask)
|
89 |
+
|
90 |
+
def masked_metric_iou(mask, reg_weight=0, norm_by_mask=True):
|
91 |
+
def iou_metric(y_true, y_pred):
|
92 |
+
axis = -1 if backend_channels_last() else 1
|
93 |
+
y_pred = K.maximum(0., y_pred)
|
94 |
+
inter = K.mean(K.square(K.minimum(y_true,y_pred)), axis=axis)
|
95 |
+
union = K.mean(K.square(K.maximum(y_true,y_pred)), axis=axis)
|
96 |
+
iou = inter/(union+K.epsilon())
|
97 |
+
loss = K.expand_dims(iou,axis)
|
98 |
+
return loss
|
99 |
+
return generic_masked_loss(mask, iou_metric, reg_weight=reg_weight, norm_by_mask=norm_by_mask)
|
100 |
+
|
101 |
+
|
102 |
+
def weighted_categorical_crossentropy(weights, ndim):
|
103 |
+
""" ndim = (2,3) """
|
104 |
+
|
105 |
+
axis = -1 if backend_channels_last() else 1
|
106 |
+
shape = [1]*(ndim+2)
|
107 |
+
shape[axis] = len(weights)
|
108 |
+
weights = np.broadcast_to(weights, shape)
|
109 |
+
weights = K.constant(weights)
|
110 |
+
|
111 |
+
def weighted_cce(y_true, y_pred):
|
112 |
+
# ignore pixels that have y_true (prob_class) < 0
|
113 |
+
mask = K.cast(y_true>=0, K.floatx())
|
114 |
+
y_pred /= K.sum(y_pred+K.epsilon(), axis=axis, keepdims=True)
|
115 |
+
y_pred = K.clip(y_pred, K.epsilon(), 1. - K.epsilon())
|
116 |
+
loss = - K.sum(weights*mask*y_true*K.log(y_pred), axis = axis)
|
117 |
+
return loss
|
118 |
+
|
119 |
+
return weighted_cce
|
120 |
+
|
121 |
+
|
122 |
+
class StarDistDataBase(RollingSequence):
|
123 |
+
|
124 |
+
def __init__(self, X, Y, n_rays, grid, batch_size, patch_size, length,
|
125 |
+
n_classes=None, classes=None,
|
126 |
+
use_gpu=False, sample_ind_cache=True, maxfilter_patch_size=None, augmenter=None, foreground_prob=0):
|
127 |
+
|
128 |
+
super().__init__(data_size=len(X), batch_size=batch_size, length=length, shuffle=True)
|
129 |
+
|
130 |
+
if isinstance(X, (np.ndarray, tuple, list)):
|
131 |
+
X = [x.astype(np.float32, copy=False) for x in X]
|
132 |
+
|
133 |
+
# sanity checks
|
134 |
+
len(X)==len(Y) and len(X)>0 or _raise(ValueError("X and Y can't be empty and must have same length"))
|
135 |
+
|
136 |
+
if classes is None:
|
137 |
+
# set classes to None for all images (i.e. defaults to every object instance assigned the same class)
|
138 |
+
classes = (None,)*len(X)
|
139 |
+
else:
|
140 |
+
n_classes is not None or warnings.warn("Ignoring classes since n_classes is None")
|
141 |
+
|
142 |
+
len(classes)==len(X) or _raise(ValueError("X and classes must have same length"))
|
143 |
+
|
144 |
+
self.n_classes, self.classes = n_classes, classes
|
145 |
+
|
146 |
+
nD = len(patch_size)
|
147 |
+
assert nD in (2,3)
|
148 |
+
x_ndim = X[0].ndim
|
149 |
+
assert x_ndim in (nD,nD+1)
|
150 |
+
|
151 |
+
if isinstance(X, (np.ndarray, tuple, list)) and \
|
152 |
+
isinstance(Y, (np.ndarray, tuple, list)):
|
153 |
+
all(y.ndim==nD and x.ndim==x_ndim and x.shape[:nD]==y.shape for x,y in zip(X,Y)) or _raise(ValueError("images and masks should have corresponding shapes/dimensions"))
|
154 |
+
all(x.shape[:nD]>=tuple(patch_size) for x in X) or _raise(ValueError("Some images are too small for given patch_size {patch_size}".format(patch_size=patch_size)))
|
155 |
+
|
156 |
+
if x_ndim == nD:
|
157 |
+
self.n_channel = None
|
158 |
+
else:
|
159 |
+
self.n_channel = X[0].shape[-1]
|
160 |
+
if isinstance(X, (np.ndarray, tuple, list)):
|
161 |
+
assert all(x.shape[-1]==self.n_channel for x in X)
|
162 |
+
|
163 |
+
assert 0 <= foreground_prob <= 1
|
164 |
+
|
165 |
+
self.X, self.Y = X, Y
|
166 |
+
# self.batch_size = batch_size
|
167 |
+
self.n_rays = n_rays
|
168 |
+
self.patch_size = patch_size
|
169 |
+
self.ss_grid = (slice(None),) + tuple(slice(0, None, g) for g in grid)
|
170 |
+
self.grid = tuple(grid)
|
171 |
+
self.use_gpu = bool(use_gpu)
|
172 |
+
if augmenter is None:
|
173 |
+
augmenter = lambda *args: args
|
174 |
+
callable(augmenter) or _raise(ValueError("augmenter must be None or callable"))
|
175 |
+
self.augmenter = augmenter
|
176 |
+
self.foreground_prob = foreground_prob
|
177 |
+
|
178 |
+
if self.use_gpu:
|
179 |
+
from gputools import max_filter
|
180 |
+
self.max_filter = lambda y, patch_size: max_filter(y.astype(np.float32), patch_size)
|
181 |
+
else:
|
182 |
+
from scipy.ndimage.filters import maximum_filter
|
183 |
+
self.max_filter = lambda y, patch_size: maximum_filter(y, patch_size, mode='constant')
|
184 |
+
|
185 |
+
self.maxfilter_patch_size = maxfilter_patch_size if maxfilter_patch_size is not None else self.patch_size
|
186 |
+
|
187 |
+
self.sample_ind_cache = sample_ind_cache
|
188 |
+
self._ind_cache_fg = {}
|
189 |
+
self._ind_cache_all = {}
|
190 |
+
self.lock = threading.Lock()
|
191 |
+
|
192 |
+
|
193 |
+
def get_valid_inds(self, k, foreground_prob=None):
|
194 |
+
if foreground_prob is None:
|
195 |
+
foreground_prob = self.foreground_prob
|
196 |
+
foreground_only = np.random.uniform() < foreground_prob
|
197 |
+
_ind_cache = self._ind_cache_fg if foreground_only else self._ind_cache_all
|
198 |
+
if k in _ind_cache:
|
199 |
+
inds = _ind_cache[k]
|
200 |
+
else:
|
201 |
+
patch_filter = (lambda y,p: self.max_filter(y, self.maxfilter_patch_size) > 0) if foreground_only else None
|
202 |
+
inds = get_valid_inds(self.Y[k], self.patch_size, patch_filter=patch_filter)
|
203 |
+
if self.sample_ind_cache:
|
204 |
+
with self.lock:
|
205 |
+
_ind_cache[k] = inds
|
206 |
+
if foreground_only and len(inds[0])==0:
|
207 |
+
# no foreground pixels available
|
208 |
+
return self.get_valid_inds(k, foreground_prob=0)
|
209 |
+
return inds
|
210 |
+
|
211 |
+
|
212 |
+
def channels_as_tuple(self, x):
|
213 |
+
if self.n_channel is None:
|
214 |
+
return (x,)
|
215 |
+
else:
|
216 |
+
return tuple(x[...,i] for i in range(self.n_channel))
|
217 |
+
|
218 |
+
|
219 |
+
|
220 |
+
class StarDistBase(BaseModel):
|
221 |
+
|
222 |
+
def __init__(self, config, name=None, basedir='.'):
|
223 |
+
super().__init__(config=config, name=name, basedir=basedir)
|
224 |
+
threshs = dict(prob=None, nms=None)
|
225 |
+
if basedir is not None:
|
226 |
+
try:
|
227 |
+
threshs = load_json(str(self.logdir / 'thresholds.json'))
|
228 |
+
print("Loading thresholds from 'thresholds.json'.")
|
229 |
+
if threshs.get('prob') is None or not (0 < threshs.get('prob') < 1):
|
230 |
+
print("- Invalid 'prob' threshold (%s), using default value." % str(threshs.get('prob')))
|
231 |
+
threshs['prob'] = None
|
232 |
+
if threshs.get('nms') is None or not (0 < threshs.get('nms') < 1):
|
233 |
+
print("- Invalid 'nms' threshold (%s), using default value." % str(threshs.get('nms')))
|
234 |
+
threshs['nms'] = None
|
235 |
+
except FileNotFoundError:
|
236 |
+
if config is None and len(tuple(self.logdir.glob('*.h5'))) > 0:
|
237 |
+
print("Couldn't load thresholds from 'thresholds.json', using default values. "
|
238 |
+
"(Call 'optimize_thresholds' to change that.)")
|
239 |
+
|
240 |
+
self.thresholds = dict (
|
241 |
+
prob = 0.5 if threshs['prob'] is None else threshs['prob'],
|
242 |
+
nms = 0.4 if threshs['nms'] is None else threshs['nms'],
|
243 |
+
)
|
244 |
+
print("Using default values: prob_thresh={prob:g}, nms_thresh={nms:g}.".format(prob=self.thresholds.prob, nms=self.thresholds.nms))
|
245 |
+
|
246 |
+
|
247 |
+
@property
|
248 |
+
def thresholds(self):
|
249 |
+
return self._thresholds
|
250 |
+
|
251 |
+
def _is_multiclass(self):
|
252 |
+
return (self.config.n_classes is not None)
|
253 |
+
|
254 |
+
def _parse_classes_arg(self, classes, length):
|
255 |
+
""" creates a proper classes tuple from different possible "classes" arguments in model.train()
|
256 |
+
|
257 |
+
classes can be
|
258 |
+
"auto" -> all objects will be assigned to the first foreground class (unless n_classes is None)
|
259 |
+
single integer -> all objects will be assigned that class
|
260 |
+
tuple, list, ndarray -> do nothing (needs to be of given length)
|
261 |
+
|
262 |
+
returns a tuple of given length
|
263 |
+
"""
|
264 |
+
if isinstance(classes, str):
|
265 |
+
classes == "auto" or _raise(ValueError(f"classes = '{classes}': only 'auto' supported as string argument for classes"))
|
266 |
+
if self.config.n_classes is None:
|
267 |
+
classes = None
|
268 |
+
elif self.config.n_classes == 1:
|
269 |
+
classes = (1,)*length
|
270 |
+
else:
|
271 |
+
raise ValueError("using classes = 'auto' for n_classes > 1 not supported")
|
272 |
+
elif isinstance(classes, (tuple, list, np.ndarray)):
|
273 |
+
len(classes) == length or _raise(ValueError(f"len(classes) should be {length}!"))
|
274 |
+
else:
|
275 |
+
raise ValueError("classes should either be 'auto' or a list of scalars/label dicts")
|
276 |
+
return classes
|
277 |
+
|
278 |
+
@thresholds.setter
|
279 |
+
def thresholds(self, d):
|
280 |
+
self._thresholds = namedtuple('Thresholds',d.keys())(*d.values())
|
281 |
+
|
282 |
+
|
283 |
+
def prepare_for_training(self, optimizer=None):
|
284 |
+
"""Prepare for neural network training.
|
285 |
+
|
286 |
+
Compiles the model and creates
|
287 |
+
`Keras Callbacks <https://keras.io/callbacks/>`_ to be used for training.
|
288 |
+
|
289 |
+
Note that this method will be implicitly called once by :func:`train`
|
290 |
+
(with default arguments) if not done so explicitly beforehand.
|
291 |
+
|
292 |
+
Parameters
|
293 |
+
----------
|
294 |
+
optimizer : obj or None
|
295 |
+
Instance of a `Keras Optimizer <https://keras.io/optimizers/>`_ to be used for training.
|
296 |
+
If ``None`` (default), uses ``Adam`` with the learning rate specified in ``config``.
|
297 |
+
|
298 |
+
"""
|
299 |
+
if optimizer is None:
|
300 |
+
optimizer = Adam(self.config.train_learning_rate)
|
301 |
+
|
302 |
+
masked_dist_loss = {'mse': masked_loss_mse,
|
303 |
+
'mae': masked_loss_mae,
|
304 |
+
'iou': masked_loss_iou,
|
305 |
+
}[self.config.train_dist_loss]
|
306 |
+
prob_loss = 'binary_crossentropy'
|
307 |
+
|
308 |
+
|
309 |
+
def split_dist_true_mask(dist_true_mask):
|
310 |
+
return tf.split(dist_true_mask, num_or_size_splits=[self.config.n_rays,-1], axis=-1)
|
311 |
+
|
312 |
+
def dist_loss(dist_true_mask, dist_pred):
|
313 |
+
dist_true, dist_mask = split_dist_true_mask(dist_true_mask)
|
314 |
+
return masked_dist_loss(dist_mask, reg_weight=self.config.train_background_reg)(dist_true, dist_pred)
|
315 |
+
|
316 |
+
def dist_iou_metric(dist_true_mask, dist_pred):
|
317 |
+
dist_true, dist_mask = split_dist_true_mask(dist_true_mask)
|
318 |
+
return masked_metric_iou(dist_mask, reg_weight=0)(dist_true, dist_pred)
|
319 |
+
|
320 |
+
def relevant_mae(dist_true_mask, dist_pred):
|
321 |
+
dist_true, dist_mask = split_dist_true_mask(dist_true_mask)
|
322 |
+
return masked_metric_mae(dist_mask)(dist_true, dist_pred)
|
323 |
+
|
324 |
+
def relevant_mse(dist_true_mask, dist_pred):
|
325 |
+
dist_true, dist_mask = split_dist_true_mask(dist_true_mask)
|
326 |
+
return masked_metric_mse(dist_mask)(dist_true, dist_pred)
|
327 |
+
|
328 |
+
|
329 |
+
if self._is_multiclass():
|
330 |
+
prob_class_loss = weighted_categorical_crossentropy(self.config.train_class_weights, ndim=self.config.n_dim)
|
331 |
+
loss = [prob_loss, dist_loss, prob_class_loss]
|
332 |
+
else:
|
333 |
+
loss = [prob_loss, dist_loss]
|
334 |
+
|
335 |
+
self.keras_model.compile(optimizer, loss = loss,
|
336 |
+
loss_weights = list(self.config.train_loss_weights),
|
337 |
+
metrics = {'prob': kld,
|
338 |
+
'dist': [relevant_mae, relevant_mse, dist_iou_metric]})
|
339 |
+
|
340 |
+
self.callbacks = []
|
341 |
+
if self.basedir is not None:
|
342 |
+
self.callbacks += self._checkpoint_callbacks()
|
343 |
+
|
344 |
+
if self.config.train_tensorboard:
|
345 |
+
if IS_TF_1:
|
346 |
+
self.callbacks.append(CARETensorBoard(log_dir=str(self.logdir), prefix_with_timestamp=False, n_images=3, write_images=True, prob_out=False))
|
347 |
+
else:
|
348 |
+
self.callbacks.append(TensorBoard(log_dir=str(self.logdir/'logs'), write_graph=False, profile_batch=0))
|
349 |
+
|
350 |
+
if self.config.train_reduce_lr is not None:
|
351 |
+
rlrop_params = self.config.train_reduce_lr
|
352 |
+
if 'verbose' not in rlrop_params:
|
353 |
+
rlrop_params['verbose'] = True
|
354 |
+
# TF2: add as first callback to put 'lr' in the logs for TensorBoard
|
355 |
+
self.callbacks.insert(0,ReduceLROnPlateau(**rlrop_params))
|
356 |
+
|
357 |
+
self._model_prepared = True
|
358 |
+
|
359 |
+
|
360 |
+
def _predict_setup(self, img, axes, normalizer, n_tiles, show_tile_progress, predict_kwargs):
|
361 |
+
""" Shared setup code between `predict` and `predict_sparse` """
|
362 |
+
if n_tiles is None:
|
363 |
+
n_tiles = [1]*img.ndim
|
364 |
+
try:
|
365 |
+
n_tiles = tuple(n_tiles)
|
366 |
+
img.ndim == len(n_tiles) or _raise(TypeError())
|
367 |
+
except TypeError:
|
368 |
+
raise ValueError("n_tiles must be an iterable of length %d" % img.ndim)
|
369 |
+
all(np.isscalar(t) and 1<=t and int(t)==t for t in n_tiles) or _raise(
|
370 |
+
ValueError("all values of n_tiles must be integer values >= 1"))
|
371 |
+
|
372 |
+
n_tiles = tuple(map(int,n_tiles))
|
373 |
+
|
374 |
+
axes = self._normalize_axes(img, axes)
|
375 |
+
axes_net = self.config.axes
|
376 |
+
|
377 |
+
_permute_axes = self._make_permute_axes(axes, axes_net)
|
378 |
+
x = _permute_axes(img) # x has axes_net semantics
|
379 |
+
|
380 |
+
channel = axes_dict(axes_net)['C']
|
381 |
+
self.config.n_channel_in == x.shape[channel] or _raise(ValueError())
|
382 |
+
axes_net_div_by = self._axes_div_by(axes_net)
|
383 |
+
|
384 |
+
grid = tuple(self.config.grid)
|
385 |
+
len(grid) == len(axes_net)-1 or _raise(ValueError())
|
386 |
+
grid_dict = dict(zip(axes_net.replace('C',''),grid))
|
387 |
+
|
388 |
+
normalizer = self._check_normalizer_resizer(normalizer, None)[0]
|
389 |
+
resizer = StarDistPadAndCropResizer(grid=grid_dict)
|
390 |
+
|
391 |
+
x = normalizer.before(x, axes_net)
|
392 |
+
x = resizer.before(x, axes_net, axes_net_div_by)
|
393 |
+
|
394 |
+
if not _is_floatarray(x):
|
395 |
+
warnings.warn("Predicting on non-float input... ( forgot to normalize? )")
|
396 |
+
|
397 |
+
def predict_direct(x):
|
398 |
+
ys = self.keras_model.predict(x[np.newaxis], **predict_kwargs)
|
399 |
+
return tuple(y[0] for y in ys)
|
400 |
+
|
401 |
+
def tiling_setup():
|
402 |
+
assert np.prod(n_tiles) > 1
|
403 |
+
tiling_axes = axes_net.replace('C','') # axes eligible for tiling
|
404 |
+
x_tiling_axis = tuple(axes_dict(axes_net)[a] for a in tiling_axes) # numerical axis ids for x
|
405 |
+
axes_net_tile_overlaps = self._axes_tile_overlap(axes_net)
|
406 |
+
# hack: permute tiling axis in the same way as img -> x was permuted
|
407 |
+
_n_tiles = _permute_axes(np.empty(n_tiles,bool)).shape
|
408 |
+
(all(_n_tiles[i] == 1 for i in range(x.ndim) if i not in x_tiling_axis) or
|
409 |
+
_raise(ValueError("entry of n_tiles > 1 only allowed for axes '%s'" % tiling_axes)))
|
410 |
+
|
411 |
+
sh = [s//grid_dict.get(a,1) for a,s in zip(axes_net,x.shape)]
|
412 |
+
sh[channel] = None
|
413 |
+
def create_empty_output(n_channel, dtype=np.float32):
|
414 |
+
sh[channel] = n_channel
|
415 |
+
return np.empty(sh,dtype)
|
416 |
+
|
417 |
+
if callable(show_tile_progress):
|
418 |
+
progress, _show_tile_progress = show_tile_progress, True
|
419 |
+
else:
|
420 |
+
progress, _show_tile_progress = tqdm, show_tile_progress
|
421 |
+
|
422 |
+
n_block_overlaps = [int(np.ceil(overlap/blocksize)) for overlap, blocksize
|
423 |
+
in zip(axes_net_tile_overlaps, axes_net_div_by)]
|
424 |
+
|
425 |
+
num_tiles_used = total_n_tiles(x, _n_tiles, block_sizes=axes_net_div_by, n_block_overlaps=n_block_overlaps)
|
426 |
+
|
427 |
+
tile_generator = progress(tile_iterator(x, _n_tiles, block_sizes=axes_net_div_by, n_block_overlaps=n_block_overlaps),
|
428 |
+
disable=(not _show_tile_progress), total=num_tiles_used)
|
429 |
+
|
430 |
+
return tile_generator, tuple(sh), create_empty_output
|
431 |
+
|
432 |
+
return x, axes, axes_net, axes_net_div_by, _permute_axes, resizer, n_tiles, grid, grid_dict, channel, predict_direct, tiling_setup
|
433 |
+
|
434 |
+
|
435 |
+
def _predict_generator(self, img, axes=None, normalizer=None, n_tiles=None, show_tile_progress=True, **predict_kwargs):
|
436 |
+
"""Predict.
|
437 |
+
|
438 |
+
Parameters
|
439 |
+
----------
|
440 |
+
img : :class:`numpy.ndarray`
|
441 |
+
Input image
|
442 |
+
axes : str or None
|
443 |
+
Axes of the input ``img``.
|
444 |
+
``None`` denotes that axes of img are the same as denoted in the config.
|
445 |
+
normalizer : :class:`csbdeep.data.Normalizer` or None
|
446 |
+
(Optional) normalization of input image before prediction.
|
447 |
+
Note that the default (``None``) assumes ``img`` to be already normalized.
|
448 |
+
n_tiles : iterable or None
|
449 |
+
Out of memory (OOM) errors can occur if the input image is too large.
|
450 |
+
To avoid this problem, the input image is broken up into (overlapping) tiles
|
451 |
+
that are processed independently and re-assembled.
|
452 |
+
This parameter denotes a tuple of the number of tiles for every image axis (see ``axes``).
|
453 |
+
``None`` denotes that no tiling should be used.
|
454 |
+
show_tile_progress: bool or callable
|
455 |
+
If boolean, indicates whether to show progress (via tqdm) during tiled prediction.
|
456 |
+
If callable, must be a drop-in replacement for tqdm.
|
457 |
+
show_tile_progress: bool
|
458 |
+
Whether to show progress during tiled prediction.
|
459 |
+
predict_kwargs: dict
|
460 |
+
Keyword arguments for ``predict`` function of Keras model.
|
461 |
+
|
462 |
+
Returns
|
463 |
+
-------
|
464 |
+
(:class:`numpy.ndarray`, :class:`numpy.ndarray`, [:class:`numpy.ndarray`])
|
465 |
+
Returns the tuple (`prob`, `dist`, [`prob_class`]) of per-pixel object probabilities and star-convex polygon/polyhedra distances.
|
466 |
+
In multiclass prediction mode, `prob_class` is the probability map for each of the 1+'n_classes' classes (first class is background)
|
467 |
+
|
468 |
+
"""
|
469 |
+
|
470 |
+
x, axes, axes_net, axes_net_div_by, _permute_axes, resizer, n_tiles, grid, grid_dict, channel, predict_direct, tiling_setup = \
|
471 |
+
self._predict_setup(img, axes, normalizer, n_tiles, show_tile_progress, predict_kwargs)
|
472 |
+
|
473 |
+
if np.prod(n_tiles) > 1:
|
474 |
+
tile_generator, output_shape, create_empty_output = tiling_setup()
|
475 |
+
|
476 |
+
prob = create_empty_output(1)
|
477 |
+
dist = create_empty_output(self.config.n_rays)
|
478 |
+
if self._is_multiclass():
|
479 |
+
prob_class = create_empty_output(self.config.n_classes+1)
|
480 |
+
result = (prob, dist, prob_class)
|
481 |
+
else:
|
482 |
+
result = (prob, dist)
|
483 |
+
|
484 |
+
for tile, s_src, s_dst in tile_generator:
|
485 |
+
# predict_direct -> prob, dist, [prob_class if multi_class]
|
486 |
+
result_tile = predict_direct(tile)
|
487 |
+
# account for grid
|
488 |
+
s_src = [slice(s.start//grid_dict.get(a,1),s.stop//grid_dict.get(a,1)) for s,a in zip(s_src,axes_net)]
|
489 |
+
s_dst = [slice(s.start//grid_dict.get(a,1),s.stop//grid_dict.get(a,1)) for s,a in zip(s_dst,axes_net)]
|
490 |
+
# prob and dist have different channel dimensionality than image x
|
491 |
+
s_src[channel] = slice(None)
|
492 |
+
s_dst[channel] = slice(None)
|
493 |
+
s_src, s_dst = tuple(s_src), tuple(s_dst)
|
494 |
+
# print(s_src,s_dst)
|
495 |
+
for part, part_tile in zip(result, result_tile):
|
496 |
+
part[s_dst] = part_tile[s_src]
|
497 |
+
yield # yield None after each processed tile
|
498 |
+
else:
|
499 |
+
# predict_direct -> prob, dist, [prob_class if multi_class]
|
500 |
+
result = predict_direct(x)
|
501 |
+
|
502 |
+
result = [resizer.after(part, axes_net) for part in result]
|
503 |
+
|
504 |
+
# result = (prob, dist) for legacy or (prob, dist, prob_class) for multiclass
|
505 |
+
|
506 |
+
# prob
|
507 |
+
result[0] = np.take(result[0],0,axis=channel)
|
508 |
+
# dist
|
509 |
+
result[1] = np.maximum(1e-3, result[1]) # avoid small dist values to prevent problems with Qhull
|
510 |
+
result[1] = np.moveaxis(result[1],channel,-1)
|
511 |
+
|
512 |
+
if self._is_multiclass():
|
513 |
+
# prob_class
|
514 |
+
result[2] = np.moveaxis(result[2],channel,-1)
|
515 |
+
|
516 |
+
# last "yield" is the actual output that would have been "return"ed if this was a regular function
|
517 |
+
yield tuple(result)
|
518 |
+
|
519 |
+
|
520 |
+
@functools.wraps(_predict_generator)
|
521 |
+
def predict(self, *args, **kwargs):
|
522 |
+
# return last "yield"ed value of generator
|
523 |
+
r = None
|
524 |
+
for r in self._predict_generator(*args, **kwargs):
|
525 |
+
pass
|
526 |
+
return r
|
527 |
+
|
528 |
+
|
529 |
+
def _predict_sparse_generator(self, img, prob_thresh=None, axes=None, normalizer=None, n_tiles=None, show_tile_progress=True, b=2, **predict_kwargs):
|
530 |
+
""" Sparse version of model.predict()
|
531 |
+
Returns
|
532 |
+
-------
|
533 |
+
(prob, dist, [prob_class], points) flat list of probs, dists, (optional prob_class) and points
|
534 |
+
"""
|
535 |
+
if prob_thresh is None: prob_thresh = self.thresholds.prob
|
536 |
+
|
537 |
+
x, axes, axes_net, axes_net_div_by, _permute_axes, resizer, n_tiles, grid, grid_dict, channel, predict_direct, tiling_setup = \
|
538 |
+
self._predict_setup(img, axes, normalizer, n_tiles, show_tile_progress, predict_kwargs)
|
539 |
+
|
540 |
+
def _prep(prob, dist):
|
541 |
+
prob = np.take(prob,0,axis=channel)
|
542 |
+
dist = np.moveaxis(dist,channel,-1)
|
543 |
+
dist = np.maximum(1e-3, dist)
|
544 |
+
return prob, dist
|
545 |
+
|
546 |
+
proba, dista, pointsa, prob_class = [],[],[], []
|
547 |
+
|
548 |
+
if np.prod(n_tiles) > 1:
|
549 |
+
tile_generator, output_shape, create_empty_output = tiling_setup()
|
550 |
+
|
551 |
+
sh = list(output_shape)
|
552 |
+
sh[channel] = 1;
|
553 |
+
|
554 |
+
proba, dista, pointsa, prob_classa = [], [], [], []
|
555 |
+
|
556 |
+
for tile, s_src, s_dst in tile_generator:
|
557 |
+
|
558 |
+
results_tile = predict_direct(tile)
|
559 |
+
|
560 |
+
# account for grid
|
561 |
+
s_src = [slice(s.start//grid_dict.get(a,1),s.stop//grid_dict.get(a,1)) for s,a in zip(s_src,axes_net)]
|
562 |
+
s_dst = [slice(s.start//grid_dict.get(a,1),s.stop//grid_dict.get(a,1)) for s,a in zip(s_dst,axes_net)]
|
563 |
+
s_src[channel] = slice(None)
|
564 |
+
s_dst[channel] = slice(None)
|
565 |
+
s_src, s_dst = tuple(s_src), tuple(s_dst)
|
566 |
+
|
567 |
+
prob_tile, dist_tile = results_tile[:2]
|
568 |
+
prob_tile, dist_tile = _prep(prob_tile[s_src], dist_tile[s_src])
|
569 |
+
|
570 |
+
bs = list((b if s.start==0 else -1, b if s.stop==_sh else -1) for s,_sh in zip(s_dst, sh))
|
571 |
+
bs.pop(channel)
|
572 |
+
inds = _ind_prob_thresh(prob_tile, prob_thresh, b=bs)
|
573 |
+
proba.extend(prob_tile[inds].copy())
|
574 |
+
dista.extend(dist_tile[inds].copy())
|
575 |
+
_points = np.stack(np.where(inds), axis=1)
|
576 |
+
offset = list(s.start for i,s in enumerate(s_dst))
|
577 |
+
offset.pop(channel)
|
578 |
+
_points = _points + np.array(offset).reshape((1,len(offset)))
|
579 |
+
_points = _points * np.array(self.config.grid).reshape((1,len(self.config.grid)))
|
580 |
+
pointsa.extend(_points)
|
581 |
+
|
582 |
+
if self._is_multiclass():
|
583 |
+
p = results_tile[2][s_src].copy()
|
584 |
+
p = np.moveaxis(p,channel,-1)
|
585 |
+
prob_classa.extend(p[inds])
|
586 |
+
yield # yield None after each processed tile
|
587 |
+
|
588 |
+
else:
|
589 |
+
# predict_direct -> prob, dist, [prob_class if multi_class]
|
590 |
+
results = predict_direct(x)
|
591 |
+
prob, dist = results[:2]
|
592 |
+
prob, dist = _prep(prob, dist)
|
593 |
+
inds = _ind_prob_thresh(prob, prob_thresh, b=b)
|
594 |
+
proba = prob[inds].copy()
|
595 |
+
dista = dist[inds].copy()
|
596 |
+
_points = np.stack(np.where(inds), axis=1)
|
597 |
+
pointsa = (_points * np.array(self.config.grid).reshape((1,len(self.config.grid))))
|
598 |
+
|
599 |
+
if self._is_multiclass():
|
600 |
+
p = np.moveaxis(results[2],channel,-1)
|
601 |
+
prob_classa = p[inds].copy()
|
602 |
+
|
603 |
+
|
604 |
+
proba = np.asarray(proba)
|
605 |
+
dista = np.asarray(dista).reshape((-1,self.config.n_rays))
|
606 |
+
pointsa = np.asarray(pointsa).reshape((-1,self.config.n_dim))
|
607 |
+
|
608 |
+
idx = resizer.filter_points(x.ndim, pointsa, axes_net)
|
609 |
+
proba = proba[idx]
|
610 |
+
dista = dista[idx]
|
611 |
+
pointsa = pointsa[idx]
|
612 |
+
|
613 |
+
# last "yield" is the actual output that would have been "return"ed if this was a regular function
|
614 |
+
if self._is_multiclass():
|
615 |
+
prob_classa = np.asarray(prob_classa).reshape((-1,self.config.n_classes+1))
|
616 |
+
prob_classa = prob_classa[idx]
|
617 |
+
yield proba, dista, prob_classa, pointsa
|
618 |
+
else:
|
619 |
+
prob_classa = None
|
620 |
+
yield proba, dista, pointsa
|
621 |
+
|
622 |
+
|
623 |
+
@functools.wraps(_predict_sparse_generator)
|
624 |
+
def predict_sparse(self, *args, **kwargs):
|
625 |
+
# return last "yield"ed value of generator
|
626 |
+
r = None
|
627 |
+
for r in self._predict_sparse_generator(*args, **kwargs):
|
628 |
+
pass
|
629 |
+
return r
|
630 |
+
|
631 |
+
|
632 |
+
def _predict_instances_generator(self, img, axes=None, normalizer=None,
|
633 |
+
sparse=True,
|
634 |
+
prob_thresh=None, nms_thresh=None,
|
635 |
+
scale=None,
|
636 |
+
n_tiles=None, show_tile_progress=True,
|
637 |
+
verbose=False,
|
638 |
+
return_labels=True,
|
639 |
+
predict_kwargs=None, nms_kwargs=None,
|
640 |
+
overlap_label=None, return_predict=False):
|
641 |
+
"""Predict instance segmentation from input image.
|
642 |
+
|
643 |
+
Parameters
|
644 |
+
----------
|
645 |
+
img : :class:`numpy.ndarray`
|
646 |
+
Input image
|
647 |
+
axes : str or None
|
648 |
+
Axes of the input ``img``.
|
649 |
+
``None`` denotes that axes of img are the same as denoted in the config.
|
650 |
+
normalizer : :class:`csbdeep.data.Normalizer` or None
|
651 |
+
(Optional) normalization of input image before prediction.
|
652 |
+
Note that the default (``None``) assumes ``img`` to be already normalized.
|
653 |
+
sparse: bool
|
654 |
+
If true, aggregate probabilities/distances sparsely during tiled
|
655 |
+
prediction to save memory (recommended).
|
656 |
+
prob_thresh : float or None
|
657 |
+
Consider only object candidates from pixels with predicted object probability
|
658 |
+
above this threshold (also see `optimize_thresholds`).
|
659 |
+
nms_thresh : float or None
|
660 |
+
Perform non-maximum suppression that considers two objects to be the same
|
661 |
+
when their area/surface overlap exceeds this threshold (also see `optimize_thresholds`).
|
662 |
+
scale: None or float or iterable
|
663 |
+
Scale the input image internally by this factor and rescale the output accordingly.
|
664 |
+
All spatial axes (X,Y,Z) will be scaled if a scalar value is provided.
|
665 |
+
Alternatively, multiple scale values (compatible with input `axes`) can be used
|
666 |
+
for more fine-grained control (scale values for non-spatial axes must be 1).
|
667 |
+
n_tiles : iterable or None
|
668 |
+
Out of memory (OOM) errors can occur if the input image is too large.
|
669 |
+
To avoid this problem, the input image is broken up into (overlapping) tiles
|
670 |
+
that are processed independently and re-assembled.
|
671 |
+
This parameter denotes a tuple of the number of tiles for every image axis (see ``axes``).
|
672 |
+
``None`` denotes that no tiling should be used.
|
673 |
+
show_tile_progress: bool
|
674 |
+
Whether to show progress during tiled prediction.
|
675 |
+
verbose: bool
|
676 |
+
Whether to print some info messages.
|
677 |
+
return_labels: bool
|
678 |
+
Whether to create a label image, otherwise return None in its place.
|
679 |
+
predict_kwargs: dict
|
680 |
+
Keyword arguments for ``predict`` function of Keras model.
|
681 |
+
nms_kwargs: dict
|
682 |
+
Keyword arguments for non-maximum suppression.
|
683 |
+
overlap_label: scalar or None
|
684 |
+
if not None, label the regions where polygons overlap with that value
|
685 |
+
return_predict: bool
|
686 |
+
Also return the outputs of :func:`predict` (in a separate tuple)
|
687 |
+
If True, implies sparse = False
|
688 |
+
|
689 |
+
Returns
|
690 |
+
-------
|
691 |
+
(:class:`numpy.ndarray`, dict), (optional: return tuple of :func:`predict`)
|
692 |
+
Returns a tuple of the label instances image and also
|
693 |
+
a dictionary with the details (coordinates, etc.) of all remaining polygons/polyhedra.
|
694 |
+
|
695 |
+
"""
|
696 |
+
if predict_kwargs is None:
|
697 |
+
predict_kwargs = {}
|
698 |
+
if nms_kwargs is None:
|
699 |
+
nms_kwargs = {}
|
700 |
+
|
701 |
+
if return_predict and sparse:
|
702 |
+
sparse = False
|
703 |
+
warnings.warn("Setting sparse to False because return_predict is True")
|
704 |
+
|
705 |
+
nms_kwargs.setdefault("verbose", verbose)
|
706 |
+
|
707 |
+
_axes = self._normalize_axes(img, axes)
|
708 |
+
_axes_net = self.config.axes
|
709 |
+
_permute_axes = self._make_permute_axes(_axes, _axes_net)
|
710 |
+
_shape_inst = tuple(s for s,a in zip(_permute_axes(img).shape, _axes_net) if a != 'C')
|
711 |
+
|
712 |
+
if scale is not None:
|
713 |
+
if isinstance(scale, numbers.Number):
|
714 |
+
scale = tuple(scale if a in 'XYZ' else 1 for a in _axes)
|
715 |
+
scale = tuple(scale)
|
716 |
+
len(scale) == len(_axes) or _raise(ValueError(f"scale {scale} must be of length {len(_axes)}, i.e. one value for each of the axes {_axes}"))
|
717 |
+
for s,a in zip(scale,_axes):
|
718 |
+
s > 0 or _raise(ValueError("scale values must be greater than 0"))
|
719 |
+
(s in (1,None) or a in 'XYZ') or warnings.warn(f"replacing scale value {s} for non-spatial axis {a} with 1")
|
720 |
+
scale = tuple(s if a in 'XYZ' else 1 for s,a in zip(scale,_axes))
|
721 |
+
verbose and print(f"scaling image by factors {scale} for axes {_axes}")
|
722 |
+
img = ndi.zoom(img, scale, order=1)
|
723 |
+
|
724 |
+
yield 'predict' # indicate that prediction is starting
|
725 |
+
res = None
|
726 |
+
if sparse:
|
727 |
+
for res in self._predict_sparse_generator(img, axes=axes, normalizer=normalizer, n_tiles=n_tiles,
|
728 |
+
prob_thresh=prob_thresh, show_tile_progress=show_tile_progress, **predict_kwargs):
|
729 |
+
if res is None:
|
730 |
+
yield 'tile' # yield 'tile' each time a tile has been processed
|
731 |
+
else:
|
732 |
+
for res in self._predict_generator(img, axes=axes, normalizer=normalizer, n_tiles=n_tiles,
|
733 |
+
show_tile_progress=show_tile_progress, **predict_kwargs):
|
734 |
+
if res is None:
|
735 |
+
yield 'tile' # yield 'tile' each time a tile has been processed
|
736 |
+
res = tuple(res) + (None,)
|
737 |
+
|
738 |
+
if self._is_multiclass():
|
739 |
+
prob, dist, prob_class, points = res
|
740 |
+
else:
|
741 |
+
prob, dist, points = res
|
742 |
+
prob_class = None
|
743 |
+
|
744 |
+
yield 'nms' # indicate that non-maximum suppression is starting
|
745 |
+
res_instances = self._instances_from_prediction(_shape_inst, prob, dist,
|
746 |
+
points=points,
|
747 |
+
prob_class=prob_class,
|
748 |
+
prob_thresh=prob_thresh,
|
749 |
+
nms_thresh=nms_thresh,
|
750 |
+
scale=(None if scale is None else dict(zip(_axes,scale))),
|
751 |
+
return_labels=return_labels,
|
752 |
+
overlap_label=overlap_label,
|
753 |
+
**nms_kwargs)
|
754 |
+
|
755 |
+
# last "yield" is the actual output that would have been "return"ed if this was a regular function
|
756 |
+
if return_predict:
|
757 |
+
yield res_instances, tuple(res[:-1])
|
758 |
+
else:
|
759 |
+
yield res_instances
|
760 |
+
|
761 |
+
|
762 |
+
@functools.wraps(_predict_instances_generator)
|
763 |
+
def predict_instances(self, *args, **kwargs):
|
764 |
+
# the reason why the actual computation happens as a generator function
|
765 |
+
# (in '_predict_instances_generator') is that the generator is called
|
766 |
+
# from the stardist napari plugin, which has its benefits regarding
|
767 |
+
# control flow and progress display. however, typical use cases should
|
768 |
+
# almost always use this function ('predict_instances'), and shouldn't
|
769 |
+
# even notice (thanks to @functools.wraps) that it wraps the generator
|
770 |
+
# function. note that similar reasoning applies to 'predict' and
|
771 |
+
# 'predict_sparse'.
|
772 |
+
|
773 |
+
# return last "yield"ed value of generator
|
774 |
+
r = None
|
775 |
+
for r in self._predict_instances_generator(*args, **kwargs):
|
776 |
+
pass
|
777 |
+
return r
|
778 |
+
|
779 |
+
|
780 |
+
# def _predict_instances_old(self, img, axes=None, normalizer=None,
|
781 |
+
# sparse = False,
|
782 |
+
# prob_thresh=None, nms_thresh=None,
|
783 |
+
# n_tiles=None, show_tile_progress=True,
|
784 |
+
# verbose = False,
|
785 |
+
# predict_kwargs=None, nms_kwargs=None, overlap_label=None):
|
786 |
+
# """
|
787 |
+
# old version, should be removed....
|
788 |
+
# """
|
789 |
+
# if predict_kwargs is None:
|
790 |
+
# predict_kwargs = {}
|
791 |
+
# if nms_kwargs is None:
|
792 |
+
# nms_kwargs = {}
|
793 |
+
|
794 |
+
# nms_kwargs.setdefault("verbose", verbose)
|
795 |
+
|
796 |
+
# _axes = self._normalize_axes(img, axes)
|
797 |
+
# _axes_net = self.config.axes
|
798 |
+
# _permute_axes = self._make_permute_axes(_axes, _axes_net)
|
799 |
+
# _shape_inst = tuple(s for s,a in zip(_permute_axes(img).shape, _axes_net) if a != 'C')
|
800 |
+
|
801 |
+
|
802 |
+
# res = self.predict(img, axes=axes, normalizer=normalizer,
|
803 |
+
# n_tiles=n_tiles,
|
804 |
+
# show_tile_progress=show_tile_progress,
|
805 |
+
# **predict_kwargs)
|
806 |
+
|
807 |
+
# res = tuple(res) + (None,)
|
808 |
+
|
809 |
+
# if self._is_multiclass():
|
810 |
+
# prob, dist, prob_class, points = res
|
811 |
+
# else:
|
812 |
+
# prob, dist, points = res
|
813 |
+
# prob_class = None
|
814 |
+
|
815 |
+
|
816 |
+
# return self._instances_from_prediction_old(_shape_inst, prob, dist,
|
817 |
+
# points = points,
|
818 |
+
# prob_class = prob_class,
|
819 |
+
# prob_thresh=prob_thresh,
|
820 |
+
# nms_thresh=nms_thresh,
|
821 |
+
# overlap_label=overlap_label,
|
822 |
+
# **nms_kwargs)
|
823 |
+
|
824 |
+
|
825 |
+
def predict_instances_big(self, img, axes, block_size, min_overlap, context=None,
|
826 |
+
labels_out=None, labels_out_dtype=np.int32, show_progress=True, **kwargs):
|
827 |
+
"""Predict instance segmentation from very large input images.
|
828 |
+
|
829 |
+
Intended to be used when `predict_instances` cannot be used due to memory limitations.
|
830 |
+
This function will break the input image into blocks and process them individually
|
831 |
+
via `predict_instances` and assemble all the partial results. If used as intended, the result
|
832 |
+
should be the same as if `predict_instances` was used directly on the whole image.
|
833 |
+
|
834 |
+
**Important**: The crucial assumption is that all predicted object instances are smaller than
|
835 |
+
the provided `min_overlap`. Also, it must hold that: min_overlap + 2*context < block_size.
|
836 |
+
|
837 |
+
Example
|
838 |
+
-------
|
839 |
+
>>> img.shape
|
840 |
+
(20000, 20000)
|
841 |
+
>>> labels, polys = model.predict_instances_big(img, axes='YX', block_size=4096,
|
842 |
+
min_overlap=128, context=128, n_tiles=(4,4))
|
843 |
+
|
844 |
+
Parameters
|
845 |
+
----------
|
846 |
+
img: :class:`numpy.ndarray` or similar
|
847 |
+
Input image
|
848 |
+
axes: str
|
849 |
+
Axes of the input ``img`` (such as 'YX', 'ZYX', 'YXC', etc.)
|
850 |
+
block_size: int or iterable of int
|
851 |
+
Process input image in blocks of the provided shape.
|
852 |
+
(If a scalar value is given, it is used for all spatial image dimensions.)
|
853 |
+
min_overlap: int or iterable of int
|
854 |
+
Amount of guaranteed overlap between blocks.
|
855 |
+
(If a scalar value is given, it is used for all spatial image dimensions.)
|
856 |
+
context: int or iterable of int, or None
|
857 |
+
Amount of image context on all sides of a block, which is discarded.
|
858 |
+
If None, uses an automatic estimate that should work in many cases.
|
859 |
+
(If a scalar value is given, it is used for all spatial image dimensions.)
|
860 |
+
labels_out: :class:`numpy.ndarray` or similar, or None, or False
|
861 |
+
numpy array or similar (must be of correct shape) to which the label image is written.
|
862 |
+
If None, will allocate a numpy array of the correct shape and data type ``labels_out_dtype``.
|
863 |
+
If False, will not write the label image (useful if only the dictionary is needed).
|
864 |
+
labels_out_dtype: str or dtype
|
865 |
+
Data type of returned label image if ``labels_out=None`` (has no effect otherwise).
|
866 |
+
show_progress: bool
|
867 |
+
Show progress bar for block processing.
|
868 |
+
kwargs: dict
|
869 |
+
Keyword arguments for ``predict_instances``.
|
870 |
+
|
871 |
+
Returns
|
872 |
+
-------
|
873 |
+
(:class:`numpy.ndarray` or False, dict)
|
874 |
+
Returns the label image and a dictionary with the details (coordinates, etc.) of the polygons/polyhedra.
|
875 |
+
|
876 |
+
"""
|
877 |
+
from ..big import _grid_divisible, BlockND, OBJECT_KEYS#, repaint_labels
|
878 |
+
from ..matching import relabel_sequential
|
879 |
+
|
880 |
+
n = img.ndim
|
881 |
+
axes = axes_check_and_normalize(axes, length=n)
|
882 |
+
grid = self._axes_div_by(axes)
|
883 |
+
axes_out = self._axes_out.replace('C','')
|
884 |
+
shape_dict = dict(zip(axes,img.shape))
|
885 |
+
shape_out = tuple(shape_dict[a] for a in axes_out)
|
886 |
+
|
887 |
+
if context is None:
|
888 |
+
context = self._axes_tile_overlap(axes)
|
889 |
+
|
890 |
+
if np.isscalar(block_size): block_size = n*[block_size]
|
891 |
+
if np.isscalar(min_overlap): min_overlap = n*[min_overlap]
|
892 |
+
if np.isscalar(context): context = n*[context]
|
893 |
+
block_size, min_overlap, context = list(block_size), list(min_overlap), list(context)
|
894 |
+
assert n == len(block_size) == len(min_overlap) == len(context)
|
895 |
+
|
896 |
+
if 'C' in axes:
|
897 |
+
# single block for channel axis
|
898 |
+
i = axes_dict(axes)['C']
|
899 |
+
# if (block_size[i], min_overlap[i], context[i]) != (None, None, None):
|
900 |
+
# print("Ignoring values of 'block_size', 'min_overlap', and 'context' for channel axis " +
|
901 |
+
# "(set to 'None' to avoid this warning).", file=sys.stderr, flush=True)
|
902 |
+
block_size[i] = img.shape[i]
|
903 |
+
min_overlap[i] = context[i] = 0
|
904 |
+
|
905 |
+
block_size = tuple(_grid_divisible(g, v, name='block_size', verbose=False) for v,g,a in zip(block_size, grid,axes))
|
906 |
+
min_overlap = tuple(_grid_divisible(g, v, name='min_overlap', verbose=False) for v,g,a in zip(min_overlap,grid,axes))
|
907 |
+
context = tuple(_grid_divisible(g, v, name='context', verbose=False) for v,g,a in zip(context, grid,axes))
|
908 |
+
|
909 |
+
# print(f"input: shape {img.shape} with axes {axes}")
|
910 |
+
print(f'effective: block_size={block_size}, min_overlap={min_overlap}, context={context}', flush=True)
|
911 |
+
|
912 |
+
for a,c,o in zip(axes,context,self._axes_tile_overlap(axes)):
|
913 |
+
if c < o:
|
914 |
+
print(f"{a}: context of {c} is small, recommended to use at least {o}", flush=True)
|
915 |
+
|
916 |
+
# create block cover
|
917 |
+
blocks = BlockND.cover(img.shape, axes, block_size, min_overlap, context, grid)
|
918 |
+
|
919 |
+
if np.isscalar(labels_out) and bool(labels_out) is False:
|
920 |
+
labels_out = None
|
921 |
+
else:
|
922 |
+
if labels_out is None:
|
923 |
+
labels_out = np.zeros(shape_out, dtype=labels_out_dtype)
|
924 |
+
else:
|
925 |
+
labels_out.shape == shape_out or _raise(ValueError(f"'labels_out' must have shape {shape_out} (axes {axes_out})."))
|
926 |
+
|
927 |
+
polys_all = {}
|
928 |
+
# problem_ids = []
|
929 |
+
label_offset = 1
|
930 |
+
|
931 |
+
kwargs_override = dict(axes=axes, overlap_label=None, return_labels=True, return_predict=False)
|
932 |
+
if show_progress:
|
933 |
+
kwargs_override['show_tile_progress'] = False # disable progress for predict_instances
|
934 |
+
for k,v in kwargs_override.items():
|
935 |
+
if k in kwargs: print(f"changing '{k}' from {kwargs[k]} to {v}", flush=True)
|
936 |
+
kwargs[k] = v
|
937 |
+
|
938 |
+
blocks = tqdm(blocks, disable=(not show_progress))
|
939 |
+
# actual computation
|
940 |
+
for block in blocks:
|
941 |
+
labels, polys = self.predict_instances(block.read(img, axes=axes), **kwargs)
|
942 |
+
labels = block.crop_context(labels, axes=axes_out)
|
943 |
+
labels, polys = block.filter_objects(labels, polys, axes=axes_out)
|
944 |
+
# TODO: relabel_sequential is not very memory-efficient (will allocate memory proportional to label_offset)
|
945 |
+
# this should not change the order of labels
|
946 |
+
labels = relabel_sequential(labels, label_offset)[0]
|
947 |
+
|
948 |
+
# labels, fwd_map, _ = relabel_sequential(labels, label_offset)
|
949 |
+
# if len(incomplete) > 0:
|
950 |
+
# problem_ids.extend([fwd_map[i] for i in incomplete])
|
951 |
+
# if show_progress:
|
952 |
+
# blocks.set_postfix_str(f"found {len(problem_ids)} problematic {'object' if len(problem_ids)==1 else 'objects'}")
|
953 |
+
if labels_out is not None:
|
954 |
+
block.write(labels_out, labels, axes=axes_out)
|
955 |
+
|
956 |
+
for k,v in polys.items():
|
957 |
+
polys_all.setdefault(k,[]).append(v)
|
958 |
+
|
959 |
+
label_offset += len(polys['prob'])
|
960 |
+
del labels
|
961 |
+
|
962 |
+
polys_all = {k: (np.concatenate(v) if k in OBJECT_KEYS else v[0]) for k,v in polys_all.items()}
|
963 |
+
|
964 |
+
# if labels_out is not None and len(problem_ids) > 0:
|
965 |
+
# # if show_progress:
|
966 |
+
# # blocks.write('')
|
967 |
+
# # print(f"Found {len(problem_ids)} objects that violate the 'min_overlap' assumption.", file=sys.stderr, flush=True)
|
968 |
+
# repaint_labels(labels_out, problem_ids, polys_all, show_progress=False)
|
969 |
+
|
970 |
+
return labels_out, polys_all#, tuple(problem_ids)
|
971 |
+
|
972 |
+
|
973 |
+
def optimize_thresholds(self, X_val, Y_val, nms_threshs=[0.3,0.4,0.5], iou_threshs=[0.3,0.5,0.7], predict_kwargs=None, optimize_kwargs=None, save_to_json=True):
|
974 |
+
"""Optimize two thresholds (probability, NMS overlap) necessary for predicting object instances.
|
975 |
+
|
976 |
+
Note that the default thresholds yield good results in many cases, but optimizing
|
977 |
+
the thresholds for a particular dataset can further improve performance.
|
978 |
+
|
979 |
+
The optimized thresholds are automatically used for all further predictions
|
980 |
+
and also written to the model directory.
|
981 |
+
|
982 |
+
See ``utils.optimize_threshold`` for details and possible choices for ``optimize_kwargs``.
|
983 |
+
|
984 |
+
Parameters
|
985 |
+
----------
|
986 |
+
X_val : list of ndarray
|
987 |
+
(Validation) input images (must be normalized) to use for threshold tuning.
|
988 |
+
Y_val : list of ndarray
|
989 |
+
(Validation) label images to use for threshold tuning.
|
990 |
+
nms_threshs : list of float
|
991 |
+
List of overlap thresholds to be considered for NMS.
|
992 |
+
For each value in this list, optimization is run to find a corresponding prob_thresh value.
|
993 |
+
iou_threshs : list of float
|
994 |
+
List of intersection over union (IOU) thresholds for which
|
995 |
+
the (average) matching performance is considered to tune the thresholds.
|
996 |
+
predict_kwargs: dict
|
997 |
+
Keyword arguments for ``predict`` function of this class.
|
998 |
+
(If not provided, will guess value for `n_tiles` to prevent out of memory errors.)
|
999 |
+
optimize_kwargs: dict
|
1000 |
+
Keyword arguments for ``utils.optimize_threshold`` function.
|
1001 |
+
|
1002 |
+
"""
|
1003 |
+
if predict_kwargs is None:
|
1004 |
+
predict_kwargs = {}
|
1005 |
+
if optimize_kwargs is None:
|
1006 |
+
optimize_kwargs = {}
|
1007 |
+
|
1008 |
+
def _predict_kwargs(x):
|
1009 |
+
if 'n_tiles' in predict_kwargs:
|
1010 |
+
return predict_kwargs
|
1011 |
+
else:
|
1012 |
+
return {**predict_kwargs, 'n_tiles': self._guess_n_tiles(x), 'show_tile_progress': False}
|
1013 |
+
|
1014 |
+
# only take first two elements of predict in case multi class is activated
|
1015 |
+
Yhat_val = [self.predict(x, **_predict_kwargs(x))[:2] for x in X_val]
|
1016 |
+
|
1017 |
+
opt_prob_thresh, opt_measure, opt_nms_thresh = None, -np.inf, None
|
1018 |
+
for _opt_nms_thresh in nms_threshs:
|
1019 |
+
_opt_prob_thresh, _opt_measure = optimize_threshold(Y_val, Yhat_val, model=self, nms_thresh=_opt_nms_thresh, iou_threshs=iou_threshs, **optimize_kwargs)
|
1020 |
+
if _opt_measure > opt_measure:
|
1021 |
+
opt_prob_thresh, opt_measure, opt_nms_thresh = _opt_prob_thresh, _opt_measure, _opt_nms_thresh
|
1022 |
+
opt_threshs = dict(prob=opt_prob_thresh, nms=opt_nms_thresh)
|
1023 |
+
|
1024 |
+
self.thresholds = opt_threshs
|
1025 |
+
print(end='', file=sys.stderr, flush=True)
|
1026 |
+
print("Using optimized values: prob_thresh={prob:g}, nms_thresh={nms:g}.".format(prob=self.thresholds.prob, nms=self.thresholds.nms))
|
1027 |
+
if save_to_json and self.basedir is not None:
|
1028 |
+
print("Saving to 'thresholds.json'.")
|
1029 |
+
save_json(opt_threshs, str(self.logdir / 'thresholds.json'))
|
1030 |
+
return opt_threshs
|
1031 |
+
|
1032 |
+
|
1033 |
+
def _guess_n_tiles(self, img):
|
1034 |
+
axes = self._normalize_axes(img, axes=None)
|
1035 |
+
shape = list(img.shape)
|
1036 |
+
if 'C' in axes:
|
1037 |
+
del shape[axes_dict(axes)['C']]
|
1038 |
+
b = self.config.train_batch_size**(1.0/self.config.n_dim)
|
1039 |
+
n_tiles = [int(np.ceil(s/(p*b))) for s,p in zip(shape,self.config.train_patch_size)]
|
1040 |
+
if 'C' in axes:
|
1041 |
+
n_tiles.insert(axes_dict(axes)['C'],1)
|
1042 |
+
return tuple(n_tiles)
|
1043 |
+
|
1044 |
+
|
1045 |
+
def _normalize_axes(self, img, axes):
|
1046 |
+
if axes is None:
|
1047 |
+
axes = self.config.axes
|
1048 |
+
assert 'C' in axes
|
1049 |
+
if img.ndim == len(axes)-1 and self.config.n_channel_in == 1:
|
1050 |
+
# img has no dedicated channel axis, but 'C' always part of config axes
|
1051 |
+
axes = axes.replace('C','')
|
1052 |
+
return axes_check_and_normalize(axes, img.ndim)
|
1053 |
+
|
1054 |
+
|
1055 |
+
def _compute_receptive_field(self, img_size=None):
|
1056 |
+
# TODO: good enough?
|
1057 |
+
from scipy.ndimage import zoom
|
1058 |
+
if img_size is None:
|
1059 |
+
img_size = tuple(g*(128 if self.config.n_dim==2 else 64) for g in self.config.grid)
|
1060 |
+
if np.isscalar(img_size):
|
1061 |
+
img_size = (img_size,) * self.config.n_dim
|
1062 |
+
img_size = tuple(img_size)
|
1063 |
+
# print(img_size)
|
1064 |
+
assert all(_is_power_of_2(s) for s in img_size)
|
1065 |
+
mid = tuple(s//2 for s in img_size)
|
1066 |
+
x = np.zeros((1,)+img_size+(self.config.n_channel_in,), dtype=np.float32)
|
1067 |
+
z = np.zeros_like(x)
|
1068 |
+
x[(0,)+mid+(slice(None),)] = 1
|
1069 |
+
y = self.keras_model.predict(x)[0][0,...,0]
|
1070 |
+
y0 = self.keras_model.predict(z)[0][0,...,0]
|
1071 |
+
grid = tuple((np.array(x.shape[1:-1])/np.array(y.shape)).astype(int))
|
1072 |
+
assert grid == self.config.grid
|
1073 |
+
y = zoom(y, grid,order=0)
|
1074 |
+
y0 = zoom(y0,grid,order=0)
|
1075 |
+
ind = np.where(np.abs(y-y0)>0)
|
1076 |
+
return [(m-np.min(i), np.max(i)-m) for (m,i) in zip(mid,ind)]
|
1077 |
+
|
1078 |
+
|
1079 |
+
def _axes_tile_overlap(self, query_axes):
|
1080 |
+
query_axes = axes_check_and_normalize(query_axes)
|
1081 |
+
try:
|
1082 |
+
self._tile_overlap
|
1083 |
+
except AttributeError:
|
1084 |
+
self._tile_overlap = self._compute_receptive_field()
|
1085 |
+
overlap = dict(zip(
|
1086 |
+
self.config.axes.replace('C',''),
|
1087 |
+
tuple(max(rf) for rf in self._tile_overlap)
|
1088 |
+
))
|
1089 |
+
return tuple(overlap.get(a,0) for a in query_axes)
|
1090 |
+
|
1091 |
+
|
1092 |
+
def export_TF(self, fname=None, single_output=True, upsample_grid=True):
|
1093 |
+
"""Export model to TensorFlow's SavedModel format that can be used e.g. in the Fiji plugin
|
1094 |
+
|
1095 |
+
Parameters
|
1096 |
+
----------
|
1097 |
+
fname : str
|
1098 |
+
Path of the zip file to store the model
|
1099 |
+
If None, the default path "<modeldir>/TF_SavedModel.zip" is used
|
1100 |
+
single_output: bool
|
1101 |
+
If set, concatenates the two model outputs into a single output (note: this is currently mandatory for further use in Fiji)
|
1102 |
+
upsample_grid: bool
|
1103 |
+
If set, upsamples the output to the input shape (note: this is currently mandatory for further use in Fiji)
|
1104 |
+
"""
|
1105 |
+
Concatenate, UpSampling2D, UpSampling3D, Conv2DTranspose, Conv3DTranspose = keras_import('layers', 'Concatenate', 'UpSampling2D', 'UpSampling3D', 'Conv2DTranspose', 'Conv3DTranspose')
|
1106 |
+
Model = keras_import('models', 'Model')
|
1107 |
+
|
1108 |
+
if self.basedir is None and fname is None:
|
1109 |
+
raise ValueError("Need explicit 'fname', since model directory not available (basedir=None).")
|
1110 |
+
|
1111 |
+
if self._is_multiclass():
|
1112 |
+
warnings.warn("multi-class mode not supported yet, removing classification output from exported model")
|
1113 |
+
|
1114 |
+
grid = self.config.grid
|
1115 |
+
prob = self.keras_model.outputs[0]
|
1116 |
+
dist = self.keras_model.outputs[1]
|
1117 |
+
assert self.config.n_dim in (2,3)
|
1118 |
+
|
1119 |
+
if upsample_grid and any(g>1 for g in grid):
|
1120 |
+
# CSBDeep Fiji plugin needs same size input/output
|
1121 |
+
# -> we need to upsample the outputs if grid > (1,1)
|
1122 |
+
# note: upsampling prob with a transposed convolution creates sparse
|
1123 |
+
# prob output with less candidates than with standard upsampling
|
1124 |
+
conv_transpose = Conv2DTranspose if self.config.n_dim==2 else Conv3DTranspose
|
1125 |
+
upsampling = UpSampling2D if self.config.n_dim==2 else UpSampling3D
|
1126 |
+
prob = conv_transpose(1, (1,)*self.config.n_dim,
|
1127 |
+
strides=grid, padding='same',
|
1128 |
+
kernel_initializer='ones', use_bias=False)(prob)
|
1129 |
+
dist = upsampling(grid)(dist)
|
1130 |
+
|
1131 |
+
inputs = self.keras_model.inputs[0]
|
1132 |
+
outputs = Concatenate()([prob,dist]) if single_output else [prob,dist]
|
1133 |
+
csbdeep_model = Model(inputs, outputs)
|
1134 |
+
|
1135 |
+
fname = (self.logdir / 'TF_SavedModel.zip') if fname is None else Path(fname)
|
1136 |
+
export_SavedModel(csbdeep_model, str(fname))
|
1137 |
+
return csbdeep_model
|
1138 |
+
|
1139 |
+
|
1140 |
+
|
1141 |
+
class StarDistPadAndCropResizer(Resizer):
|
1142 |
+
|
1143 |
+
# TODO: check correctness
|
1144 |
+
def __init__(self, grid, mode='reflect', **kwargs):
|
1145 |
+
assert isinstance(grid, dict)
|
1146 |
+
self.mode = mode
|
1147 |
+
self.grid = grid
|
1148 |
+
self.kwargs = kwargs
|
1149 |
+
|
1150 |
+
|
1151 |
+
def before(self, x, axes, axes_div_by):
|
1152 |
+
assert all(a%g==0 for g,a in zip((self.grid.get(a,1) for a in axes), axes_div_by))
|
1153 |
+
axes = axes_check_and_normalize(axes,x.ndim)
|
1154 |
+
def _split(v):
|
1155 |
+
return 0, v # only pad at the end
|
1156 |
+
self.pad = {
|
1157 |
+
a : _split((div_n-s%div_n)%div_n)
|
1158 |
+
for a, div_n, s in zip(axes, axes_div_by, x.shape)
|
1159 |
+
}
|
1160 |
+
x_pad = np.pad(x, tuple(self.pad[a] for a in axes), mode=self.mode, **self.kwargs)
|
1161 |
+
self.padded_shape = dict(zip(axes,x_pad.shape))
|
1162 |
+
if 'C' in self.padded_shape: del self.padded_shape['C']
|
1163 |
+
return x_pad
|
1164 |
+
|
1165 |
+
|
1166 |
+
def after(self, x, axes):
|
1167 |
+
# axes can include 'C', which may not have been present in before()
|
1168 |
+
axes = axes_check_and_normalize(axes,x.ndim)
|
1169 |
+
assert all(s_pad == s * g for s,s_pad,g in zip(x.shape,
|
1170 |
+
(self.padded_shape.get(a,_s) for a,_s in zip(axes,x.shape)),
|
1171 |
+
(self.grid.get(a,1) for a in axes)))
|
1172 |
+
# print(self.padded_shape)
|
1173 |
+
# print(self.pad)
|
1174 |
+
# print(self.grid)
|
1175 |
+
crop = tuple (
|
1176 |
+
slice(0, -(math.floor(p[1]/g)) if p[1]>=g else None)
|
1177 |
+
for p,g in zip((self.pad.get(a,(0,0)) for a in axes),(self.grid.get(a,1) for a in axes))
|
1178 |
+
)
|
1179 |
+
# print(crop)
|
1180 |
+
return x[crop]
|
1181 |
+
|
1182 |
+
|
1183 |
+
def filter_points(self, ndim, points, axes):
|
1184 |
+
""" returns indices of points inside crop region """
|
1185 |
+
assert points.ndim==2
|
1186 |
+
axes = axes_check_and_normalize(axes,ndim)
|
1187 |
+
|
1188 |
+
bounds = np.array(tuple(self.padded_shape[a]-self.pad[a][1] for a in axes if a.lower() in ('z','y','x')))
|
1189 |
+
idx = np.where(np.all(points< bounds, 1))
|
1190 |
+
return idx
|
1191 |
+
|
1192 |
+
|
1193 |
+
|
1194 |
+
def _tf_version_at_least(version_string="1.0.0"):
|
1195 |
+
from packaging import version
|
1196 |
+
return version.parse(tf.__version__) >= version.parse(version_string)
|
stardist_pkg/models/model2d.py
ADDED
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import warnings
|
5 |
+
import math
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
from csbdeep.models import BaseConfig
|
9 |
+
from csbdeep.internals.blocks import unet_block
|
10 |
+
from csbdeep.utils import _raise, backend_channels_last, axes_check_and_normalize, axes_dict
|
11 |
+
from csbdeep.utils.tf import keras_import, IS_TF_1, CARETensorBoard, CARETensorBoardImage
|
12 |
+
from skimage.segmentation import clear_border
|
13 |
+
from skimage.measure import regionprops
|
14 |
+
from scipy.ndimage import zoom
|
15 |
+
from distutils.version import LooseVersion
|
16 |
+
|
17 |
+
keras = keras_import()
|
18 |
+
K = keras_import('backend')
|
19 |
+
Input, Conv2D, MaxPooling2D = keras_import('layers', 'Input', 'Conv2D', 'MaxPooling2D')
|
20 |
+
Model = keras_import('models', 'Model')
|
21 |
+
|
22 |
+
from .base import StarDistBase, StarDistDataBase, _tf_version_at_least
|
23 |
+
from ..sample_patches import sample_patches
|
24 |
+
from ..utils import edt_prob, _normalize_grid, mask_to_categorical
|
25 |
+
from ..geometry import star_dist, dist_to_coord, polygons_to_label
|
26 |
+
from ..nms import non_maximum_suppression, non_maximum_suppression_sparse
|
27 |
+
|
28 |
+
|
29 |
+
class StarDistData2D(StarDistDataBase):
|
30 |
+
|
31 |
+
def __init__(self, X, Y, batch_size, n_rays, length,
|
32 |
+
n_classes=None, classes=None,
|
33 |
+
patch_size=(256,256), b=32, grid=(1,1), shape_completion=False, augmenter=None, foreground_prob=0, **kwargs):
|
34 |
+
|
35 |
+
super().__init__(X=X, Y=Y, n_rays=n_rays, grid=grid,
|
36 |
+
n_classes=n_classes, classes=classes,
|
37 |
+
batch_size=batch_size, patch_size=patch_size, length=length,
|
38 |
+
augmenter=augmenter, foreground_prob=foreground_prob, **kwargs)
|
39 |
+
|
40 |
+
self.shape_completion = bool(shape_completion)
|
41 |
+
if self.shape_completion and b > 0:
|
42 |
+
self.b = slice(b,-b),slice(b,-b)
|
43 |
+
else:
|
44 |
+
self.b = slice(None),slice(None)
|
45 |
+
|
46 |
+
self.sd_mode = 'opencl' if self.use_gpu else 'cpp'
|
47 |
+
|
48 |
+
|
49 |
+
def __getitem__(self, i):
|
50 |
+
idx = self.batch(i)
|
51 |
+
arrays = [sample_patches((self.Y[k],) + self.channels_as_tuple(self.X[k]),
|
52 |
+
patch_size=self.patch_size, n_samples=1,
|
53 |
+
valid_inds=self.get_valid_inds(k)) for k in idx]
|
54 |
+
|
55 |
+
if self.n_channel is None:
|
56 |
+
X, Y = list(zip(*[(x[0][self.b],y[0]) for y,x in arrays]))
|
57 |
+
else:
|
58 |
+
X, Y = list(zip(*[(np.stack([_x[0] for _x in x],axis=-1)[self.b], y[0]) for y,*x in arrays]))
|
59 |
+
|
60 |
+
X, Y = tuple(zip(*tuple(self.augmenter(_x, _y) for _x, _y in zip(X,Y))))
|
61 |
+
|
62 |
+
|
63 |
+
prob = np.stack([edt_prob(lbl[self.b][self.ss_grid[1:3]]) for lbl in Y])
|
64 |
+
# prob = np.stack([edt_prob(lbl[self.b]) for lbl in Y])
|
65 |
+
# prob = prob[self.ss_grid]
|
66 |
+
|
67 |
+
if self.shape_completion:
|
68 |
+
Y_cleared = [clear_border(lbl) for lbl in Y]
|
69 |
+
_dist = np.stack([star_dist(lbl,self.n_rays,mode=self.sd_mode)[self.b+(slice(None),)] for lbl in Y_cleared])
|
70 |
+
dist = _dist[self.ss_grid]
|
71 |
+
dist_mask = np.stack([edt_prob(lbl[self.b][self.ss_grid[1:3]]) for lbl in Y_cleared])
|
72 |
+
else:
|
73 |
+
# directly subsample with grid
|
74 |
+
dist = np.stack([star_dist(lbl,self.n_rays,mode=self.sd_mode, grid=self.grid) for lbl in Y])
|
75 |
+
dist_mask = prob
|
76 |
+
|
77 |
+
X = np.stack(X)
|
78 |
+
if X.ndim == 3: # input image has no channel axis
|
79 |
+
X = np.expand_dims(X,-1)
|
80 |
+
prob = np.expand_dims(prob,-1)
|
81 |
+
dist_mask = np.expand_dims(dist_mask,-1)
|
82 |
+
|
83 |
+
# subsample wth given grid
|
84 |
+
# dist_mask = dist_mask[self.ss_grid]
|
85 |
+
# prob = prob[self.ss_grid]
|
86 |
+
|
87 |
+
# append dist_mask to dist as additional channel
|
88 |
+
# dist_and_mask = np.concatenate([dist,dist_mask],axis=-1)
|
89 |
+
# faster than concatenate
|
90 |
+
dist_and_mask = np.empty(dist.shape[:-1]+(self.n_rays+1,), np.float32)
|
91 |
+
dist_and_mask[...,:-1] = dist
|
92 |
+
dist_and_mask[...,-1:] = dist_mask
|
93 |
+
|
94 |
+
|
95 |
+
if self.n_classes is None:
|
96 |
+
return [X], [prob,dist_and_mask]
|
97 |
+
else:
|
98 |
+
prob_class = np.stack(tuple((mask_to_categorical(y, self.n_classes, self.classes[k]) for y,k in zip(Y, idx))))
|
99 |
+
|
100 |
+
# TODO: investigate downsampling via simple indexing vs. using 'zoom'
|
101 |
+
# prob_class = prob_class[self.ss_grid]
|
102 |
+
# 'zoom' might lead to better registered maps (especially if upscaled later)
|
103 |
+
prob_class = zoom(prob_class, (1,)+tuple(1/g for g in self.grid)+(1,), order=0)
|
104 |
+
|
105 |
+
return [X], [prob,dist_and_mask, prob_class]
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
class Config2D(BaseConfig):
|
110 |
+
"""Configuration for a :class:`StarDist2D` model.
|
111 |
+
|
112 |
+
Parameters
|
113 |
+
----------
|
114 |
+
axes : str or None
|
115 |
+
Axes of the input images.
|
116 |
+
n_rays : int
|
117 |
+
Number of radial directions for the star-convex polygon.
|
118 |
+
Recommended to use a power of 2 (default: 32).
|
119 |
+
n_channel_in : int
|
120 |
+
Number of channels of given input image (default: 1).
|
121 |
+
grid : (int,int)
|
122 |
+
Subsampling factors (must be powers of 2) for each of the axes.
|
123 |
+
Model will predict on a subsampled grid for increased efficiency and larger field of view.
|
124 |
+
n_classes : None or int
|
125 |
+
Number of object classes to use for multi-class predection (use None to disable)
|
126 |
+
backbone : str
|
127 |
+
Name of the neural network architecture to be used as backbone.
|
128 |
+
kwargs : dict
|
129 |
+
Overwrite (or add) configuration attributes (see below).
|
130 |
+
|
131 |
+
|
132 |
+
Attributes
|
133 |
+
----------
|
134 |
+
unet_n_depth : int
|
135 |
+
Number of U-Net resolution levels (down/up-sampling layers).
|
136 |
+
unet_kernel_size : (int,int)
|
137 |
+
Convolution kernel size for all (U-Net) convolution layers.
|
138 |
+
unet_n_filter_base : int
|
139 |
+
Number of convolution kernels (feature channels) for first U-Net layer.
|
140 |
+
Doubled after each down-sampling layer.
|
141 |
+
unet_pool : (int,int)
|
142 |
+
Maxpooling size for all (U-Net) convolution layers.
|
143 |
+
net_conv_after_unet : int
|
144 |
+
Number of filters of the extra convolution layer after U-Net (0 to disable).
|
145 |
+
unet_* : *
|
146 |
+
Additional parameters for U-net backbone.
|
147 |
+
train_shape_completion : bool
|
148 |
+
Train model to predict complete shapes for partially visible objects at image boundary.
|
149 |
+
train_completion_crop : int
|
150 |
+
If 'train_shape_completion' is set to True, specify number of pixels to crop at boundary of training patches.
|
151 |
+
Should be chosen based on (largest) object sizes.
|
152 |
+
train_patch_size : (int,int)
|
153 |
+
Size of patches to be cropped from provided training images.
|
154 |
+
train_background_reg : float
|
155 |
+
Regularizer to encourage distance predictions on background regions to be 0.
|
156 |
+
train_foreground_only : float
|
157 |
+
Fraction (0..1) of patches that will only be sampled from regions that contain foreground pixels.
|
158 |
+
train_sample_cache : bool
|
159 |
+
Activate caching of valid patch regions for all training images (disable to save memory for large datasets)
|
160 |
+
train_dist_loss : str
|
161 |
+
Training loss for star-convex polygon distances ('mse' or 'mae').
|
162 |
+
train_loss_weights : tuple of float
|
163 |
+
Weights for losses relating to (probability, distance)
|
164 |
+
train_epochs : int
|
165 |
+
Number of training epochs.
|
166 |
+
train_steps_per_epoch : int
|
167 |
+
Number of parameter update steps per epoch.
|
168 |
+
train_learning_rate : float
|
169 |
+
Learning rate for training.
|
170 |
+
train_batch_size : int
|
171 |
+
Batch size for training.
|
172 |
+
train_n_val_patches : int
|
173 |
+
Number of patches to be extracted from validation images (``None`` = one patch per image).
|
174 |
+
train_tensorboard : bool
|
175 |
+
Enable TensorBoard for monitoring training progress.
|
176 |
+
train_reduce_lr : dict
|
177 |
+
Parameter :class:`dict` of ReduceLROnPlateau_ callback; set to ``None`` to disable.
|
178 |
+
use_gpu : bool
|
179 |
+
Indicate that the data generator should use OpenCL to do computations on the GPU.
|
180 |
+
|
181 |
+
.. _ReduceLROnPlateau: https://keras.io/api/callbacks/reduce_lr_on_plateau/
|
182 |
+
"""
|
183 |
+
|
184 |
+
def __init__(self, axes='YX', n_rays=32, n_channel_in=1, grid=(1,1), n_classes=None, backbone='unet', **kwargs):
|
185 |
+
"""See class docstring."""
|
186 |
+
|
187 |
+
super().__init__(axes=axes, n_channel_in=n_channel_in, n_channel_out=1+n_rays)
|
188 |
+
|
189 |
+
# directly set by parameters
|
190 |
+
self.n_rays = int(n_rays)
|
191 |
+
self.grid = _normalize_grid(grid,2)
|
192 |
+
self.backbone = str(backbone).lower()
|
193 |
+
self.n_classes = None if n_classes is None else int(n_classes)
|
194 |
+
|
195 |
+
# default config (can be overwritten by kwargs below)
|
196 |
+
if self.backbone == 'unet':
|
197 |
+
self.unet_n_depth = 3
|
198 |
+
self.unet_kernel_size = 3,3
|
199 |
+
self.unet_n_filter_base = 32
|
200 |
+
self.unet_n_conv_per_depth = 2
|
201 |
+
self.unet_pool = 2,2
|
202 |
+
self.unet_activation = 'relu'
|
203 |
+
self.unet_last_activation = 'relu'
|
204 |
+
self.unet_batch_norm = False
|
205 |
+
self.unet_dropout = 0.0
|
206 |
+
self.unet_prefix = ''
|
207 |
+
self.net_conv_after_unet = 128
|
208 |
+
else:
|
209 |
+
# TODO: resnet backbone for 2D model?
|
210 |
+
raise ValueError("backbone '%s' not supported." % self.backbone)
|
211 |
+
|
212 |
+
# net_mask_shape not needed but kept for legacy reasons
|
213 |
+
if backend_channels_last():
|
214 |
+
self.net_input_shape = None,None,self.n_channel_in
|
215 |
+
self.net_mask_shape = None,None,1
|
216 |
+
else:
|
217 |
+
self.net_input_shape = self.n_channel_in,None,None
|
218 |
+
self.net_mask_shape = 1,None,None
|
219 |
+
|
220 |
+
self.train_shape_completion = False
|
221 |
+
self.train_completion_crop = 32
|
222 |
+
self.train_patch_size = 256,256
|
223 |
+
self.train_background_reg = 1e-4
|
224 |
+
self.train_foreground_only = 0.9
|
225 |
+
self.train_sample_cache = True
|
226 |
+
|
227 |
+
self.train_dist_loss = 'mae'
|
228 |
+
self.train_loss_weights = (1,0.2) if self.n_classes is None else (1,0.2,1)
|
229 |
+
self.train_class_weights = (1,1) if self.n_classes is None else (1,)*(self.n_classes+1)
|
230 |
+
self.train_epochs = 400
|
231 |
+
self.train_steps_per_epoch = 100
|
232 |
+
self.train_learning_rate = 0.0003
|
233 |
+
self.train_batch_size = 4
|
234 |
+
self.train_n_val_patches = None
|
235 |
+
self.train_tensorboard = True
|
236 |
+
# the parameter 'min_delta' was called 'epsilon' for keras<=2.1.5
|
237 |
+
min_delta_key = 'epsilon' if LooseVersion(keras.__version__)<=LooseVersion('2.1.5') else 'min_delta'
|
238 |
+
self.train_reduce_lr = {'factor': 0.5, 'patience': 40, min_delta_key: 0}
|
239 |
+
|
240 |
+
self.use_gpu = False
|
241 |
+
|
242 |
+
# remove derived attributes that shouldn't be overwritten
|
243 |
+
for k in ('n_dim', 'n_channel_out'):
|
244 |
+
try: del kwargs[k]
|
245 |
+
except KeyError: pass
|
246 |
+
|
247 |
+
self.update_parameters(False, **kwargs)
|
248 |
+
|
249 |
+
# FIXME: put into is_valid()
|
250 |
+
if not len(self.train_loss_weights) == (2 if self.n_classes is None else 3):
|
251 |
+
raise ValueError(f"train_loss_weights {self.train_loss_weights} not compatible with n_classes ({self.n_classes}): must be 3 weights if n_classes is not None, otherwise 2")
|
252 |
+
|
253 |
+
if not len(self.train_class_weights) == (2 if self.n_classes is None else self.n_classes+1):
|
254 |
+
raise ValueError(f"train_class_weights {self.train_class_weights} not compatible with n_classes ({self.n_classes}): must be 'n_classes + 1' weights if n_classes is not None, otherwise 2")
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
class StarDist2D(StarDistBase):
|
259 |
+
"""StarDist2D model.
|
260 |
+
|
261 |
+
Parameters
|
262 |
+
----------
|
263 |
+
config : :class:`Config` or None
|
264 |
+
Will be saved to disk as JSON (``config.json``).
|
265 |
+
If set to ``None``, will be loaded from disk (must exist).
|
266 |
+
name : str or None
|
267 |
+
Model name. Uses a timestamp if set to ``None`` (default).
|
268 |
+
basedir : str
|
269 |
+
Directory that contains (or will contain) a folder with the given model name.
|
270 |
+
|
271 |
+
Raises
|
272 |
+
------
|
273 |
+
FileNotFoundError
|
274 |
+
If ``config=None`` and config cannot be loaded from disk.
|
275 |
+
ValueError
|
276 |
+
Illegal arguments, including invalid configuration.
|
277 |
+
|
278 |
+
Attributes
|
279 |
+
----------
|
280 |
+
config : :class:`Config`
|
281 |
+
Configuration, as provided during instantiation.
|
282 |
+
keras_model : `Keras model <https://keras.io/getting-started/functional-api-guide/>`_
|
283 |
+
Keras neural network model.
|
284 |
+
name : str
|
285 |
+
Model name.
|
286 |
+
logdir : :class:`pathlib.Path`
|
287 |
+
Path to model folder (which stores configuration, weights, etc.)
|
288 |
+
"""
|
289 |
+
|
290 |
+
def __init__(self, config=Config2D(), name=None, basedir='.'):
|
291 |
+
"""See class docstring."""
|
292 |
+
super().__init__(config, name=name, basedir=basedir)
|
293 |
+
|
294 |
+
|
295 |
+
def _build(self):
|
296 |
+
self.config.backbone == 'unet' or _raise(NotImplementedError())
|
297 |
+
unet_kwargs = {k[len('unet_'):]:v for (k,v) in vars(self.config).items() if k.startswith('unet_')}
|
298 |
+
|
299 |
+
input_img = Input(self.config.net_input_shape, name='input')
|
300 |
+
|
301 |
+
# maxpool input image to grid size
|
302 |
+
pooled = np.array([1,1])
|
303 |
+
pooled_img = input_img
|
304 |
+
while tuple(pooled) != tuple(self.config.grid):
|
305 |
+
pool = 1 + (np.asarray(self.config.grid) > pooled)
|
306 |
+
pooled *= pool
|
307 |
+
for _ in range(self.config.unet_n_conv_per_depth):
|
308 |
+
pooled_img = Conv2D(self.config.unet_n_filter_base, self.config.unet_kernel_size,
|
309 |
+
padding='same', activation=self.config.unet_activation)(pooled_img)
|
310 |
+
pooled_img = MaxPooling2D(pool)(pooled_img)
|
311 |
+
|
312 |
+
unet_base = unet_block(**unet_kwargs)(pooled_img)
|
313 |
+
|
314 |
+
if self.config.net_conv_after_unet > 0:
|
315 |
+
unet = Conv2D(self.config.net_conv_after_unet, self.config.unet_kernel_size,
|
316 |
+
name='features', padding='same', activation=self.config.unet_activation)(unet_base)
|
317 |
+
else:
|
318 |
+
unet = unet_base
|
319 |
+
|
320 |
+
output_prob = Conv2D( 1, (1,1), name='prob', padding='same', activation='sigmoid')(unet)
|
321 |
+
output_dist = Conv2D(self.config.n_rays, (1,1), name='dist', padding='same', activation='linear')(unet)
|
322 |
+
|
323 |
+
# attach extra classification head when self.n_classes is given
|
324 |
+
if self._is_multiclass():
|
325 |
+
if self.config.net_conv_after_unet > 0:
|
326 |
+
unet_class = Conv2D(self.config.net_conv_after_unet, self.config.unet_kernel_size,
|
327 |
+
name='features_class', padding='same', activation=self.config.unet_activation)(unet_base)
|
328 |
+
else:
|
329 |
+
unet_class = unet_base
|
330 |
+
|
331 |
+
output_prob_class = Conv2D(self.config.n_classes+1, (1,1), name='prob_class', padding='same', activation='softmax')(unet_class)
|
332 |
+
return Model([input_img], [output_prob,output_dist,output_prob_class])
|
333 |
+
else:
|
334 |
+
return Model([input_img], [output_prob,output_dist])
|
335 |
+
|
336 |
+
|
337 |
+
def train(self, X, Y, validation_data, classes='auto', augmenter=None, seed=None, epochs=None, steps_per_epoch=None, workers=1):
|
338 |
+
"""Train the neural network with the given data.
|
339 |
+
|
340 |
+
Parameters
|
341 |
+
----------
|
342 |
+
X : tuple, list, `numpy.ndarray`, `keras.utils.Sequence`
|
343 |
+
Input images
|
344 |
+
Y : tuple, list, `numpy.ndarray`, `keras.utils.Sequence`
|
345 |
+
Label masks
|
346 |
+
classes (optional): 'auto' or iterable of same length as X
|
347 |
+
label id -> class id mapping for each label mask of Y if multiclass prediction is activated (n_classes > 0)
|
348 |
+
list of dicts with label id -> class id (1,...,n_classes)
|
349 |
+
'auto' -> all objects will be assigned to the first non-background class,
|
350 |
+
or will be ignored if config.n_classes is None
|
351 |
+
validation_data : tuple(:class:`numpy.ndarray`, :class:`numpy.ndarray`) or triple (if multiclass)
|
352 |
+
Tuple (triple if multiclass) of X,Y,[classes] validation data.
|
353 |
+
augmenter : None or callable
|
354 |
+
Function with expected signature ``xt, yt = augmenter(x, y)``
|
355 |
+
that takes in a single pair of input/label image (x,y) and returns
|
356 |
+
the transformed images (xt, yt) for the purpose of data augmentation
|
357 |
+
during training. Not applied to validation images.
|
358 |
+
Example:
|
359 |
+
def simple_augmenter(x,y):
|
360 |
+
x = x + 0.05*np.random.normal(0,1,x.shape)
|
361 |
+
return x,y
|
362 |
+
seed : int
|
363 |
+
Convenience to set ``np.random.seed(seed)``. (To obtain reproducible validation patches, etc.)
|
364 |
+
epochs : int
|
365 |
+
Optional argument to use instead of the value from ``config``.
|
366 |
+
steps_per_epoch : int
|
367 |
+
Optional argument to use instead of the value from ``config``.
|
368 |
+
|
369 |
+
Returns
|
370 |
+
-------
|
371 |
+
``History`` object
|
372 |
+
See `Keras training history <https://keras.io/models/model/#fit>`_.
|
373 |
+
|
374 |
+
"""
|
375 |
+
if seed is not None:
|
376 |
+
# https://keras.io/getting-started/faq/#how-can-i-obtain-reproducible-results-using-keras-during-development
|
377 |
+
np.random.seed(seed)
|
378 |
+
if epochs is None:
|
379 |
+
epochs = self.config.train_epochs
|
380 |
+
if steps_per_epoch is None:
|
381 |
+
steps_per_epoch = self.config.train_steps_per_epoch
|
382 |
+
|
383 |
+
classes = self._parse_classes_arg(classes, len(X))
|
384 |
+
|
385 |
+
if not self._is_multiclass() and classes is not None:
|
386 |
+
warnings.warn("Ignoring given classes as n_classes is set to None")
|
387 |
+
|
388 |
+
isinstance(validation_data,(list,tuple)) or _raise(ValueError())
|
389 |
+
if self._is_multiclass() and len(validation_data) == 2:
|
390 |
+
validation_data = tuple(validation_data) + ('auto',)
|
391 |
+
((len(validation_data) == (3 if self._is_multiclass() else 2))
|
392 |
+
or _raise(ValueError(f'len(validation_data) = {len(validation_data)}, but should be {3 if self._is_multiclass() else 2}')))
|
393 |
+
|
394 |
+
patch_size = self.config.train_patch_size
|
395 |
+
axes = self.config.axes.replace('C','')
|
396 |
+
b = self.config.train_completion_crop if self.config.train_shape_completion else 0
|
397 |
+
div_by = self._axes_div_by(axes)
|
398 |
+
[(p-2*b) % d == 0 or _raise(ValueError(
|
399 |
+
"'train_patch_size' - 2*'train_completion_crop' must be divisible by {d} along axis '{a}'".format(a=a,d=d) if self.config.train_shape_completion else
|
400 |
+
"'train_patch_size' must be divisible by {d} along axis '{a}'".format(a=a,d=d)
|
401 |
+
)) for p,d,a in zip(patch_size,div_by,axes)]
|
402 |
+
|
403 |
+
if not self._model_prepared:
|
404 |
+
self.prepare_for_training()
|
405 |
+
|
406 |
+
data_kwargs = dict (
|
407 |
+
n_rays = self.config.n_rays,
|
408 |
+
patch_size = self.config.train_patch_size,
|
409 |
+
grid = self.config.grid,
|
410 |
+
shape_completion = self.config.train_shape_completion,
|
411 |
+
b = self.config.train_completion_crop,
|
412 |
+
use_gpu = self.config.use_gpu,
|
413 |
+
foreground_prob = self.config.train_foreground_only,
|
414 |
+
n_classes = self.config.n_classes,
|
415 |
+
sample_ind_cache = self.config.train_sample_cache,
|
416 |
+
)
|
417 |
+
|
418 |
+
# generate validation data and store in numpy arrays
|
419 |
+
n_data_val = len(validation_data[0])
|
420 |
+
classes_val = self._parse_classes_arg(validation_data[2], n_data_val) if self._is_multiclass() else None
|
421 |
+
n_take = self.config.train_n_val_patches if self.config.train_n_val_patches is not None else n_data_val
|
422 |
+
_data_val = StarDistData2D(validation_data[0],validation_data[1], classes=classes_val, batch_size=n_take, length=1, **data_kwargs)
|
423 |
+
data_val = _data_val[0]
|
424 |
+
|
425 |
+
# expose data generator as member for general diagnostics
|
426 |
+
self.data_train = StarDistData2D(X, Y, classes=classes, batch_size=self.config.train_batch_size,
|
427 |
+
augmenter=augmenter, length=epochs*steps_per_epoch, **data_kwargs)
|
428 |
+
|
429 |
+
if self.config.train_tensorboard:
|
430 |
+
# show dist for three rays
|
431 |
+
_n = min(3, self.config.n_rays)
|
432 |
+
channel = axes_dict(self.config.axes)['C']
|
433 |
+
output_slices = [[slice(None)]*4,[slice(None)]*4]
|
434 |
+
output_slices[1][1+channel] = slice(0,(self.config.n_rays//_n)*_n, self.config.n_rays//_n)
|
435 |
+
if self._is_multiclass():
|
436 |
+
_n = min(3, self.config.n_classes)
|
437 |
+
output_slices += [[slice(None)]*4]
|
438 |
+
output_slices[2][1+channel] = slice(1,1+(self.config.n_classes//_n)*_n, self.config.n_classes//_n)
|
439 |
+
|
440 |
+
if IS_TF_1:
|
441 |
+
for cb in self.callbacks:
|
442 |
+
if isinstance(cb,CARETensorBoard):
|
443 |
+
cb.output_slices = output_slices
|
444 |
+
# target image for dist includes dist_mask and thus has more channels than dist output
|
445 |
+
cb.output_target_shapes = [None,[None]*4,None]
|
446 |
+
cb.output_target_shapes[1][1+channel] = data_val[1][1].shape[1+channel]
|
447 |
+
elif self.basedir is not None and not any(isinstance(cb,CARETensorBoardImage) for cb in self.callbacks):
|
448 |
+
self.callbacks.append(CARETensorBoardImage(model=self.keras_model, data=data_val, log_dir=str(self.logdir/'logs'/'images'),
|
449 |
+
n_images=3, prob_out=False, output_slices=output_slices))
|
450 |
+
|
451 |
+
fit = self.keras_model.fit_generator if IS_TF_1 else self.keras_model.fit
|
452 |
+
history = fit(iter(self.data_train), validation_data=data_val,
|
453 |
+
epochs=epochs, steps_per_epoch=steps_per_epoch,
|
454 |
+
workers=workers, use_multiprocessing=workers>1,
|
455 |
+
callbacks=self.callbacks, verbose=1,
|
456 |
+
# set validation batchsize to training batchsize (only works for tf >= 2.2)
|
457 |
+
**(dict(validation_batch_size = self.config.train_batch_size) if _tf_version_at_least("2.2.0") else {}))
|
458 |
+
self._training_finished()
|
459 |
+
|
460 |
+
return history
|
461 |
+
|
462 |
+
|
463 |
+
# def _instances_from_prediction_old(self, img_shape, prob, dist,points = None, prob_class = None, prob_thresh=None, nms_thresh=None, overlap_label = None, **nms_kwargs):
|
464 |
+
# from stardist.geometry.geom2d import _polygons_to_label_old, _dist_to_coord_old
|
465 |
+
# from stardist.nms import _non_maximum_suppression_old
|
466 |
+
|
467 |
+
# if prob_thresh is None: prob_thresh = self.thresholds.prob
|
468 |
+
# if nms_thresh is None: nms_thresh = self.thresholds.nms
|
469 |
+
# if overlap_label is not None: raise NotImplementedError("overlap_label not supported for 2D yet!")
|
470 |
+
|
471 |
+
# coord = _dist_to_coord_old(dist, grid=self.config.grid)
|
472 |
+
# inds = _non_maximum_suppression_old(coord, prob, grid=self.config.grid,
|
473 |
+
# prob_thresh=prob_thresh, nms_thresh=nms_thresh, **nms_kwargs)
|
474 |
+
# labels = _polygons_to_label_old(coord, prob, inds, shape=img_shape)
|
475 |
+
# # sort 'inds' such that ids in 'labels' map to entries in polygon dictionary entries
|
476 |
+
# inds = inds[np.argsort(prob[inds[:,0],inds[:,1]])]
|
477 |
+
# # adjust for grid
|
478 |
+
# points = inds*np.array(self.config.grid)
|
479 |
+
|
480 |
+
# res_dict = dict(coord=coord[inds[:,0],inds[:,1]], points=points, prob=prob[inds[:,0],inds[:,1]])
|
481 |
+
|
482 |
+
# if prob_class is not None:
|
483 |
+
# prob_class = np.asarray(prob_class)
|
484 |
+
# res_dict.update(dict(class_prob = prob_class))
|
485 |
+
|
486 |
+
# return labels, res_dict
|
487 |
+
|
488 |
+
|
489 |
+
def _instances_from_prediction(self, img_shape, prob, dist, points=None, prob_class=None, prob_thresh=None, nms_thresh=None, overlap_label=None, return_labels=True, scale=None, **nms_kwargs):
|
490 |
+
"""
|
491 |
+
if points is None -> dense prediction
|
492 |
+
if points is not None -> sparse prediction
|
493 |
+
|
494 |
+
if prob_class is None -> single class prediction
|
495 |
+
if prob_class is not None -> multi class prediction
|
496 |
+
"""
|
497 |
+
if prob_thresh is None: prob_thresh = self.thresholds.prob
|
498 |
+
if nms_thresh is None: nms_thresh = self.thresholds.nms
|
499 |
+
if overlap_label is not None: raise NotImplementedError("overlap_label not supported for 2D yet!")
|
500 |
+
|
501 |
+
# sparse prediction
|
502 |
+
if points is not None:
|
503 |
+
points, probi, disti, indsi = non_maximum_suppression_sparse(dist, prob, points, nms_thresh=nms_thresh, **nms_kwargs)
|
504 |
+
if prob_class is not None:
|
505 |
+
prob_class = prob_class[indsi]
|
506 |
+
|
507 |
+
# dense prediction
|
508 |
+
else:
|
509 |
+
points, probi, disti = non_maximum_suppression(dist, prob, grid=self.config.grid,
|
510 |
+
prob_thresh=prob_thresh, nms_thresh=nms_thresh, **nms_kwargs)
|
511 |
+
if prob_class is not None:
|
512 |
+
inds = tuple(p//g for p,g in zip(points.T, self.config.grid))
|
513 |
+
prob_class = prob_class[inds]
|
514 |
+
|
515 |
+
if scale is not None:
|
516 |
+
# need to undo the scaling given by the scale dict, e.g. scale = dict(X=0.5,Y=0.5):
|
517 |
+
# 1. re-scale points (origins of polygons)
|
518 |
+
# 2. re-scale coordinates (computed from distances) of (zero-origin) polygons
|
519 |
+
if not (isinstance(scale,dict) and 'X' in scale and 'Y' in scale):
|
520 |
+
raise ValueError("scale must be a dictionary with entries for 'X' and 'Y'")
|
521 |
+
rescale = (1/scale['Y'],1/scale['X'])
|
522 |
+
points = points * np.array(rescale).reshape(1,2)
|
523 |
+
else:
|
524 |
+
rescale = (1,1)
|
525 |
+
|
526 |
+
if return_labels:
|
527 |
+
labels = polygons_to_label(disti, points, prob=probi, shape=img_shape, scale_dist=rescale)
|
528 |
+
else:
|
529 |
+
labels = None
|
530 |
+
|
531 |
+
coord = dist_to_coord(disti, points, scale_dist=rescale)
|
532 |
+
res_dict = dict(coord=coord, points=points, prob=probi)
|
533 |
+
|
534 |
+
# multi class prediction
|
535 |
+
if prob_class is not None:
|
536 |
+
prob_class = np.asarray(prob_class)
|
537 |
+
class_id = np.argmax(prob_class, axis=-1)
|
538 |
+
res_dict.update(dict(class_prob=prob_class, class_id=class_id))
|
539 |
+
|
540 |
+
return labels, res_dict
|
541 |
+
|
542 |
+
|
543 |
+
def _axes_div_by(self, query_axes):
|
544 |
+
self.config.backbone == 'unet' or _raise(NotImplementedError())
|
545 |
+
query_axes = axes_check_and_normalize(query_axes)
|
546 |
+
assert len(self.config.unet_pool) == len(self.config.grid)
|
547 |
+
div_by = dict(zip(
|
548 |
+
self.config.axes.replace('C',''),
|
549 |
+
tuple(p**self.config.unet_n_depth * g for p,g in zip(self.config.unet_pool,self.config.grid))
|
550 |
+
))
|
551 |
+
return tuple(div_by.get(a,1) for a in query_axes)
|
552 |
+
|
553 |
+
|
554 |
+
# def _axes_tile_overlap(self, query_axes):
|
555 |
+
# self.config.backbone == 'unet' or _raise(NotImplementedError())
|
556 |
+
# query_axes = axes_check_and_normalize(query_axes)
|
557 |
+
# assert len(self.config.unet_pool) == len(self.config.grid) == len(self.config.unet_kernel_size)
|
558 |
+
# # TODO: compute this properly when any value of grid > 1
|
559 |
+
# # all(g==1 for g in self.config.grid) or warnings.warn('FIXME')
|
560 |
+
# overlap = dict(zip(
|
561 |
+
# self.config.axes.replace('C',''),
|
562 |
+
# tuple(tile_overlap(self.config.unet_n_depth + int(np.log2(g)), k, p)
|
563 |
+
# for p,k,g in zip(self.config.unet_pool,self.config.unet_kernel_size,self.config.grid))
|
564 |
+
# ))
|
565 |
+
# return tuple(overlap.get(a,0) for a in query_axes)
|
566 |
+
|
567 |
+
|
568 |
+
@property
|
569 |
+
def _config_class(self):
|
570 |
+
return Config2D
|
stardist_pkg/nms.py
ADDED
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
2 |
+
import numpy as np
|
3 |
+
from time import time
|
4 |
+
from .utils import _normalize_grid
|
5 |
+
|
6 |
+
def _ind_prob_thresh(prob, prob_thresh, b=2):
|
7 |
+
if b is not None and np.isscalar(b):
|
8 |
+
b = ((b,b),)*prob.ndim
|
9 |
+
|
10 |
+
ind_thresh = prob > prob_thresh
|
11 |
+
if b is not None:
|
12 |
+
_ind_thresh = np.zeros_like(ind_thresh)
|
13 |
+
ss = tuple(slice(_bs[0] if _bs[0]>0 else None,
|
14 |
+
-_bs[1] if _bs[1]>0 else None) for _bs in b)
|
15 |
+
_ind_thresh[ss] = True
|
16 |
+
ind_thresh &= _ind_thresh
|
17 |
+
return ind_thresh
|
18 |
+
|
19 |
+
|
20 |
+
def _non_maximum_suppression_old(coord, prob, grid=(1,1), b=2, nms_thresh=0.5, prob_thresh=0.5, verbose=False, max_bbox_search=True):
|
21 |
+
"""2D coordinates of the polys that survive from a given prediction (prob, coord)
|
22 |
+
|
23 |
+
prob.shape = (Ny,Nx)
|
24 |
+
coord.shape = (Ny,Nx,2,n_rays)
|
25 |
+
|
26 |
+
b: don't use pixel closer than b pixels to the image boundary
|
27 |
+
|
28 |
+
returns retained points
|
29 |
+
"""
|
30 |
+
from .lib.stardist2d import c_non_max_suppression_inds_old
|
31 |
+
|
32 |
+
# TODO: using b>0 with grid>1 can suppress small/cropped objects at the image boundary
|
33 |
+
|
34 |
+
assert prob.ndim == 2
|
35 |
+
assert coord.ndim == 4
|
36 |
+
grid = _normalize_grid(grid,2)
|
37 |
+
|
38 |
+
# mask = prob > prob_thresh
|
39 |
+
# if b is not None and b > 0:
|
40 |
+
# _mask = np.zeros_like(mask)
|
41 |
+
# _mask[b:-b,b:-b] = True
|
42 |
+
# mask &= _mask
|
43 |
+
|
44 |
+
mask = _ind_prob_thresh(prob, prob_thresh, b)
|
45 |
+
|
46 |
+
polygons = coord[mask]
|
47 |
+
scores = prob[mask]
|
48 |
+
|
49 |
+
# sort scores descendingly
|
50 |
+
ind = np.argsort(scores)[::-1]
|
51 |
+
survivors = np.zeros(len(ind), bool)
|
52 |
+
polygons = polygons[ind]
|
53 |
+
scores = scores[ind]
|
54 |
+
|
55 |
+
if max_bbox_search:
|
56 |
+
# map pixel indices to ids of sorted polygons (-1 => polygon at that pixel not a candidate)
|
57 |
+
mapping = -np.ones(mask.shape,np.int32)
|
58 |
+
mapping.flat[ np.flatnonzero(mask)[ind] ] = range(len(ind))
|
59 |
+
else:
|
60 |
+
mapping = np.empty((0,0),np.int32)
|
61 |
+
|
62 |
+
if verbose:
|
63 |
+
t = time()
|
64 |
+
|
65 |
+
survivors[ind] = c_non_max_suppression_inds_old(np.ascontiguousarray(polygons.astype(np.int32)),
|
66 |
+
mapping, np.float32(nms_thresh), np.int32(max_bbox_search),
|
67 |
+
np.int32(grid[0]), np.int32(grid[1]),np.int32(verbose))
|
68 |
+
|
69 |
+
if verbose:
|
70 |
+
print("keeping %s/%s polygons" % (np.count_nonzero(survivors), len(polygons)))
|
71 |
+
print("NMS took %.4f s" % (time() - t))
|
72 |
+
|
73 |
+
points = np.stack([ii[survivors] for ii in np.nonzero(mask)],axis=-1)
|
74 |
+
return points
|
75 |
+
|
76 |
+
|
77 |
+
def non_maximum_suppression(dist, prob, grid=(1,1), b=2, nms_thresh=0.5, prob_thresh=0.5,
|
78 |
+
use_bbox=True, use_kdtree=True, verbose=False,cut=False):
|
79 |
+
"""Non-Maximum-Supression of 2D polygons
|
80 |
+
|
81 |
+
Retains only polygons whose overlap is smaller than nms_thresh
|
82 |
+
|
83 |
+
dist.shape = (Ny,Nx, n_rays)
|
84 |
+
prob.shape = (Ny,Nx)
|
85 |
+
|
86 |
+
returns the retained points, probabilities, and distances:
|
87 |
+
|
88 |
+
points, prob, dist = non_maximum_suppression(dist, prob, ....
|
89 |
+
|
90 |
+
"""
|
91 |
+
|
92 |
+
# TODO: using b>0 with grid>1 can suppress small/cropped objects at the image boundary
|
93 |
+
|
94 |
+
assert prob.ndim == 2 and dist.ndim == 3 and prob.shape == dist.shape[:2]
|
95 |
+
dist = np.asarray(dist)
|
96 |
+
prob = np.asarray(prob)
|
97 |
+
n_rays = dist.shape[-1]
|
98 |
+
|
99 |
+
grid = _normalize_grid(grid,2)
|
100 |
+
|
101 |
+
# mask = prob > prob_thresh
|
102 |
+
# if b is not None and b > 0:
|
103 |
+
# _mask = np.zeros_like(mask)
|
104 |
+
# _mask[b:-b,b:-b] = True
|
105 |
+
# mask &= _mask
|
106 |
+
|
107 |
+
mask = _ind_prob_thresh(prob, prob_thresh, b)
|
108 |
+
points = np.stack(np.where(mask), axis=1)
|
109 |
+
|
110 |
+
dist = dist[mask]
|
111 |
+
scores = prob[mask]
|
112 |
+
|
113 |
+
# sort scores descendingly
|
114 |
+
ind = np.argsort(scores)[::-1]
|
115 |
+
if cut is True and ind.shape[0] > 20000:
|
116 |
+
#if cut is True and :
|
117 |
+
ind = ind[:round(ind.shape[0]*0.5)]
|
118 |
+
dist = dist[ind]
|
119 |
+
scores = scores[ind]
|
120 |
+
points = points[ind]
|
121 |
+
|
122 |
+
points = (points * np.array(grid).reshape((1,2)))
|
123 |
+
|
124 |
+
if verbose:
|
125 |
+
t = time()
|
126 |
+
|
127 |
+
inds = non_maximum_suppression_inds(dist, points.astype(np.int32, copy=False), scores=scores,
|
128 |
+
use_bbox=use_bbox, use_kdtree=use_kdtree,
|
129 |
+
thresh=nms_thresh, verbose=verbose)
|
130 |
+
|
131 |
+
if verbose:
|
132 |
+
print("keeping %s/%s polygons" % (np.count_nonzero(inds), len(inds)))
|
133 |
+
print("NMS took %.4f s" % (time() - t))
|
134 |
+
|
135 |
+
return points[inds], scores[inds], dist[inds]
|
136 |
+
|
137 |
+
|
138 |
+
def non_maximum_suppression_sparse(dist, prob, points, b=2, nms_thresh=0.5,
|
139 |
+
use_bbox=True, use_kdtree = True, verbose=False):
|
140 |
+
"""Non-Maximum-Supression of 2D polygons from a list of dists, probs (scores), and points
|
141 |
+
|
142 |
+
Retains only polyhedra whose overlap is smaller than nms_thresh
|
143 |
+
|
144 |
+
dist.shape = (n_polys, n_rays)
|
145 |
+
prob.shape = (n_polys,)
|
146 |
+
points.shape = (n_polys,2)
|
147 |
+
|
148 |
+
returns the retained instances
|
149 |
+
|
150 |
+
(pointsi, probi, disti, indsi)
|
151 |
+
|
152 |
+
with
|
153 |
+
pointsi = points[indsi] ...
|
154 |
+
|
155 |
+
"""
|
156 |
+
|
157 |
+
# TODO: using b>0 with grid>1 can suppress small/cropped objects at the image boundary
|
158 |
+
|
159 |
+
dist = np.asarray(dist)
|
160 |
+
prob = np.asarray(prob)
|
161 |
+
points = np.asarray(points)
|
162 |
+
n_rays = dist.shape[-1]
|
163 |
+
|
164 |
+
assert dist.ndim == 2 and prob.ndim == 1 and points.ndim == 2 and \
|
165 |
+
points.shape[-1]==2 and len(prob) == len(dist) == len(points)
|
166 |
+
|
167 |
+
verbose and print("predicting instances with nms_thresh = {nms_thresh}".format(nms_thresh=nms_thresh), flush=True)
|
168 |
+
|
169 |
+
inds_original = np.arange(len(prob))
|
170 |
+
_sorted = np.argsort(prob)[::-1]
|
171 |
+
probi = prob[_sorted]
|
172 |
+
disti = dist[_sorted]
|
173 |
+
pointsi = points[_sorted]
|
174 |
+
inds_original = inds_original[_sorted]
|
175 |
+
|
176 |
+
if verbose:
|
177 |
+
print("non-maximum suppression...")
|
178 |
+
t = time()
|
179 |
+
|
180 |
+
inds = non_maximum_suppression_inds(disti, pointsi, scores=probi, thresh=nms_thresh, use_kdtree = use_kdtree, verbose=verbose)
|
181 |
+
|
182 |
+
if verbose:
|
183 |
+
print("keeping %s/%s polyhedra" % (np.count_nonzero(inds), len(inds)))
|
184 |
+
print("NMS took %.4f s" % (time() - t))
|
185 |
+
|
186 |
+
return pointsi[inds], probi[inds], disti[inds], inds_original[inds]
|
187 |
+
|
188 |
+
|
189 |
+
def non_maximum_suppression_inds(dist, points, scores, thresh=0.5, use_bbox=True, use_kdtree = True, verbose=1):
|
190 |
+
"""
|
191 |
+
Applies non maximum supression to ray-convex polygons given by dists and points
|
192 |
+
sorted by scores and IoU threshold
|
193 |
+
|
194 |
+
P1 will suppress P2, if IoU(P1,P2) > thresh
|
195 |
+
|
196 |
+
with IoU(P1,P2) = Ainter(P1,P2) / min(A(P1),A(P2))
|
197 |
+
|
198 |
+
i.e. the smaller thresh, the more polygons will be supressed
|
199 |
+
|
200 |
+
dist.shape = (n_poly, n_rays)
|
201 |
+
point.shape = (n_poly, 2)
|
202 |
+
score.shape = (n_poly,)
|
203 |
+
|
204 |
+
returns indices of selected polygons
|
205 |
+
"""
|
206 |
+
|
207 |
+
from stardist.lib.stardist2d import c_non_max_suppression_inds
|
208 |
+
|
209 |
+
assert dist.ndim == 2
|
210 |
+
assert points.ndim == 2
|
211 |
+
|
212 |
+
n_poly = dist.shape[0]
|
213 |
+
|
214 |
+
if scores is None:
|
215 |
+
scores = np.ones(n_poly)
|
216 |
+
|
217 |
+
assert len(scores) == n_poly
|
218 |
+
assert points.shape[0] == n_poly
|
219 |
+
|
220 |
+
def _prep(x, dtype):
|
221 |
+
return np.ascontiguousarray(x.astype(dtype, copy=False))
|
222 |
+
|
223 |
+
inds = c_non_max_suppression_inds(_prep(dist, np.float32),
|
224 |
+
_prep(points, np.float32),
|
225 |
+
int(use_kdtree),
|
226 |
+
int(use_bbox),
|
227 |
+
int(verbose),
|
228 |
+
np.float32(thresh))
|
229 |
+
|
230 |
+
return inds
|
231 |
+
|
232 |
+
|
233 |
+
#########
|
234 |
+
|
235 |
+
|
236 |
+
def non_maximum_suppression_3d(dist, prob, rays, grid=(1,1,1), b=2, nms_thresh=0.5, prob_thresh=0.5, use_bbox=True, use_kdtree=True, verbose=False):
|
237 |
+
"""Non-Maximum-Supression of 3D polyhedra
|
238 |
+
|
239 |
+
Retains only polyhedra whose overlap is smaller than nms_thresh
|
240 |
+
|
241 |
+
dist.shape = (Nz,Ny,Nx, n_rays)
|
242 |
+
prob.shape = (Nz,Ny,Nx)
|
243 |
+
|
244 |
+
returns the retained points, probabilities, and distances:
|
245 |
+
|
246 |
+
points, prob, dist = non_maximum_suppression_3d(dist, prob, ....
|
247 |
+
"""
|
248 |
+
|
249 |
+
# TODO: using b>0 with grid>1 can suppress small/cropped objects at the image boundary
|
250 |
+
|
251 |
+
dist = np.asarray(dist)
|
252 |
+
prob = np.asarray(prob)
|
253 |
+
|
254 |
+
assert prob.ndim == 3 and dist.ndim == 4 and dist.shape[-1] == len(rays) and prob.shape == dist.shape[:3]
|
255 |
+
|
256 |
+
grid = _normalize_grid(grid,3)
|
257 |
+
|
258 |
+
verbose and print("predicting instances with prob_thresh = {prob_thresh} and nms_thresh = {nms_thresh}".format(prob_thresh=prob_thresh, nms_thresh=nms_thresh), flush=True)
|
259 |
+
|
260 |
+
# ind_thresh = prob > prob_thresh
|
261 |
+
# if b is not None and b > 0:
|
262 |
+
# _ind_thresh = np.zeros_like(ind_thresh)
|
263 |
+
# _ind_thresh[b:-b,b:-b,b:-b] = True
|
264 |
+
# ind_thresh &= _ind_thresh
|
265 |
+
|
266 |
+
ind_thresh = _ind_prob_thresh(prob, prob_thresh, b)
|
267 |
+
points = np.stack(np.where(ind_thresh), axis=1)
|
268 |
+
verbose and print("found %s candidates"%len(points))
|
269 |
+
probi = prob[ind_thresh]
|
270 |
+
disti = dist[ind_thresh]
|
271 |
+
|
272 |
+
_sorted = np.argsort(probi)[::-1]
|
273 |
+
probi = probi[_sorted]
|
274 |
+
disti = disti[_sorted]
|
275 |
+
points = points[_sorted]
|
276 |
+
|
277 |
+
verbose and print("non-maximum suppression...")
|
278 |
+
points = (points * np.array(grid).reshape((1,3)))
|
279 |
+
|
280 |
+
inds = non_maximum_suppression_3d_inds(disti, points, rays=rays, scores=probi, thresh=nms_thresh,
|
281 |
+
use_bbox=use_bbox, use_kdtree = use_kdtree,
|
282 |
+
verbose=verbose)
|
283 |
+
|
284 |
+
verbose and print("keeping %s/%s polyhedra" % (np.count_nonzero(inds), len(inds)))
|
285 |
+
return points[inds], probi[inds], disti[inds]
|
286 |
+
|
287 |
+
|
288 |
+
def non_maximum_suppression_3d_sparse(dist, prob, points, rays, b=2, nms_thresh=0.5, use_kdtree = True, verbose=False):
|
289 |
+
"""Non-Maximum-Supression of 3D polyhedra from a list of dists, probs and points
|
290 |
+
|
291 |
+
Retains only polyhedra whose overlap is smaller than nms_thresh
|
292 |
+
dist.shape = (n_polys, n_rays)
|
293 |
+
prob.shape = (n_polys,)
|
294 |
+
points.shape = (n_polys,3)
|
295 |
+
|
296 |
+
returns the retained instances
|
297 |
+
|
298 |
+
(pointsi, probi, disti, indsi)
|
299 |
+
|
300 |
+
with
|
301 |
+
pointsi = points[indsi] ...
|
302 |
+
"""
|
303 |
+
|
304 |
+
# TODO: using b>0 with grid>1 can suppress small/cropped objects at the image boundary
|
305 |
+
|
306 |
+
dist = np.asarray(dist)
|
307 |
+
prob = np.asarray(prob)
|
308 |
+
points = np.asarray(points)
|
309 |
+
|
310 |
+
assert dist.ndim == 2 and prob.ndim == 1 and points.ndim == 2 and \
|
311 |
+
dist.shape[-1] == len(rays) and points.shape[-1]==3 and len(prob) == len(dist) == len(points)
|
312 |
+
|
313 |
+
verbose and print("predicting instances with nms_thresh = {nms_thresh}".format(nms_thresh=nms_thresh), flush=True)
|
314 |
+
|
315 |
+
inds_original = np.arange(len(prob))
|
316 |
+
_sorted = np.argsort(prob)[::-1]
|
317 |
+
probi = prob[_sorted]
|
318 |
+
disti = dist[_sorted]
|
319 |
+
pointsi = points[_sorted]
|
320 |
+
inds_original = inds_original[_sorted]
|
321 |
+
|
322 |
+
verbose and print("non-maximum suppression...")
|
323 |
+
|
324 |
+
inds = non_maximum_suppression_3d_inds(disti, pointsi, rays=rays, scores=probi, thresh=nms_thresh, use_kdtree = use_kdtree, verbose=verbose)
|
325 |
+
|
326 |
+
verbose and print("keeping %s/%s polyhedra" % (np.count_nonzero(inds), len(inds)))
|
327 |
+
return pointsi[inds], probi[inds], disti[inds], inds_original[inds]
|
328 |
+
|
329 |
+
|
330 |
+
def non_maximum_suppression_3d_inds(dist, points, rays, scores, thresh=0.5, use_bbox=True, use_kdtree = True, verbose=1):
|
331 |
+
"""
|
332 |
+
Applies non maximum supression to ray-convex polyhedra given by dists and rays
|
333 |
+
sorted by scores and IoU threshold
|
334 |
+
|
335 |
+
P1 will suppress P2, if IoU(P1,P2) > thresh
|
336 |
+
|
337 |
+
with IoU(P1,P2) = Ainter(P1,P2) / min(A(P1),A(P2))
|
338 |
+
|
339 |
+
i.e. the smaller thresh, the more polygons will be supressed
|
340 |
+
|
341 |
+
dist.shape = (n_poly, n_rays)
|
342 |
+
point.shape = (n_poly, 3)
|
343 |
+
score.shape = (n_poly,)
|
344 |
+
|
345 |
+
returns indices of selected polygons
|
346 |
+
"""
|
347 |
+
from .lib.stardist3d import c_non_max_suppression_inds
|
348 |
+
|
349 |
+
assert dist.ndim == 2
|
350 |
+
assert points.ndim == 2
|
351 |
+
assert dist.shape[1] == len(rays)
|
352 |
+
|
353 |
+
n_poly = dist.shape[0]
|
354 |
+
|
355 |
+
if scores is None:
|
356 |
+
scores = np.ones(n_poly)
|
357 |
+
|
358 |
+
assert len(scores) == n_poly
|
359 |
+
assert points.shape[0] == n_poly
|
360 |
+
|
361 |
+
# sort scores descendingly
|
362 |
+
ind = np.argsort(scores)[::-1]
|
363 |
+
survivors = np.ones(n_poly, bool)
|
364 |
+
dist = dist[ind]
|
365 |
+
points = points[ind]
|
366 |
+
scores = scores[ind]
|
367 |
+
|
368 |
+
def _prep(x, dtype):
|
369 |
+
return np.ascontiguousarray(x.astype(dtype, copy=False))
|
370 |
+
|
371 |
+
if verbose:
|
372 |
+
t = time()
|
373 |
+
|
374 |
+
survivors[ind] = c_non_max_suppression_inds(_prep(dist, np.float32),
|
375 |
+
_prep(points, np.float32),
|
376 |
+
_prep(rays.vertices, np.float32),
|
377 |
+
_prep(rays.faces, np.int32),
|
378 |
+
_prep(scores, np.float32),
|
379 |
+
int(use_bbox),
|
380 |
+
int(use_kdtree),
|
381 |
+
int(verbose),
|
382 |
+
np.float32(thresh))
|
383 |
+
|
384 |
+
if verbose:
|
385 |
+
print("NMS took %.4f s" % (time() - t))
|
386 |
+
|
387 |
+
return survivors
|
stardist_pkg/rays3d.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Ray factory
|
3 |
+
|
4 |
+
classes that provide vertex and triangle information for rays on spheres
|
5 |
+
|
6 |
+
Example:
|
7 |
+
|
8 |
+
rays = Rays_Tetra(n_level = 4)
|
9 |
+
|
10 |
+
print(rays.vertices)
|
11 |
+
print(rays.faces)
|
12 |
+
|
13 |
+
"""
|
14 |
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
15 |
+
import numpy as np
|
16 |
+
from scipy.spatial import ConvexHull
|
17 |
+
import copy
|
18 |
+
import warnings
|
19 |
+
|
20 |
+
class Rays_Base(object):
|
21 |
+
def __init__(self, **kwargs):
|
22 |
+
self.kwargs = kwargs
|
23 |
+
self._vertices, self._faces = self.setup_vertices_faces()
|
24 |
+
self._vertices = np.asarray(self._vertices, np.float32)
|
25 |
+
self._faces = np.asarray(self._faces, int)
|
26 |
+
self._faces = np.asanyarray(self._faces)
|
27 |
+
|
28 |
+
def setup_vertices_faces(self):
|
29 |
+
"""has to return
|
30 |
+
|
31 |
+
verts , faces
|
32 |
+
|
33 |
+
verts = ( (z_1,y_1,x_1), ... )
|
34 |
+
faces ( (0,1,2), (2,3,4), ... )
|
35 |
+
|
36 |
+
"""
|
37 |
+
raise NotImplementedError()
|
38 |
+
|
39 |
+
@property
|
40 |
+
def vertices(self):
|
41 |
+
"""read-only property"""
|
42 |
+
return self._vertices.copy()
|
43 |
+
|
44 |
+
@property
|
45 |
+
def faces(self):
|
46 |
+
"""read-only property"""
|
47 |
+
return self._faces.copy()
|
48 |
+
|
49 |
+
def __getitem__(self, i):
|
50 |
+
return self.vertices[i]
|
51 |
+
|
52 |
+
def __len__(self):
|
53 |
+
return len(self._vertices)
|
54 |
+
|
55 |
+
def __repr__(self):
|
56 |
+
def _conv(x):
|
57 |
+
if isinstance(x,(tuple, list, np.ndarray)):
|
58 |
+
return "_".join(_conv(_x) for _x in x)
|
59 |
+
if isinstance(x,float):
|
60 |
+
return "%.2f"%x
|
61 |
+
return str(x)
|
62 |
+
return "%s_%s" % (self.__class__.__name__, "_".join("%s_%s" % (k, _conv(v)) for k, v in sorted(self.kwargs.items())))
|
63 |
+
|
64 |
+
def to_json(self):
|
65 |
+
return {
|
66 |
+
"name": self.__class__.__name__,
|
67 |
+
"kwargs": self.kwargs
|
68 |
+
}
|
69 |
+
|
70 |
+
def dist_loss_weights(self, anisotropy = (1,1,1)):
|
71 |
+
"""returns the anisotropy corrected weights for each ray"""
|
72 |
+
anisotropy = np.array(anisotropy)
|
73 |
+
assert anisotropy.shape == (3,)
|
74 |
+
return np.linalg.norm(self.vertices*anisotropy, axis = -1)
|
75 |
+
|
76 |
+
def volume(self, dist=None):
|
77 |
+
"""volume of the starconvex polyhedron spanned by dist (if None, uses dist=1)
|
78 |
+
dist can be a nD array, but the last dimension has to be of length n_rays
|
79 |
+
"""
|
80 |
+
if dist is None: dist = np.ones_like(self.vertices)
|
81 |
+
|
82 |
+
dist = np.asarray(dist)
|
83 |
+
|
84 |
+
if not dist.shape[-1]==len(self.vertices):
|
85 |
+
raise ValueError("last dimension of dist should have length len(rays.vertices)")
|
86 |
+
# all the shuffling below is to allow dist to be an arbitrary sized array (with last dim n_rays)
|
87 |
+
# self.vertices -> (n_rays,3)
|
88 |
+
# dist -> (m,n,..., n_rays)
|
89 |
+
|
90 |
+
# dist -> (m,n,..., n_rays, 3)
|
91 |
+
dist = np.repeat(np.expand_dims(dist,-1), 3, axis = -1)
|
92 |
+
# verts -> (m,n,..., n_rays, 3)
|
93 |
+
verts = np.broadcast_to(self.vertices, dist.shape)
|
94 |
+
|
95 |
+
# dist, verts -> (n_rays, m,n, ..., 3)
|
96 |
+
dist = np.moveaxis(dist,-2,0)
|
97 |
+
verts = np.moveaxis(verts,-2,0)
|
98 |
+
|
99 |
+
# vs -> (n_faces, 3, m, n, ..., 3)
|
100 |
+
vs = (dist*verts)[self.faces]
|
101 |
+
# vs -> (n_faces, m, n, ..., 3, 3)
|
102 |
+
vs = np.moveaxis(vs, 1,-2)
|
103 |
+
# vs -> (n_faces * m * n, 3, 3)
|
104 |
+
vs = vs.reshape((len(self.faces)*int(np.prod(dist.shape[1:-1])),3,3))
|
105 |
+
d = np.linalg.det(list(vs)).reshape((len(self.faces),)+dist.shape[1:-1])
|
106 |
+
|
107 |
+
return -1./6*np.sum(d, axis = 0)
|
108 |
+
|
109 |
+
def surface(self, dist=None):
|
110 |
+
"""surface area of the starconvex polyhedron spanned by dist (if None, uses dist=1)"""
|
111 |
+
dist = np.asarray(dist)
|
112 |
+
|
113 |
+
if not dist.shape[-1]==len(self.vertices):
|
114 |
+
raise ValueError("last dimension of dist should have length len(rays.vertices)")
|
115 |
+
|
116 |
+
# self.vertices -> (n_rays,3)
|
117 |
+
# dist -> (m,n,..., n_rays)
|
118 |
+
|
119 |
+
# all the shuffling below is to allow dist to be an arbitrary sized array (with last dim n_rays)
|
120 |
+
|
121 |
+
# dist -> (m,n,..., n_rays, 3)
|
122 |
+
dist = np.repeat(np.expand_dims(dist,-1), 3, axis = -1)
|
123 |
+
# verts -> (m,n,..., n_rays, 3)
|
124 |
+
verts = np.broadcast_to(self.vertices, dist.shape)
|
125 |
+
|
126 |
+
# dist, verts -> (n_rays, m,n, ..., 3)
|
127 |
+
dist = np.moveaxis(dist,-2,0)
|
128 |
+
verts = np.moveaxis(verts,-2,0)
|
129 |
+
|
130 |
+
# vs -> (n_faces, 3, m, n, ..., 3)
|
131 |
+
vs = (dist*verts)[self.faces]
|
132 |
+
# vs -> (n_faces, m, n, ..., 3, 3)
|
133 |
+
vs = np.moveaxis(vs, 1,-2)
|
134 |
+
# vs -> (n_faces * m * n, 3, 3)
|
135 |
+
vs = vs.reshape((len(self.faces)*int(np.prod(dist.shape[1:-1])),3,3))
|
136 |
+
|
137 |
+
pa = vs[...,1,:]-vs[...,0,:]
|
138 |
+
pb = vs[...,2,:]-vs[...,0,:]
|
139 |
+
|
140 |
+
d = .5*np.linalg.norm(np.cross(list(pa), list(pb)), axis = -1)
|
141 |
+
d = d.reshape((len(self.faces),)+dist.shape[1:-1])
|
142 |
+
return np.sum(d, axis = 0)
|
143 |
+
|
144 |
+
|
145 |
+
def copy(self, scale=(1,1,1)):
|
146 |
+
""" returns a copy whose vertices are scaled by given factor"""
|
147 |
+
scale = np.asarray(scale)
|
148 |
+
assert scale.shape == (3,)
|
149 |
+
res = copy.deepcopy(self)
|
150 |
+
res._vertices *= scale[np.newaxis]
|
151 |
+
return res
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
def rays_from_json(d):
|
157 |
+
return eval(d["name"])(**d["kwargs"])
|
158 |
+
|
159 |
+
|
160 |
+
################################################################
|
161 |
+
|
162 |
+
class Rays_Explicit(Rays_Base):
|
163 |
+
def __init__(self, vertices0, faces0):
|
164 |
+
self.vertices0, self.faces0 = vertices0, faces0
|
165 |
+
super().__init__(vertices0=list(vertices0), faces0=list(faces0))
|
166 |
+
|
167 |
+
def setup_vertices_faces(self):
|
168 |
+
return self.vertices0, self.faces0
|
169 |
+
|
170 |
+
|
171 |
+
class Rays_Cartesian(Rays_Base):
|
172 |
+
def __init__(self, n_rays_x=11, n_rays_z=5):
|
173 |
+
super().__init__(n_rays_x=n_rays_x, n_rays_z=n_rays_z)
|
174 |
+
|
175 |
+
def setup_vertices_faces(self):
|
176 |
+
"""has to return list of ( (z_1,y_1,x_1), ... ) _"""
|
177 |
+
n_rays_x, n_rays_z = self.kwargs["n_rays_x"], self.kwargs["n_rays_z"]
|
178 |
+
dphi = np.float32(2. * np.pi / n_rays_x)
|
179 |
+
dtheta = np.float32(np.pi / n_rays_z)
|
180 |
+
|
181 |
+
verts = []
|
182 |
+
for mz in range(n_rays_z):
|
183 |
+
for mx in range(n_rays_x):
|
184 |
+
phi = mx * dphi
|
185 |
+
theta = mz * dtheta
|
186 |
+
if mz == 0:
|
187 |
+
theta = 1e-12
|
188 |
+
if mz == n_rays_z - 1:
|
189 |
+
theta = np.pi - 1e-12
|
190 |
+
dx = np.cos(phi) * np.sin(theta)
|
191 |
+
dy = np.sin(phi) * np.sin(theta)
|
192 |
+
dz = np.cos(theta)
|
193 |
+
if mz == 0 or mz == n_rays_z - 1:
|
194 |
+
dx += 1e-12
|
195 |
+
dy += 1e-12
|
196 |
+
verts.append([dz, dy, dx])
|
197 |
+
|
198 |
+
verts = np.array(verts)
|
199 |
+
|
200 |
+
def _ind(mz, mx):
|
201 |
+
return mz * n_rays_x + mx
|
202 |
+
|
203 |
+
faces = []
|
204 |
+
|
205 |
+
for mz in range(n_rays_z - 1):
|
206 |
+
for mx in range(n_rays_x):
|
207 |
+
faces.append([_ind(mz, mx), _ind(mz + 1, (mx + 1) % n_rays_x), _ind(mz, (mx + 1) % n_rays_x)])
|
208 |
+
faces.append([_ind(mz, mx), _ind(mz + 1, mx), _ind(mz + 1, (mx + 1) % n_rays_x)])
|
209 |
+
|
210 |
+
faces = np.array(faces)
|
211 |
+
|
212 |
+
return verts, faces
|
213 |
+
|
214 |
+
|
215 |
+
class Rays_SubDivide(Rays_Base):
|
216 |
+
"""
|
217 |
+
Subdivision polyehdra
|
218 |
+
|
219 |
+
n_level = 1 -> base polyhedra
|
220 |
+
n_level = 2 -> 1x subdivision
|
221 |
+
n_level = 3 -> 2x subdivision
|
222 |
+
...
|
223 |
+
"""
|
224 |
+
|
225 |
+
def __init__(self, n_level=4):
|
226 |
+
super().__init__(n_level=n_level)
|
227 |
+
|
228 |
+
def base_polyhedron(self):
|
229 |
+
raise NotImplementedError()
|
230 |
+
|
231 |
+
def setup_vertices_faces(self):
|
232 |
+
n_level = self.kwargs["n_level"]
|
233 |
+
verts0, faces0 = self.base_polyhedron()
|
234 |
+
return self._recursive_split(verts0, faces0, n_level)
|
235 |
+
|
236 |
+
def _recursive_split(self, verts, faces, n_level):
|
237 |
+
if n_level <= 1:
|
238 |
+
return verts, faces
|
239 |
+
else:
|
240 |
+
verts, faces = Rays_SubDivide.split(verts, faces)
|
241 |
+
return self._recursive_split(verts, faces, n_level - 1)
|
242 |
+
|
243 |
+
@classmethod
|
244 |
+
def split(self, verts0, faces0):
|
245 |
+
"""split a level"""
|
246 |
+
|
247 |
+
split_edges = dict()
|
248 |
+
verts = list(verts0[:])
|
249 |
+
faces = []
|
250 |
+
|
251 |
+
def _add(a, b):
|
252 |
+
""" returns index of middle point and adds vertex if not already added"""
|
253 |
+
edge = tuple(sorted((a, b)))
|
254 |
+
if not edge in split_edges:
|
255 |
+
v = .5 * (verts[a] + verts[b])
|
256 |
+
v *= 1. / np.linalg.norm(v)
|
257 |
+
verts.append(v)
|
258 |
+
split_edges[edge] = len(verts) - 1
|
259 |
+
return split_edges[edge]
|
260 |
+
|
261 |
+
for v1, v2, v3 in faces0:
|
262 |
+
ind1 = _add(v1, v2)
|
263 |
+
ind2 = _add(v2, v3)
|
264 |
+
ind3 = _add(v3, v1)
|
265 |
+
faces.append([v1, ind1, ind3])
|
266 |
+
faces.append([v2, ind2, ind1])
|
267 |
+
faces.append([v3, ind3, ind2])
|
268 |
+
faces.append([ind1, ind2, ind3])
|
269 |
+
|
270 |
+
return verts, faces
|
271 |
+
|
272 |
+
|
273 |
+
class Rays_Tetra(Rays_SubDivide):
|
274 |
+
"""
|
275 |
+
Subdivision of a tetrahedron
|
276 |
+
|
277 |
+
n_level = 1 -> normal tetrahedron (4 vertices)
|
278 |
+
n_level = 2 -> 1x subdivision (10 vertices)
|
279 |
+
n_level = 3 -> 2x subdivision (34 vertices)
|
280 |
+
...
|
281 |
+
"""
|
282 |
+
|
283 |
+
def base_polyhedron(self):
|
284 |
+
verts = np.array([
|
285 |
+
[np.sqrt(8. / 9), 0., -1. / 3],
|
286 |
+
[-np.sqrt(2. / 9), np.sqrt(2. / 3), -1. / 3],
|
287 |
+
[-np.sqrt(2. / 9), -np.sqrt(2. / 3), -1. / 3],
|
288 |
+
[0., 0., 1.]
|
289 |
+
])
|
290 |
+
faces = [[0, 1, 2],
|
291 |
+
[0, 3, 1],
|
292 |
+
[0, 2, 3],
|
293 |
+
[1, 3, 2]]
|
294 |
+
|
295 |
+
return verts, faces
|
296 |
+
|
297 |
+
|
298 |
+
class Rays_Octo(Rays_SubDivide):
|
299 |
+
"""
|
300 |
+
Subdivision of a tetrahedron
|
301 |
+
|
302 |
+
n_level = 1 -> normal Octahedron (6 vertices)
|
303 |
+
n_level = 2 -> 1x subdivision (18 vertices)
|
304 |
+
n_level = 3 -> 2x subdivision (66 vertices)
|
305 |
+
|
306 |
+
"""
|
307 |
+
|
308 |
+
def base_polyhedron(self):
|
309 |
+
verts = np.array([
|
310 |
+
[0, 0, 1],
|
311 |
+
[0, 1, 0],
|
312 |
+
[0, 0, -1],
|
313 |
+
[0, -1, 0],
|
314 |
+
[1, 0, 0],
|
315 |
+
[-1, 0, 0]])
|
316 |
+
|
317 |
+
faces = [[0, 1, 4],
|
318 |
+
[0, 5, 1],
|
319 |
+
[1, 2, 4],
|
320 |
+
[1, 5, 2],
|
321 |
+
[2, 3, 4],
|
322 |
+
[2, 5, 3],
|
323 |
+
[3, 0, 4],
|
324 |
+
[3, 5, 0],
|
325 |
+
]
|
326 |
+
|
327 |
+
return verts, faces
|
328 |
+
|
329 |
+
|
330 |
+
def reorder_faces(verts, faces):
|
331 |
+
"""reorder faces such that their orientation points outward"""
|
332 |
+
def _single(face):
|
333 |
+
return face[::-1] if np.linalg.det(verts[face])>0 else face
|
334 |
+
return tuple(map(_single, faces))
|
335 |
+
|
336 |
+
|
337 |
+
class Rays_GoldenSpiral(Rays_Base):
|
338 |
+
def __init__(self, n=70, anisotropy = None):
|
339 |
+
if n<4:
|
340 |
+
raise ValueError("At least 4 points have to be given!")
|
341 |
+
super().__init__(n=n, anisotropy = anisotropy if anisotropy is None else tuple(anisotropy))
|
342 |
+
|
343 |
+
def setup_vertices_faces(self):
|
344 |
+
n = self.kwargs["n"]
|
345 |
+
anisotropy = self.kwargs["anisotropy"]
|
346 |
+
if anisotropy is None:
|
347 |
+
anisotropy = np.ones(3)
|
348 |
+
else:
|
349 |
+
anisotropy = np.array(anisotropy)
|
350 |
+
|
351 |
+
# the smaller golden angle = 2pi * 0.3819...
|
352 |
+
g = (3. - np.sqrt(5.)) * np.pi
|
353 |
+
phi = g * np.arange(n)
|
354 |
+
# z = np.linspace(-1, 1, n + 2)[1:-1]
|
355 |
+
# rho = np.sqrt(1. - z ** 2)
|
356 |
+
# verts = np.stack([rho*np.cos(phi), rho*np.sin(phi),z]).T
|
357 |
+
#
|
358 |
+
z = np.linspace(-1, 1, n)
|
359 |
+
rho = np.sqrt(1. - z ** 2)
|
360 |
+
verts = np.stack([z, rho * np.sin(phi), rho * np.cos(phi)]).T
|
361 |
+
|
362 |
+
# warnings.warn("ray definition has changed! Old results are invalid!")
|
363 |
+
|
364 |
+
# correct for anisotropy
|
365 |
+
verts = verts/anisotropy
|
366 |
+
#verts /= np.linalg.norm(verts, axis=-1, keepdims=True)
|
367 |
+
|
368 |
+
hull = ConvexHull(verts)
|
369 |
+
faces = reorder_faces(verts,hull.simplices)
|
370 |
+
|
371 |
+
verts /= np.linalg.norm(verts, axis=-1, keepdims=True)
|
372 |
+
|
373 |
+
return verts, faces
|
stardist_pkg/sample_patches.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""provides a faster sampling function"""
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
from csbdeep.utils import _raise, choice
|
5 |
+
|
6 |
+
|
7 |
+
def sample_patches(datas, patch_size, n_samples, valid_inds=None, verbose=False):
|
8 |
+
"""optimized version of csbdeep.data.sample_patches_from_multiple_stacks
|
9 |
+
"""
|
10 |
+
|
11 |
+
len(patch_size)==datas[0].ndim or _raise(ValueError())
|
12 |
+
|
13 |
+
if not all(( a.shape == datas[0].shape for a in datas )):
|
14 |
+
raise ValueError("all input shapes must be the same: %s" % (" / ".join(str(a.shape) for a in datas)))
|
15 |
+
|
16 |
+
if not all(( 0 < s <= d for s,d in zip(patch_size,datas[0].shape) )):
|
17 |
+
raise ValueError("patch_size %s negative or larger than data shape %s along some dimensions" % (str(patch_size), str(datas[0].shape)))
|
18 |
+
|
19 |
+
if valid_inds is None:
|
20 |
+
valid_inds = tuple(_s.ravel() for _s in np.meshgrid(*tuple(np.arange(p//2,s-p//2+1) for s,p in zip(datas[0].shape, patch_size))))
|
21 |
+
|
22 |
+
n_valid = len(valid_inds[0])
|
23 |
+
|
24 |
+
if n_valid == 0:
|
25 |
+
raise ValueError("no regions to sample from!")
|
26 |
+
|
27 |
+
idx = choice(range(n_valid), n_samples, replace=(n_valid < n_samples))
|
28 |
+
rand_inds = [v[idx] for v in valid_inds]
|
29 |
+
res = [np.stack([data[tuple(slice(_r-(_p//2),_r+_p-(_p//2)) for _r,_p in zip(r,patch_size))] for r in zip(*rand_inds)]) for data in datas]
|
30 |
+
|
31 |
+
return res
|
32 |
+
|
33 |
+
|
34 |
+
def get_valid_inds(img, patch_size, patch_filter=None):
|
35 |
+
"""
|
36 |
+
Returns all indices of an image that
|
37 |
+
- can be used as center points for sampling patches of a given patch_size, and
|
38 |
+
- are part of the boolean mask given by the function patch_filter (if provided)
|
39 |
+
|
40 |
+
img: np.ndarray
|
41 |
+
patch_size: tuple of ints
|
42 |
+
the width of patches per img dimension,
|
43 |
+
patch_filter: None or callable
|
44 |
+
a function with signature patch_filter(img, patch_size) returning a boolean mask
|
45 |
+
"""
|
46 |
+
|
47 |
+
len(patch_size)==img.ndim or _raise(ValueError())
|
48 |
+
|
49 |
+
if not all(( 0 < s <= d for s,d in zip(patch_size,img.shape))):
|
50 |
+
raise ValueError("patch_size %s negative or larger than image shape %s along some dimensions" % (str(patch_size), str(img.shape)))
|
51 |
+
|
52 |
+
if patch_filter is None:
|
53 |
+
# only cut border indices (which is faster)
|
54 |
+
patch_mask = np.ones(img.shape,dtype=bool)
|
55 |
+
valid_inds = tuple(np.arange(p // 2, s - p + p // 2 + 1).astype(np.uint32) for p, s in zip(patch_size, img.shape))
|
56 |
+
valid_inds = tuple(s.ravel() for s in np.meshgrid(*valid_inds, indexing='ij'))
|
57 |
+
else:
|
58 |
+
patch_mask = patch_filter(img, patch_size)
|
59 |
+
|
60 |
+
# get the valid indices
|
61 |
+
border_slices = tuple([slice(p // 2, s - p + p // 2 + 1) for p, s in zip(patch_size, img.shape)])
|
62 |
+
valid_inds = np.where(patch_mask[border_slices])
|
63 |
+
valid_inds = tuple((v + s.start).astype(np.uint32) for s, v in zip(border_slices, valid_inds))
|
64 |
+
|
65 |
+
return valid_inds
|
stardist_pkg/utils.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function, unicode_literals, absolute_import, division
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import warnings
|
5 |
+
import os
|
6 |
+
import datetime
|
7 |
+
from tqdm import tqdm
|
8 |
+
from collections import defaultdict
|
9 |
+
from zipfile import ZipFile, ZIP_DEFLATED
|
10 |
+
from scipy.ndimage.morphology import distance_transform_edt, binary_fill_holes
|
11 |
+
from scipy.ndimage.measurements import find_objects
|
12 |
+
from scipy.optimize import minimize_scalar
|
13 |
+
from skimage.measure import regionprops
|
14 |
+
from csbdeep.utils import _raise
|
15 |
+
from csbdeep.utils.six import Path
|
16 |
+
from collections.abc import Iterable
|
17 |
+
|
18 |
+
from .matching import matching_dataset, _check_label_array
|
19 |
+
|
20 |
+
|
21 |
+
try:
|
22 |
+
from edt import edt
|
23 |
+
_edt_available = True
|
24 |
+
try: _edt_parallel_max = len(os.sched_getaffinity(0))
|
25 |
+
except: _edt_parallel_max = 128
|
26 |
+
_edt_parallel_default = 4
|
27 |
+
_edt_parallel = os.environ.get('STARDIST_EDT_NUM_THREADS', _edt_parallel_default)
|
28 |
+
try:
|
29 |
+
_edt_parallel = min(_edt_parallel_max, int(_edt_parallel))
|
30 |
+
except ValueError as e:
|
31 |
+
warnings.warn(f"Invalid value ({_edt_parallel}) for STARDIST_EDT_NUM_THREADS. Using default value ({_edt_parallel_default}) instead.")
|
32 |
+
_edt_parallel = _edt_parallel_default
|
33 |
+
del _edt_parallel_default, _edt_parallel_max
|
34 |
+
except ImportError:
|
35 |
+
_edt_available = False
|
36 |
+
# warnings.warn("Could not find package edt... \nConsider installing it with \n pip install edt\nto improve training data generation performance.")
|
37 |
+
pass
|
38 |
+
|
39 |
+
|
40 |
+
def gputools_available():
|
41 |
+
try:
|
42 |
+
import gputools
|
43 |
+
except:
|
44 |
+
return False
|
45 |
+
return True
|
46 |
+
|
47 |
+
|
48 |
+
def path_absolute(path_relative):
|
49 |
+
""" Get absolute path to resource"""
|
50 |
+
base_path = os.path.abspath(os.path.dirname(__file__))
|
51 |
+
return os.path.join(base_path, path_relative)
|
52 |
+
|
53 |
+
|
54 |
+
def _is_power_of_2(i):
|
55 |
+
assert i > 0
|
56 |
+
e = np.log2(i)
|
57 |
+
return e == int(e)
|
58 |
+
|
59 |
+
|
60 |
+
def _normalize_grid(grid,n):
|
61 |
+
try:
|
62 |
+
grid = tuple(grid)
|
63 |
+
(len(grid) == n and
|
64 |
+
all(map(np.isscalar,grid)) and
|
65 |
+
all(map(_is_power_of_2,grid))) or _raise(TypeError())
|
66 |
+
return tuple(int(g) for g in grid)
|
67 |
+
except (TypeError, AssertionError):
|
68 |
+
raise ValueError("grid = {grid} must be a list/tuple of length {n} with values that are power of 2".format(grid=grid, n=n))
|
69 |
+
|
70 |
+
|
71 |
+
def edt_prob(lbl_img, anisotropy=None):
|
72 |
+
if _edt_available:
|
73 |
+
return _edt_prob_edt(lbl_img, anisotropy=anisotropy)
|
74 |
+
else:
|
75 |
+
# warnings.warn("Could not find package edt... \nConsider installing it with \n pip install edt\nto improve training data generation performance.")
|
76 |
+
return _edt_prob_scipy(lbl_img, anisotropy=anisotropy)
|
77 |
+
|
78 |
+
def _edt_prob_edt(lbl_img, anisotropy=None):
|
79 |
+
"""Perform EDT on each labeled object and normalize.
|
80 |
+
Internally uses https://github.com/seung-lab/euclidean-distance-transform-3d
|
81 |
+
that can handle multiple labels at once
|
82 |
+
"""
|
83 |
+
lbl_img = np.ascontiguousarray(lbl_img)
|
84 |
+
constant_img = lbl_img.min() == lbl_img.max() and lbl_img.flat[0] > 0
|
85 |
+
if constant_img:
|
86 |
+
warnings.warn("EDT of constant label image is ill-defined. (Assuming background around it.)")
|
87 |
+
# we just need to compute the edt once but then normalize it for each object
|
88 |
+
prob = edt(lbl_img, anisotropy=anisotropy, black_border=constant_img, parallel=_edt_parallel)
|
89 |
+
objects = find_objects(lbl_img)
|
90 |
+
for i,sl in enumerate(objects,1):
|
91 |
+
# i: object label id, sl: slices of object in lbl_img
|
92 |
+
if sl is None: continue
|
93 |
+
_mask = lbl_img[sl]==i
|
94 |
+
# normalize it
|
95 |
+
prob[sl][_mask] /= np.max(prob[sl][_mask]+1e-10)
|
96 |
+
return prob
|
97 |
+
|
98 |
+
def _edt_prob_scipy(lbl_img, anisotropy=None):
|
99 |
+
"""Perform EDT on each labeled object and normalize."""
|
100 |
+
def grow(sl,interior):
|
101 |
+
return tuple(slice(s.start-int(w[0]),s.stop+int(w[1])) for s,w in zip(sl,interior))
|
102 |
+
def shrink(interior):
|
103 |
+
return tuple(slice(int(w[0]),(-1 if w[1] else None)) for w in interior)
|
104 |
+
constant_img = lbl_img.min() == lbl_img.max() and lbl_img.flat[0] > 0
|
105 |
+
if constant_img:
|
106 |
+
lbl_img = np.pad(lbl_img, ((1,1),)*lbl_img.ndim, mode='constant')
|
107 |
+
warnings.warn("EDT of constant label image is ill-defined. (Assuming background around it.)")
|
108 |
+
objects = find_objects(lbl_img)
|
109 |
+
prob = np.zeros(lbl_img.shape,np.float32)
|
110 |
+
for i,sl in enumerate(objects,1):
|
111 |
+
# i: object label id, sl: slices of object in lbl_img
|
112 |
+
if sl is None: continue
|
113 |
+
interior = [(s.start>0,s.stop<sz) for s,sz in zip(sl,lbl_img.shape)]
|
114 |
+
# 1. grow object slice by 1 for all interior object bounding boxes
|
115 |
+
# 2. perform (correct) EDT for object with label id i
|
116 |
+
# 3. extract EDT for object of original slice and normalize
|
117 |
+
# 4. store edt for object only for pixels of given label id i
|
118 |
+
shrink_slice = shrink(interior)
|
119 |
+
grown_mask = lbl_img[grow(sl,interior)]==i
|
120 |
+
mask = grown_mask[shrink_slice]
|
121 |
+
edt = distance_transform_edt(grown_mask, sampling=anisotropy)[shrink_slice][mask]
|
122 |
+
prob[sl][mask] = edt/(np.max(edt)+1e-10)
|
123 |
+
if constant_img:
|
124 |
+
prob = prob[(slice(1,-1),)*lbl_img.ndim].copy()
|
125 |
+
return prob
|
126 |
+
|
127 |
+
|
128 |
+
def _fill_label_holes(lbl_img, **kwargs):
|
129 |
+
lbl_img_filled = np.zeros_like(lbl_img)
|
130 |
+
for l in (set(np.unique(lbl_img)) - set([0])):
|
131 |
+
mask = lbl_img==l
|
132 |
+
mask_filled = binary_fill_holes(mask,**kwargs)
|
133 |
+
lbl_img_filled[mask_filled] = l
|
134 |
+
return lbl_img_filled
|
135 |
+
|
136 |
+
|
137 |
+
def fill_label_holes(lbl_img, **kwargs):
|
138 |
+
"""Fill small holes in label image."""
|
139 |
+
# TODO: refactor 'fill_label_holes' and 'edt_prob' to share code
|
140 |
+
def grow(sl,interior):
|
141 |
+
return tuple(slice(s.start-int(w[0]),s.stop+int(w[1])) for s,w in zip(sl,interior))
|
142 |
+
def shrink(interior):
|
143 |
+
return tuple(slice(int(w[0]),(-1 if w[1] else None)) for w in interior)
|
144 |
+
objects = find_objects(lbl_img)
|
145 |
+
lbl_img_filled = np.zeros_like(lbl_img)
|
146 |
+
for i,sl in enumerate(objects,1):
|
147 |
+
if sl is None: continue
|
148 |
+
interior = [(s.start>0,s.stop<sz) for s,sz in zip(sl,lbl_img.shape)]
|
149 |
+
shrink_slice = shrink(interior)
|
150 |
+
grown_mask = lbl_img[grow(sl,interior)]==i
|
151 |
+
mask_filled = binary_fill_holes(grown_mask,**kwargs)[shrink_slice]
|
152 |
+
lbl_img_filled[sl][mask_filled] = i
|
153 |
+
return lbl_img_filled
|
154 |
+
|
155 |
+
|
156 |
+
def sample_points(n_samples, mask, prob=None, b=2):
|
157 |
+
"""sample points to draw some of the associated polygons"""
|
158 |
+
if b is not None and b > 0:
|
159 |
+
# ignore image boundary, since predictions may not be reliable
|
160 |
+
mask_b = np.zeros_like(mask)
|
161 |
+
mask_b[b:-b,b:-b] = True
|
162 |
+
else:
|
163 |
+
mask_b = True
|
164 |
+
|
165 |
+
points = np.nonzero(mask & mask_b)
|
166 |
+
|
167 |
+
if prob is not None:
|
168 |
+
# weighted sampling via prob
|
169 |
+
w = prob[points[0],points[1]].astype(np.float64)
|
170 |
+
w /= np.sum(w)
|
171 |
+
ind = np.random.choice(len(points[0]), n_samples, replace=True, p=w)
|
172 |
+
else:
|
173 |
+
ind = np.random.choice(len(points[0]), n_samples, replace=True)
|
174 |
+
|
175 |
+
points = points[0][ind], points[1][ind]
|
176 |
+
points = np.stack(points,axis=-1)
|
177 |
+
return points
|
178 |
+
|
179 |
+
|
180 |
+
def calculate_extents(lbl, func=np.median):
|
181 |
+
""" Aggregate bounding box sizes of objects in label images. """
|
182 |
+
if (isinstance(lbl,np.ndarray) and lbl.ndim==4) or (not isinstance(lbl,np.ndarray) and isinstance(lbl,Iterable)):
|
183 |
+
return func(np.stack([calculate_extents(_lbl,func) for _lbl in lbl], axis=0), axis=0)
|
184 |
+
|
185 |
+
n = lbl.ndim
|
186 |
+
n in (2,3) or _raise(ValueError("label image should be 2- or 3-dimensional (or pass a list of these)"))
|
187 |
+
|
188 |
+
regs = regionprops(lbl)
|
189 |
+
if len(regs) == 0:
|
190 |
+
return np.zeros(n)
|
191 |
+
else:
|
192 |
+
extents = np.array([np.array(r.bbox[n:])-np.array(r.bbox[:n]) for r in regs])
|
193 |
+
return func(extents, axis=0)
|
194 |
+
|
195 |
+
|
196 |
+
def polyroi_bytearray(x,y,pos=None,subpixel=True):
|
197 |
+
""" Byte array of polygon roi with provided x and y coordinates
|
198 |
+
See https://github.com/imagej/imagej1/blob/master/ij/io/RoiDecoder.java
|
199 |
+
"""
|
200 |
+
import struct
|
201 |
+
def _int16(x):
|
202 |
+
return int(x).to_bytes(2, byteorder='big', signed=True)
|
203 |
+
def _uint16(x):
|
204 |
+
return int(x).to_bytes(2, byteorder='big', signed=False)
|
205 |
+
def _int32(x):
|
206 |
+
return int(x).to_bytes(4, byteorder='big', signed=True)
|
207 |
+
def _float(x):
|
208 |
+
return struct.pack(">f", x)
|
209 |
+
|
210 |
+
subpixel = bool(subpixel)
|
211 |
+
# add offset since pixel center is at (0.5,0.5) in ImageJ
|
212 |
+
x_raw = np.asarray(x).ravel() + 0.5
|
213 |
+
y_raw = np.asarray(y).ravel() + 0.5
|
214 |
+
x = np.round(x_raw)
|
215 |
+
y = np.round(y_raw)
|
216 |
+
assert len(x) == len(y)
|
217 |
+
top, left, bottom, right = y.min(), x.min(), y.max(), x.max() # bbox
|
218 |
+
|
219 |
+
n_coords = len(x)
|
220 |
+
bytes_header = 64
|
221 |
+
bytes_total = bytes_header + n_coords*2*2 + subpixel*n_coords*2*4
|
222 |
+
B = [0] * bytes_total
|
223 |
+
B[ 0: 4] = map(ord,'Iout') # magic start
|
224 |
+
B[ 4: 6] = _int16(227) # version
|
225 |
+
B[ 6: 8] = _int16(0) # roi type (0 = polygon)
|
226 |
+
B[ 8:10] = _int16(top) # bbox top
|
227 |
+
B[10:12] = _int16(left) # bbox left
|
228 |
+
B[12:14] = _int16(bottom) # bbox bottom
|
229 |
+
B[14:16] = _int16(right) # bbox right
|
230 |
+
B[16:18] = _uint16(n_coords) # number of coordinates
|
231 |
+
if subpixel:
|
232 |
+
B[50:52] = _int16(128) # subpixel resolution (option flag)
|
233 |
+
if pos is not None:
|
234 |
+
B[56:60] = _int32(pos) # position (C, Z, or T)
|
235 |
+
|
236 |
+
for i,(_x,_y) in enumerate(zip(x,y)):
|
237 |
+
xs = bytes_header + 2*i
|
238 |
+
ys = xs + 2*n_coords
|
239 |
+
B[xs:xs+2] = _int16(_x - left)
|
240 |
+
B[ys:ys+2] = _int16(_y - top)
|
241 |
+
|
242 |
+
if subpixel:
|
243 |
+
base1 = bytes_header + n_coords*2*2
|
244 |
+
base2 = base1 + n_coords*4
|
245 |
+
for i,(_x,_y) in enumerate(zip(x_raw,y_raw)):
|
246 |
+
xs = base1 + 4*i
|
247 |
+
ys = base2 + 4*i
|
248 |
+
B[xs:xs+4] = _float(_x)
|
249 |
+
B[ys:ys+4] = _float(_y)
|
250 |
+
|
251 |
+
return bytearray(B)
|
252 |
+
|
253 |
+
|
254 |
+
def export_imagej_rois(fname, polygons, set_position=True, subpixel=True, compression=ZIP_DEFLATED):
|
255 |
+
""" polygons assumed to be a list of arrays with shape (id,2,c) """
|
256 |
+
|
257 |
+
if isinstance(polygons,np.ndarray):
|
258 |
+
polygons = (polygons,)
|
259 |
+
|
260 |
+
fname = Path(fname)
|
261 |
+
if fname.suffix == '.zip':
|
262 |
+
fname = fname.with_suffix('')
|
263 |
+
|
264 |
+
with ZipFile(str(fname)+'.zip', mode='w', compression=compression) as roizip:
|
265 |
+
for pos,polygroup in enumerate(polygons,start=1):
|
266 |
+
for i,poly in enumerate(polygroup,start=1):
|
267 |
+
roi = polyroi_bytearray(poly[1],poly[0], pos=(pos if set_position else None), subpixel=subpixel)
|
268 |
+
roizip.writestr('{pos:03d}_{i:03d}.roi'.format(pos=pos,i=i), roi)
|
269 |
+
|
270 |
+
|
271 |
+
def optimize_threshold(Y, Yhat, model, nms_thresh, measure='accuracy', iou_threshs=[0.3,0.5,0.7], bracket=None, tol=1e-2, maxiter=20, verbose=1):
|
272 |
+
""" Tune prob_thresh for provided (fixed) nms_thresh to maximize matching score (for given measure and averaged over iou_threshs). """
|
273 |
+
np.isscalar(nms_thresh) or _raise(ValueError("nms_thresh must be a scalar"))
|
274 |
+
iou_threshs = [iou_threshs] if np.isscalar(iou_threshs) else iou_threshs
|
275 |
+
values = dict()
|
276 |
+
|
277 |
+
if bracket is None:
|
278 |
+
max_prob = max([np.max(prob) for prob, dist in Yhat])
|
279 |
+
bracket = max_prob/2, max_prob
|
280 |
+
# print("bracket =", bracket)
|
281 |
+
|
282 |
+
with tqdm(total=maxiter, disable=(verbose!=1), desc="NMS threshold = %g" % nms_thresh) as progress:
|
283 |
+
|
284 |
+
def fn(thr):
|
285 |
+
prob_thresh = np.clip(thr, *bracket)
|
286 |
+
value = values.get(prob_thresh)
|
287 |
+
if value is None:
|
288 |
+
Y_instances = [model._instances_from_prediction(y.shape, *prob_dist, prob_thresh=prob_thresh, nms_thresh=nms_thresh)[0] for y,prob_dist in zip(Y,Yhat)]
|
289 |
+
stats = matching_dataset(Y, Y_instances, thresh=iou_threshs, show_progress=False, parallel=True)
|
290 |
+
values[prob_thresh] = value = np.mean([s._asdict()[measure] for s in stats])
|
291 |
+
if verbose > 1:
|
292 |
+
print("{now} thresh: {prob_thresh:f} {measure}: {value:f}".format(
|
293 |
+
now = datetime.datetime.now().strftime('%H:%M:%S'),
|
294 |
+
prob_thresh = prob_thresh,
|
295 |
+
measure = measure,
|
296 |
+
value = value,
|
297 |
+
), flush=True)
|
298 |
+
else:
|
299 |
+
progress.update()
|
300 |
+
progress.set_postfix_str("{prob_thresh:.3f} -> {value:.3f}".format(prob_thresh=prob_thresh, value=value))
|
301 |
+
progress.refresh()
|
302 |
+
return -value
|
303 |
+
|
304 |
+
opt = minimize_scalar(fn, method='golden', bracket=bracket, tol=tol, options={'maxiter': maxiter})
|
305 |
+
|
306 |
+
verbose > 1 and print('\n',opt, flush=True)
|
307 |
+
return opt.x, -opt.fun
|
308 |
+
|
309 |
+
|
310 |
+
def _invert_dict(d):
|
311 |
+
""" return v-> [k_1,k_2,k_3....] for k,v in d"""
|
312 |
+
res = defaultdict(list)
|
313 |
+
for k,v in d.items():
|
314 |
+
res[v].append(k)
|
315 |
+
return res
|
316 |
+
|
317 |
+
|
318 |
+
def mask_to_categorical(y, n_classes, classes, return_cls_dict=False):
|
319 |
+
"""generates a multi-channel categorical class map
|
320 |
+
|
321 |
+
Parameters
|
322 |
+
----------
|
323 |
+
y : n-dimensional ndarray
|
324 |
+
integer label array
|
325 |
+
n_classes : int
|
326 |
+
Number of different classes (without background)
|
327 |
+
classes: dict, integer, or None
|
328 |
+
the label to class assignment
|
329 |
+
can be
|
330 |
+
- dict {label -> class_id}
|
331 |
+
the value of class_id can be
|
332 |
+
0 -> background class
|
333 |
+
1...n_classes -> the respective object class (1 ... n_classes)
|
334 |
+
None -> ignore object (prob is set to -1 for the pixels of the object, except for background class)
|
335 |
+
- single integer value or None -> broadcast value to all labels
|
336 |
+
|
337 |
+
Returns
|
338 |
+
-------
|
339 |
+
probability map of shape y.shape+(n_classes+1,) (first channel is background)
|
340 |
+
|
341 |
+
"""
|
342 |
+
|
343 |
+
_check_label_array(y, 'y')
|
344 |
+
if not (np.issubdtype(type(n_classes), np.integer) and n_classes>=1):
|
345 |
+
raise ValueError(f"n_classes is '{n_classes}' but should be a positive integer")
|
346 |
+
|
347 |
+
y_labels = np.unique(y[y>0]).tolist()
|
348 |
+
|
349 |
+
# build dict class_id -> labels (inverse of classes)
|
350 |
+
if np.issubdtype(type(classes), np.integer) or classes is None:
|
351 |
+
classes = dict((k,classes) for k in y_labels)
|
352 |
+
elif isinstance(classes, dict):
|
353 |
+
pass
|
354 |
+
else:
|
355 |
+
raise ValueError("classes should be dict, single scalar, or None!")
|
356 |
+
|
357 |
+
if not set(y_labels).issubset(set(classes.keys())):
|
358 |
+
raise ValueError(f"all gt labels should be present in class dict provided \ngt_labels found\n{set(y_labels)}\nclass dict labels provided\n{set(classes.keys())}")
|
359 |
+
|
360 |
+
cls_dict = _invert_dict(classes)
|
361 |
+
|
362 |
+
# prob map
|
363 |
+
y_mask = np.zeros(y.shape+(n_classes+1,), np.float32)
|
364 |
+
|
365 |
+
for cls, labels in cls_dict.items():
|
366 |
+
if cls is None:
|
367 |
+
# prob == -1 will be used in the loss to ignore object
|
368 |
+
y_mask[np.isin(y, labels)] = -1
|
369 |
+
elif np.issubdtype(type(cls), np.integer) and 0 <= cls <= n_classes:
|
370 |
+
y_mask[...,cls] = np.isin(y, labels)
|
371 |
+
else:
|
372 |
+
raise ValueError(f"Wrong class id '{cls}' (for n_classes={n_classes})")
|
373 |
+
|
374 |
+
# set 0/1 background prob (unaffected by None values for class ids)
|
375 |
+
y_mask[...,0] = (y==0)
|
376 |
+
|
377 |
+
if return_cls_dict:
|
378 |
+
return y_mask, cls_dict
|
379 |
+
else:
|
380 |
+
return y_mask
|
381 |
+
|
382 |
+
|
383 |
+
def _is_floatarray(x):
|
384 |
+
return isinstance(x.dtype.type(0),np.floating)
|
385 |
+
|
386 |
+
|
387 |
+
def abspath(root, relpath):
|
388 |
+
from pathlib import Path
|
389 |
+
root = Path(root)
|
390 |
+
if root.is_dir():
|
391 |
+
path = root/relpath
|
392 |
+
else:
|
393 |
+
path = root.parent/relpath
|
394 |
+
return str(path.absolute())
|
stardist_pkg/version.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = '0.8.3'
|
utils_modify.py
ADDED
@@ -0,0 +1,743 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
import warnings
|
13 |
+
from typing import Any, Callable, Dict, List, Mapping, Sequence, Tuple, Union
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from stardist_pkg.big import _grid_divisible, BlockND, OBJECT_KEYS#, repaint_labels
|
18 |
+
from stardist_pkg.matching import relabel_sequential
|
19 |
+
from stardist_pkg import dist_to_coord, non_maximum_suppression, polygons_to_label
|
20 |
+
#from stardist_pkg import dist_to_coord, polygons_to_label
|
21 |
+
from stardist_pkg import star_dist,edt_prob
|
22 |
+
from monai.data.meta_tensor import MetaTensor
|
23 |
+
from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size
|
24 |
+
from monai.transforms import Resize
|
25 |
+
from monai.utils import (
|
26 |
+
BlendMode,
|
27 |
+
PytorchPadMode,
|
28 |
+
convert_data_type,
|
29 |
+
convert_to_dst_type,
|
30 |
+
ensure_tuple,
|
31 |
+
fall_back_tuple,
|
32 |
+
look_up_option,
|
33 |
+
optional_import,
|
34 |
+
)
|
35 |
+
import cv2
|
36 |
+
from scipy import ndimage
|
37 |
+
from scipy.ndimage.filters import gaussian_filter
|
38 |
+
from scipy.ndimage.interpolation import affine_transform, map_coordinates
|
39 |
+
from skimage import morphology as morph
|
40 |
+
from scipy.ndimage import filters, measurements
|
41 |
+
from scipy.ndimage.morphology import (
|
42 |
+
binary_dilation,
|
43 |
+
binary_fill_holes,
|
44 |
+
distance_transform_cdt,
|
45 |
+
distance_transform_edt,
|
46 |
+
)
|
47 |
+
|
48 |
+
from skimage.segmentation import watershed
|
49 |
+
tqdm, _ = optional_import("tqdm", name="tqdm")
|
50 |
+
|
51 |
+
__all__ = ["sliding_window_inference"]
|
52 |
+
|
53 |
+
|
54 |
+
####
|
55 |
+
def normalize(mask, dtype=np.uint8):
|
56 |
+
return (255 * mask / np.amax(mask)).astype(dtype)
|
57 |
+
|
58 |
+
def fix_mirror_padding(ann):
|
59 |
+
"""Deal with duplicated instances due to mirroring in interpolation
|
60 |
+
during shape augmentation (scale, rotation etc.).
|
61 |
+
|
62 |
+
"""
|
63 |
+
current_max_id = np.amax(ann)
|
64 |
+
inst_list = list(np.unique(ann))
|
65 |
+
if 0 in inst_list:
|
66 |
+
inst_list.remove(0) # 0 is background
|
67 |
+
for inst_id in inst_list:
|
68 |
+
inst_map = np.array(ann == inst_id, np.uint8)
|
69 |
+
remapped_ids = measurements.label(inst_map)[0]
|
70 |
+
remapped_ids[remapped_ids > 1] += current_max_id
|
71 |
+
ann[remapped_ids > 1] = remapped_ids[remapped_ids > 1]
|
72 |
+
current_max_id = np.amax(ann)
|
73 |
+
return ann
|
74 |
+
|
75 |
+
####
|
76 |
+
def get_bounding_box(img):
|
77 |
+
"""Get bounding box coordinate information."""
|
78 |
+
rows = np.any(img, axis=1)
|
79 |
+
cols = np.any(img, axis=0)
|
80 |
+
rmin, rmax = np.where(rows)[0][[0, -1]]
|
81 |
+
cmin, cmax = np.where(cols)[0][[0, -1]]
|
82 |
+
# due to python indexing, need to add 1 to max
|
83 |
+
# else accessing will be 1px in the box, not out
|
84 |
+
rmax += 1
|
85 |
+
cmax += 1
|
86 |
+
return [rmin, rmax, cmin, cmax]
|
87 |
+
|
88 |
+
|
89 |
+
####
|
90 |
+
def cropping_center(x, crop_shape, batch=False):
|
91 |
+
"""Crop an input image at the centre.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
x: input array
|
95 |
+
crop_shape: dimensions of cropped array
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
x: cropped array
|
99 |
+
|
100 |
+
"""
|
101 |
+
orig_shape = x.shape
|
102 |
+
if not batch:
|
103 |
+
h0 = int((orig_shape[0] - crop_shape[0]) * 0.5)
|
104 |
+
w0 = int((orig_shape[1] - crop_shape[1]) * 0.5)
|
105 |
+
x = x[h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]]
|
106 |
+
else:
|
107 |
+
h0 = int((orig_shape[1] - crop_shape[0]) * 0.5)
|
108 |
+
w0 = int((orig_shape[2] - crop_shape[1]) * 0.5)
|
109 |
+
x = x[:, h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1]]
|
110 |
+
return x
|
111 |
+
|
112 |
+
def gen_instance_hv_map(ann, crop_shape):
|
113 |
+
"""Input annotation must be of original shape.
|
114 |
+
|
115 |
+
The map is calculated only for instances within the crop portion
|
116 |
+
but based on the original shape in original image.
|
117 |
+
|
118 |
+
Perform following operation:
|
119 |
+
Obtain the horizontal and vertical distance maps for each
|
120 |
+
nuclear instance.
|
121 |
+
|
122 |
+
"""
|
123 |
+
orig_ann = ann.copy() # instance ID map
|
124 |
+
fixed_ann = fix_mirror_padding(orig_ann)
|
125 |
+
# re-cropping with fixed instance id map
|
126 |
+
crop_ann = cropping_center(fixed_ann, crop_shape)
|
127 |
+
# TODO: deal with 1 label warning
|
128 |
+
crop_ann = morph.remove_small_objects(crop_ann, min_size=30)
|
129 |
+
|
130 |
+
x_map = np.zeros(orig_ann.shape[:2], dtype=np.float32)
|
131 |
+
y_map = np.zeros(orig_ann.shape[:2], dtype=np.float32)
|
132 |
+
|
133 |
+
inst_list = list(np.unique(crop_ann))
|
134 |
+
if 0 in inst_list:
|
135 |
+
inst_list.remove(0) # 0 is background
|
136 |
+
for inst_id in inst_list:
|
137 |
+
inst_map = np.array(fixed_ann == inst_id, np.uint8)
|
138 |
+
inst_box = get_bounding_box(inst_map) # rmin, rmax, cmin, cmax
|
139 |
+
|
140 |
+
# expand the box by 2px
|
141 |
+
# Because we first pad the ann at line 207, the bboxes
|
142 |
+
# will remain valid after expansion
|
143 |
+
inst_box[0] -= 2
|
144 |
+
inst_box[2] -= 2
|
145 |
+
inst_box[1] += 2
|
146 |
+
inst_box[3] += 2
|
147 |
+
|
148 |
+
# fix inst_box
|
149 |
+
inst_box[0] = max(inst_box[0], 0)
|
150 |
+
inst_box[2] = max(inst_box[2], 0)
|
151 |
+
# inst_box[1] = min(inst_box[1], fixed_ann.shape[0])
|
152 |
+
# inst_box[3] = min(inst_box[3], fixed_ann.shape[1])
|
153 |
+
|
154 |
+
inst_map = inst_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
|
155 |
+
|
156 |
+
if inst_map.shape[0] < 2 or inst_map.shape[1] < 2:
|
157 |
+
print(f'inst_map.shape < 2: {inst_map.shape}, {inst_box}, {get_bounding_box(np.array(fixed_ann == inst_id, np.uint8))}')
|
158 |
+
continue
|
159 |
+
|
160 |
+
# instance center of mass, rounded to nearest pixel
|
161 |
+
inst_com = list(measurements.center_of_mass(inst_map))
|
162 |
+
if np.isnan(measurements.center_of_mass(inst_map)).any():
|
163 |
+
print(inst_id, fixed_ann.shape, np.array(fixed_ann == inst_id, np.uint8).shape)
|
164 |
+
print(get_bounding_box(np.array(fixed_ann == inst_id, np.uint8)))
|
165 |
+
print(inst_map)
|
166 |
+
print(inst_list)
|
167 |
+
print(inst_box)
|
168 |
+
print(np.count_nonzero(np.array(fixed_ann == inst_id, np.uint8)))
|
169 |
+
|
170 |
+
inst_com[0] = int(inst_com[0] + 0.5)
|
171 |
+
inst_com[1] = int(inst_com[1] + 0.5)
|
172 |
+
|
173 |
+
inst_x_range = np.arange(1, inst_map.shape[1] + 1)
|
174 |
+
inst_y_range = np.arange(1, inst_map.shape[0] + 1)
|
175 |
+
# shifting center of pixels grid to instance center of mass
|
176 |
+
inst_x_range -= inst_com[1]
|
177 |
+
inst_y_range -= inst_com[0]
|
178 |
+
|
179 |
+
inst_x, inst_y = np.meshgrid(inst_x_range, inst_y_range)
|
180 |
+
|
181 |
+
# remove coord outside of instance
|
182 |
+
inst_x[inst_map == 0] = 0
|
183 |
+
inst_y[inst_map == 0] = 0
|
184 |
+
inst_x = inst_x.astype("float32")
|
185 |
+
inst_y = inst_y.astype("float32")
|
186 |
+
|
187 |
+
# normalize min into -1 scale
|
188 |
+
if np.min(inst_x) < 0:
|
189 |
+
inst_x[inst_x < 0] /= -np.amin(inst_x[inst_x < 0])
|
190 |
+
if np.min(inst_y) < 0:
|
191 |
+
inst_y[inst_y < 0] /= -np.amin(inst_y[inst_y < 0])
|
192 |
+
# normalize max into +1 scale
|
193 |
+
if np.max(inst_x) > 0:
|
194 |
+
inst_x[inst_x > 0] /= np.amax(inst_x[inst_x > 0])
|
195 |
+
if np.max(inst_y) > 0:
|
196 |
+
inst_y[inst_y > 0] /= np.amax(inst_y[inst_y > 0])
|
197 |
+
|
198 |
+
####
|
199 |
+
x_map_box = x_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
|
200 |
+
x_map_box[inst_map > 0] = inst_x[inst_map > 0]
|
201 |
+
|
202 |
+
y_map_box = y_map[inst_box[0] : inst_box[1], inst_box[2] : inst_box[3]]
|
203 |
+
y_map_box[inst_map > 0] = inst_y[inst_map > 0]
|
204 |
+
|
205 |
+
hv_map = np.dstack([x_map, y_map])
|
206 |
+
return hv_map
|
207 |
+
|
208 |
+
def remove_small_objects(pred, min_size=64, connectivity=1):
|
209 |
+
"""Remove connected components smaller than the specified size.
|
210 |
+
|
211 |
+
This function is taken from skimage.morphology.remove_small_objects, but the warning
|
212 |
+
is removed when a single label is provided.
|
213 |
+
|
214 |
+
Args:
|
215 |
+
pred: input labelled array
|
216 |
+
min_size: minimum size of instance in output array
|
217 |
+
connectivity: The connectivity defining the neighborhood of a pixel.
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
out: output array with instances removed under min_size
|
221 |
+
|
222 |
+
"""
|
223 |
+
out = pred
|
224 |
+
|
225 |
+
if min_size == 0: # shortcut for efficiency
|
226 |
+
return out
|
227 |
+
|
228 |
+
if out.dtype == bool:
|
229 |
+
selem = ndimage.generate_binary_structure(pred.ndim, connectivity)
|
230 |
+
ccs = np.zeros_like(pred, dtype=np.int32)
|
231 |
+
ndimage.label(pred, selem, output=ccs)
|
232 |
+
else:
|
233 |
+
ccs = out
|
234 |
+
|
235 |
+
try:
|
236 |
+
component_sizes = np.bincount(ccs.ravel())
|
237 |
+
except ValueError:
|
238 |
+
raise ValueError(
|
239 |
+
"Negative value labels are not supported. Try "
|
240 |
+
"relabeling the input with `scipy.ndimage.label` or "
|
241 |
+
"`skimage.morphology.label`."
|
242 |
+
)
|
243 |
+
|
244 |
+
too_small = component_sizes < min_size
|
245 |
+
too_small_mask = too_small[ccs]
|
246 |
+
out[too_small_mask] = 0
|
247 |
+
|
248 |
+
return out
|
249 |
+
|
250 |
+
####
|
251 |
+
def gen_targets(ann, crop_shape, **kwargs):
|
252 |
+
"""Generate the targets for the network."""
|
253 |
+
hv_map = gen_instance_hv_map(ann, crop_shape)
|
254 |
+
np_map = ann.copy()
|
255 |
+
np_map[np_map > 0] = 1
|
256 |
+
|
257 |
+
hv_map = cropping_center(hv_map, crop_shape)
|
258 |
+
np_map = cropping_center(np_map, crop_shape)
|
259 |
+
|
260 |
+
target_dict = {
|
261 |
+
"hv_map": hv_map,
|
262 |
+
"np_map": np_map,
|
263 |
+
}
|
264 |
+
|
265 |
+
return target_dict
|
266 |
+
def __proc_np_hv(pred, np_thres=0.5, ksize=21, overall_thres=0.4, obj_size_thres=10):
|
267 |
+
"""Process Nuclei Prediction with XY Coordinate Map.
|
268 |
+
|
269 |
+
Args:
|
270 |
+
pred: prediction output, assuming
|
271 |
+
channel 0 contain probability map of nuclei
|
272 |
+
channel 1 containing the regressed X-map
|
273 |
+
channel 2 containing the regressed Y-map
|
274 |
+
|
275 |
+
"""
|
276 |
+
pred = np.array(pred, dtype=np.float32)
|
277 |
+
|
278 |
+
blb_raw = pred[..., 0]
|
279 |
+
h_dir_raw = pred[..., 1]
|
280 |
+
v_dir_raw = pred[..., 2]
|
281 |
+
|
282 |
+
# processing
|
283 |
+
blb = np.array(blb_raw >= np_thres, dtype=np.int32)
|
284 |
+
|
285 |
+
blb = measurements.label(blb)[0]
|
286 |
+
blb = remove_small_objects(blb, min_size=10)
|
287 |
+
blb[blb > 0] = 1 # background is 0 already
|
288 |
+
|
289 |
+
h_dir = cv2.normalize(
|
290 |
+
h_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
|
291 |
+
)
|
292 |
+
v_dir = cv2.normalize(
|
293 |
+
v_dir_raw, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
|
294 |
+
)
|
295 |
+
|
296 |
+
sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=ksize)
|
297 |
+
sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=ksize)
|
298 |
+
|
299 |
+
sobelh = 1 - (
|
300 |
+
cv2.normalize(
|
301 |
+
sobelh, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
|
302 |
+
)
|
303 |
+
)
|
304 |
+
sobelv = 1 - (
|
305 |
+
cv2.normalize(
|
306 |
+
sobelv, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F
|
307 |
+
)
|
308 |
+
)
|
309 |
+
|
310 |
+
overall = np.maximum(sobelh, sobelv)
|
311 |
+
overall = overall - (1 - blb)
|
312 |
+
overall[overall < 0] = 0
|
313 |
+
|
314 |
+
dist = (1.0 - overall) * blb
|
315 |
+
## nuclei values form mountains so inverse to get basins
|
316 |
+
dist = -cv2.GaussianBlur(dist, (3, 3), 0)
|
317 |
+
|
318 |
+
overall = np.array(overall >= overall_thres, dtype=np.int32)
|
319 |
+
|
320 |
+
marker = blb - overall
|
321 |
+
marker[marker < 0] = 0
|
322 |
+
marker = binary_fill_holes(marker).astype("uint8")
|
323 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
324 |
+
marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel)
|
325 |
+
marker = measurements.label(marker)[0]
|
326 |
+
marker = remove_small_objects(marker, min_size=obj_size_thres)
|
327 |
+
|
328 |
+
proced_pred = watershed(dist, markers=marker, mask=blb)
|
329 |
+
|
330 |
+
return proced_pred
|
331 |
+
|
332 |
+
####
|
333 |
+
def colorize(ch, vmin, vmax):
|
334 |
+
"""Will clamp value value outside the provided range to vmax and vmin."""
|
335 |
+
cmap = plt.get_cmap("jet")
|
336 |
+
ch = np.squeeze(ch.astype("float32"))
|
337 |
+
vmin = vmin if vmin is not None else ch.min()
|
338 |
+
vmax = vmax if vmax is not None else ch.max()
|
339 |
+
ch[ch > vmax] = vmax # clamp value
|
340 |
+
ch[ch < vmin] = vmin
|
341 |
+
ch = (ch - vmin) / (vmax - vmin + 1.0e-16)
|
342 |
+
# take RGB from RGBA heat map
|
343 |
+
ch_cmap = (cmap(ch)[..., :3] * 255).astype("uint8")
|
344 |
+
return ch_cmap
|
345 |
+
|
346 |
+
|
347 |
+
####
|
348 |
+
def random_colors(N, bright=True):
|
349 |
+
"""Generate random colors.
|
350 |
+
|
351 |
+
To get visually distinct colors, generate them in HSV space then
|
352 |
+
convert to RGB.
|
353 |
+
"""
|
354 |
+
brightness = 1.0 if bright else 0.7
|
355 |
+
hsv = [(i / N, 1, brightness) for i in range(N)]
|
356 |
+
colors = list(map(lambda c: colorsys.hsv_to_rgb(*c), hsv))
|
357 |
+
random.shuffle(colors)
|
358 |
+
return colors
|
359 |
+
|
360 |
+
|
361 |
+
####
|
362 |
+
def visualize_instances_map(
|
363 |
+
input_image, inst_map, type_map=None, type_colour=None, line_thickness=2
|
364 |
+
):
|
365 |
+
"""Overlays segmentation results on image as contours.
|
366 |
+
|
367 |
+
Args:
|
368 |
+
input_image: input image
|
369 |
+
inst_map: instance mask with unique value for every object
|
370 |
+
type_map: type mask with unique value for every class
|
371 |
+
type_colour: a dict of {type : colour} , `type` is from 0-N
|
372 |
+
and `colour` is a tuple of (R, G, B)
|
373 |
+
line_thickness: line thickness of contours
|
374 |
+
|
375 |
+
Returns:
|
376 |
+
overlay: output image with segmentation overlay as contours
|
377 |
+
"""
|
378 |
+
overlay = np.copy((input_image).astype(np.uint8))
|
379 |
+
|
380 |
+
inst_list = list(np.unique(inst_map)) # get list of instances
|
381 |
+
inst_list.remove(0) # remove background
|
382 |
+
|
383 |
+
inst_rng_colors = random_colors(len(inst_list))
|
384 |
+
inst_rng_colors = np.array(inst_rng_colors) * 255
|
385 |
+
inst_rng_colors = inst_rng_colors.astype(np.uint8)
|
386 |
+
|
387 |
+
for inst_idx, inst_id in enumerate(inst_list):
|
388 |
+
inst_map_mask = np.array(inst_map == inst_id, np.uint8) # get single object
|
389 |
+
y1, y2, x1, x2 = get_bounding_box(inst_map_mask)
|
390 |
+
y1 = y1 - 2 if y1 - 2 >= 0 else y1
|
391 |
+
x1 = x1 - 2 if x1 - 2 >= 0 else x1
|
392 |
+
x2 = x2 + 2 if x2 + 2 <= inst_map.shape[1] - 1 else x2
|
393 |
+
y2 = y2 + 2 if y2 + 2 <= inst_map.shape[0] - 1 else y2
|
394 |
+
inst_map_crop = inst_map_mask[y1:y2, x1:x2]
|
395 |
+
contours_crop = cv2.findContours(
|
396 |
+
inst_map_crop, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
397 |
+
)
|
398 |
+
# only has 1 instance per map, no need to check #contour detected by opencv
|
399 |
+
contours_crop = np.squeeze(
|
400 |
+
contours_crop[0][0].astype("int32")
|
401 |
+
) # * opencv protocol format may break
|
402 |
+
contours_crop += np.asarray([[x1, y1]]) # index correction
|
403 |
+
if type_map is not None:
|
404 |
+
type_map_crop = type_map[y1:y2, x1:x2]
|
405 |
+
type_id = np.unique(type_map_crop).max() # non-zero
|
406 |
+
inst_colour = type_colour[type_id]
|
407 |
+
else:
|
408 |
+
inst_colour = (inst_rng_colors[inst_idx]).tolist()
|
409 |
+
cv2.drawContours(overlay, [contours_crop], -1, inst_colour, line_thickness)
|
410 |
+
return overlay
|
411 |
+
|
412 |
+
|
413 |
+
def sliding_window_inference_large(inputs,block_size,min_overlap,context,roi_size,sw_batch_size,predictor,device):
|
414 |
+
|
415 |
+
h,w = inputs.shape[0],inputs.shape[1]
|
416 |
+
if h < 5000 or w < 5000:
|
417 |
+
test_tensor = torch.from_numpy(np.expand_dims(inputs, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
|
418 |
+
output_dist,output_prob = sliding_window_inference(test_tensor, roi_size, sw_batch_size, predictor)
|
419 |
+
prob = output_prob[0][0].cpu().numpy()
|
420 |
+
dist = output_dist[0].cpu().numpy()
|
421 |
+
dist = np.transpose(dist,(1,2,0))
|
422 |
+
dist = np.maximum(1e-3, dist)
|
423 |
+
if h*w < 1500*1500:
|
424 |
+
points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.55, nms_thresh=0.4,cut=True)
|
425 |
+
else:
|
426 |
+
points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4)
|
427 |
+
|
428 |
+
|
429 |
+
labels_out = polygons_to_label(disti, points, prob=probi,shape=prob.shape)
|
430 |
+
else:
|
431 |
+
n = inputs.ndim
|
432 |
+
axes = 'YXC'
|
433 |
+
grid = (1,1,1)
|
434 |
+
if np.isscalar(block_size): block_size = n*[block_size]
|
435 |
+
if np.isscalar(min_overlap): min_overlap = n*[min_overlap]
|
436 |
+
if np.isscalar(context): context = n*[context]
|
437 |
+
shape_out = (inputs.shape[0],inputs.shape[1])
|
438 |
+
labels_out = np.zeros(shape_out, dtype=np.uint64)
|
439 |
+
#print(inputs.dtype)
|
440 |
+
block_size[2] = inputs.shape[2]
|
441 |
+
min_overlap[2] = context[2] = 0
|
442 |
+
block_size = tuple(_grid_divisible(g, v, name='block_size', verbose=False) for v,g,a in zip(block_size, grid,axes))
|
443 |
+
min_overlap = tuple(_grid_divisible(g, v, name='min_overlap', verbose=False) for v,g,a in zip(min_overlap,grid,axes))
|
444 |
+
context = tuple(_grid_divisible(g, v, name='context', verbose=False) for v,g,a in zip(context, grid,axes))
|
445 |
+
print(f'effective: block_size={block_size}, min_overlap={min_overlap}, context={context}', flush=True)
|
446 |
+
blocks = BlockND.cover(inputs.shape, axes, block_size, min_overlap, context)
|
447 |
+
label_offset = 1
|
448 |
+
blocks = tqdm(blocks)
|
449 |
+
for block in blocks:
|
450 |
+
image = block.read(inputs, axes=axes)
|
451 |
+
test_tensor = torch.from_numpy(np.expand_dims(image, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device)
|
452 |
+
output_dist,output_prob = sliding_window_inference(test_tensor, roi_size, sw_batch_size, predictor)
|
453 |
+
prob = output_prob[0][0].cpu().numpy()
|
454 |
+
dist = output_dist[0].cpu().numpy()
|
455 |
+
dist = np.transpose(dist,(1,2,0))
|
456 |
+
dist = np.maximum(1e-3, dist)
|
457 |
+
points, probi, disti = non_maximum_suppression(dist,prob,prob_thresh=0.5, nms_thresh=0.4)
|
458 |
+
|
459 |
+
coord = dist_to_coord(disti,points)
|
460 |
+
polys = dict(coord=coord, points=points, prob=probi)
|
461 |
+
labels = polygons_to_label(disti, points, prob=probi,shape=prob.shape)
|
462 |
+
labels = block.crop_context(labels, axes='YX')
|
463 |
+
labels, polys = block.filter_objects(labels, polys, axes='YX')
|
464 |
+
labels = relabel_sequential(labels, label_offset)[0]
|
465 |
+
if labels_out is not None:
|
466 |
+
block.write(labels_out, labels, axes='YX')
|
467 |
+
#for k,v in polys.items():
|
468 |
+
#polys_all.setdefault(k,[]).append(v)
|
469 |
+
label_offset += len(polys['prob'])
|
470 |
+
del labels
|
471 |
+
#polys_all = {k: (np.concatenate(v) if k in OBJECT_KEYS else v[0]) for k,v in polys_all.items()}
|
472 |
+
return labels_out
|
473 |
+
def sliding_window_inference(
|
474 |
+
inputs: torch.Tensor,
|
475 |
+
roi_size: Union[Sequence[int], int],
|
476 |
+
sw_batch_size: int,
|
477 |
+
predictor: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]]],
|
478 |
+
overlap: float = 0.25,
|
479 |
+
mode: Union[BlendMode, str] = BlendMode.CONSTANT,
|
480 |
+
sigma_scale: Union[Sequence[float], float] = 0.125,
|
481 |
+
padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT,
|
482 |
+
cval: float = 0.0,
|
483 |
+
sw_device: Union[torch.device, str, None] = None,
|
484 |
+
device: Union[torch.device, str, None] = None,
|
485 |
+
progress: bool = False,
|
486 |
+
roi_weight_map: Union[torch.Tensor, None] = None,
|
487 |
+
*args: Any,
|
488 |
+
**kwargs: Any,
|
489 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[Any, torch.Tensor]]:
|
490 |
+
"""
|
491 |
+
Sliding window inference on `inputs` with `predictor`.
|
492 |
+
|
493 |
+
The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors.
|
494 |
+
Each output in the tuple or dict value is allowed to have different resolutions with respect to the input.
|
495 |
+
e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes
|
496 |
+
could be ([128,64,256], [64,32,128]).
|
497 |
+
In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still
|
498 |
+
an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters
|
499 |
+
so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension).
|
500 |
+
|
501 |
+
When roi_size is larger than the inputs' spatial size, the input image are padded during inference.
|
502 |
+
To maintain the same spatial sizes, the output image will be cropped to the original input size.
|
503 |
+
|
504 |
+
Args:
|
505 |
+
inputs: input image to be processed (assuming NCHW[D])
|
506 |
+
roi_size: the spatial window size for inferences.
|
507 |
+
When its components have None or non-positives, the corresponding inputs dimension will be used.
|
508 |
+
if the components of the `roi_size` are non-positive values, the transform will use the
|
509 |
+
corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
|
510 |
+
to `(32, 64)` if the second spatial dimension size of img is `64`.
|
511 |
+
sw_batch_size: the batch size to run window slices.
|
512 |
+
predictor: given input tensor ``patch_data`` in shape NCHW[D],
|
513 |
+
The outputs of the function call ``predictor(patch_data)`` should be a tensor, a tuple, or a dictionary
|
514 |
+
with Tensor values. Each output in the tuple or dict value should have the same batch_size, i.e. NM'H'W'[D'];
|
515 |
+
where H'W'[D'] represents the output patch's spatial size, M is the number of output channels,
|
516 |
+
N is `sw_batch_size`, e.g., the input shape is (7, 1, 128,128,128),
|
517 |
+
the output could be a tuple of two tensors, with shapes: ((7, 5, 128, 64, 256), (7, 4, 64, 32, 128)).
|
518 |
+
In this case, the parameter `overlap` and `roi_size` need to be carefully chosen
|
519 |
+
to ensure the scaled output ROI sizes are still integers.
|
520 |
+
If the `predictor`'s input and output spatial sizes are different,
|
521 |
+
we recommend choosing the parameters so that ``overlap*roi_size*zoom_scale`` is an integer for each dimension.
|
522 |
+
overlap: Amount of overlap between scans.
|
523 |
+
mode: {``"constant"``, ``"gaussian"``}
|
524 |
+
How to blend output of overlapping windows. Defaults to ``"constant"``.
|
525 |
+
|
526 |
+
- ``"constant``": gives equal weight to all predictions.
|
527 |
+
- ``"gaussian``": gives less weight to predictions on edges of windows.
|
528 |
+
|
529 |
+
sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
|
530 |
+
Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
|
531 |
+
When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
|
532 |
+
spatial dimensions.
|
533 |
+
padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
|
534 |
+
Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
|
535 |
+
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
|
536 |
+
cval: fill value for 'constant' padding mode. Default: 0
|
537 |
+
sw_device: device for the window data.
|
538 |
+
By default the device (and accordingly the memory) of the `inputs` is used.
|
539 |
+
Normally `sw_device` should be consistent with the device where `predictor` is defined.
|
540 |
+
device: device for the stitched output prediction.
|
541 |
+
By default the device (and accordingly the memory) of the `inputs` is used. If for example
|
542 |
+
set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
|
543 |
+
`inputs` and `roi_size`. Output is on the `device`.
|
544 |
+
progress: whether to print a `tqdm` progress bar.
|
545 |
+
roi_weight_map: pre-computed (non-negative) weight map for each ROI.
|
546 |
+
If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
|
547 |
+
args: optional args to be passed to ``predictor``.
|
548 |
+
kwargs: optional keyword args to be passed to ``predictor``.
|
549 |
+
|
550 |
+
Note:
|
551 |
+
- input must be channel-first and have a batch dim, supports N-D sliding window.
|
552 |
+
|
553 |
+
"""
|
554 |
+
compute_dtype = inputs.dtype
|
555 |
+
num_spatial_dims = len(inputs.shape) - 2
|
556 |
+
if overlap < 0 or overlap >= 1:
|
557 |
+
raise ValueError("overlap must be >= 0 and < 1.")
|
558 |
+
|
559 |
+
# determine image spatial size and batch size
|
560 |
+
# Note: all input images must have the same image size and batch size
|
561 |
+
batch_size, _, *image_size_ = inputs.shape
|
562 |
+
|
563 |
+
if device is None:
|
564 |
+
device = inputs.device
|
565 |
+
if sw_device is None:
|
566 |
+
sw_device = inputs.device
|
567 |
+
|
568 |
+
roi_size = fall_back_tuple(roi_size, image_size_)
|
569 |
+
# in case that image size is smaller than roi size
|
570 |
+
image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims))
|
571 |
+
pad_size = []
|
572 |
+
for k in range(len(inputs.shape) - 1, 1, -1):
|
573 |
+
diff = max(roi_size[k - 2] - inputs.shape[k], 0)
|
574 |
+
half = diff // 2
|
575 |
+
pad_size.extend([half, diff - half])
|
576 |
+
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)
|
577 |
+
#print('inputs',inputs.shape)
|
578 |
+
scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
|
579 |
+
|
580 |
+
# Store all slices in list
|
581 |
+
slices = dense_patch_slices(image_size, roi_size, scan_interval)
|
582 |
+
num_win = len(slices) # number of windows per image
|
583 |
+
total_slices = num_win * batch_size # total number of windows
|
584 |
+
|
585 |
+
# Create window-level importance map
|
586 |
+
valid_patch_size = get_valid_patch_size(image_size, roi_size)
|
587 |
+
if valid_patch_size == roi_size and (roi_weight_map is not None):
|
588 |
+
importance_map = roi_weight_map
|
589 |
+
else:
|
590 |
+
try:
|
591 |
+
importance_map = compute_importance_map(valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device)
|
592 |
+
except BaseException as e:
|
593 |
+
raise RuntimeError(
|
594 |
+
"Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'."
|
595 |
+
) from e
|
596 |
+
importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore
|
597 |
+
# handle non-positive weights
|
598 |
+
min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
|
599 |
+
importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(compute_dtype)
|
600 |
+
|
601 |
+
# Perform predictions
|
602 |
+
dict_key, output_image_list, count_map_list = None, [], []
|
603 |
+
_initialized_ss = -1
|
604 |
+
is_tensor_output = True # whether the predictor's output is a tensor (instead of dict/tuple)
|
605 |
+
|
606 |
+
# for each patch
|
607 |
+
for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size):
|
608 |
+
slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices))
|
609 |
+
unravel_slice = [
|
610 |
+
[slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win])
|
611 |
+
for idx in slice_range
|
612 |
+
]
|
613 |
+
window_data = torch.cat(
|
614 |
+
[convert_data_type(inputs[win_slice], torch.Tensor)[0] for win_slice in unravel_slice]
|
615 |
+
).to(sw_device)
|
616 |
+
seg_prob_out = predictor(window_data, *args, **kwargs) # batched patch segmentation
|
617 |
+
#print('seg_prob_out',seg_prob_out[0].shape)
|
618 |
+
# convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
|
619 |
+
seg_prob_tuple: Tuple[torch.Tensor, ...]
|
620 |
+
if isinstance(seg_prob_out, torch.Tensor):
|
621 |
+
seg_prob_tuple = (seg_prob_out,)
|
622 |
+
elif isinstance(seg_prob_out, Mapping):
|
623 |
+
if dict_key is None:
|
624 |
+
dict_key = sorted(seg_prob_out.keys()) # track predictor's output keys
|
625 |
+
seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key)
|
626 |
+
is_tensor_output = False
|
627 |
+
else:
|
628 |
+
seg_prob_tuple = ensure_tuple(seg_prob_out)
|
629 |
+
is_tensor_output = False
|
630 |
+
|
631 |
+
# for each output in multi-output list
|
632 |
+
for ss, seg_prob in enumerate(seg_prob_tuple):
|
633 |
+
seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN
|
634 |
+
|
635 |
+
# compute zoom scale: out_roi_size/in_roi_size
|
636 |
+
zoom_scale = []
|
637 |
+
for axis, (img_s_i, out_w_i, in_w_i) in enumerate(
|
638 |
+
zip(image_size, seg_prob.shape[2:], window_data.shape[2:])
|
639 |
+
):
|
640 |
+
_scale = out_w_i / float(in_w_i)
|
641 |
+
if not (img_s_i * _scale).is_integer():
|
642 |
+
warnings.warn(
|
643 |
+
f"For spatial axis: {axis}, output[{ss}] will have non-integer shape. Spatial "
|
644 |
+
f"zoom_scale between output[{ss}] and input is {_scale}. Please pad inputs."
|
645 |
+
)
|
646 |
+
zoom_scale.append(_scale)
|
647 |
+
|
648 |
+
if _initialized_ss < ss: # init. the ss-th buffer at the first iteration
|
649 |
+
# construct multi-resolution outputs
|
650 |
+
output_classes = seg_prob.shape[1]
|
651 |
+
output_shape = [batch_size, output_classes] + [
|
652 |
+
int(image_size_d * zoom_scale_d) for image_size_d, zoom_scale_d in zip(image_size, zoom_scale)
|
653 |
+
]
|
654 |
+
# allocate memory to store the full output and the count for overlapping parts
|
655 |
+
output_image_list.append(torch.zeros(output_shape, dtype=compute_dtype, device=device))
|
656 |
+
count_map_list.append(torch.zeros([1, 1] + output_shape[2:], dtype=compute_dtype, device=device))
|
657 |
+
_initialized_ss += 1
|
658 |
+
|
659 |
+
# resizing the importance_map
|
660 |
+
resizer = Resize(spatial_size=seg_prob.shape[2:], mode="nearest", anti_aliasing=False)
|
661 |
+
|
662 |
+
# store the result in the proper location of the full output. Apply weights from importance map.
|
663 |
+
for idx, original_idx in zip(slice_range, unravel_slice):
|
664 |
+
# zoom roi
|
665 |
+
original_idx_zoom = list(original_idx) # 4D for 2D image, 5D for 3D image
|
666 |
+
for axis in range(2, len(original_idx_zoom)):
|
667 |
+
zoomed_start = original_idx[axis].start * zoom_scale[axis - 2]
|
668 |
+
zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2]
|
669 |
+
if not zoomed_start.is_integer() or (not zoomed_end.is_integer()):
|
670 |
+
warnings.warn(
|
671 |
+
f"For axis-{axis-2} of output[{ss}], the output roi range is not int. "
|
672 |
+
f"Input roi range is ({original_idx[axis].start}, {original_idx[axis].stop}). "
|
673 |
+
f"Spatial zoom_scale between output[{ss}] and input is {zoom_scale[axis - 2]}. "
|
674 |
+
f"Corresponding output roi range is ({zoomed_start}, {zoomed_end}).\n"
|
675 |
+
f"Please change overlap ({overlap}) or roi_size ({roi_size[axis-2]}) for axis-{axis-2}. "
|
676 |
+
"Tips: if overlap*roi_size*zoom_scale is an integer, it usually works."
|
677 |
+
)
|
678 |
+
original_idx_zoom[axis] = slice(int(zoomed_start), int(zoomed_end), None)
|
679 |
+
importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(compute_dtype)
|
680 |
+
# store results and weights
|
681 |
+
output_image_list[ss][original_idx_zoom] += importance_map_zoom * seg_prob[idx - slice_g]
|
682 |
+
count_map_list[ss][original_idx_zoom] += (
|
683 |
+
importance_map_zoom.unsqueeze(0).unsqueeze(0).expand(count_map_list[ss][original_idx_zoom].shape)
|
684 |
+
)
|
685 |
+
|
686 |
+
# account for any overlapping sections
|
687 |
+
for ss in range(len(output_image_list)):
|
688 |
+
output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(compute_dtype)
|
689 |
+
|
690 |
+
# remove padding if image_size smaller than roi_size
|
691 |
+
for ss, output_i in enumerate(output_image_list):
|
692 |
+
if torch.isnan(output_i).any() or torch.isinf(output_i).any():
|
693 |
+
warnings.warn("Sliding window inference results contain NaN or Inf.")
|
694 |
+
|
695 |
+
zoom_scale = [
|
696 |
+
seg_prob_map_shape_d / roi_size_d for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size)
|
697 |
+
]
|
698 |
+
|
699 |
+
final_slicing: List[slice] = []
|
700 |
+
for sp in range(num_spatial_dims):
|
701 |
+
slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2])
|
702 |
+
slice_dim = slice(
|
703 |
+
int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])),
|
704 |
+
int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])),
|
705 |
+
)
|
706 |
+
final_slicing.insert(0, slice_dim)
|
707 |
+
while len(final_slicing) < len(output_i.shape):
|
708 |
+
final_slicing.insert(0, slice(None))
|
709 |
+
output_image_list[ss] = output_i[final_slicing]
|
710 |
+
|
711 |
+
if dict_key is not None: # if output of predictor is a dict
|
712 |
+
final_output = dict(zip(dict_key, output_image_list))
|
713 |
+
else:
|
714 |
+
final_output = tuple(output_image_list) # type: ignore
|
715 |
+
final_output = final_output[0] if is_tensor_output else final_output
|
716 |
+
|
717 |
+
if isinstance(inputs, MetaTensor):
|
718 |
+
final_output = convert_to_dst_type(final_output, inputs, device=device)[0] # type: ignore
|
719 |
+
return final_output
|
720 |
+
|
721 |
+
|
722 |
+
def _get_scan_interval(
|
723 |
+
image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float
|
724 |
+
) -> Tuple[int, ...]:
|
725 |
+
"""
|
726 |
+
Compute scan interval according to the image size, roi size and overlap.
|
727 |
+
Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
|
728 |
+
use 1 instead to make sure sliding window works.
|
729 |
+
|
730 |
+
"""
|
731 |
+
if len(image_size) != num_spatial_dims:
|
732 |
+
raise ValueError("image coord different from spatial dims.")
|
733 |
+
if len(roi_size) != num_spatial_dims:
|
734 |
+
raise ValueError("roi coord different from spatial dims.")
|
735 |
+
|
736 |
+
scan_interval = []
|
737 |
+
for i in range(num_spatial_dims):
|
738 |
+
if roi_size[i] == image_size[i]:
|
739 |
+
scan_interval.append(int(roi_size[i]))
|
740 |
+
else:
|
741 |
+
interval = int(roi_size[i] * (1 - overlap))
|
742 |
+
scan_interval.append(interval if interval > 0 else 1)
|
743 |
+
return tuple(scan_interval)
|