ghlee94 commited on
Commit
2a13495
1 Parent(s): af96727
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +7 -0
  2. main_model.pt +3 -0
  3. predict.py +1256 -0
  4. predict.sh +1 -0
  5. requirements.txt +83 -0
  6. save_model.py +99 -0
  7. segmentation_models_pytorch/__init__.py +61 -0
  8. segmentation_models_pytorch/__pycache__/__init__.cpython-37.pyc +0 -0
  9. segmentation_models_pytorch/__pycache__/__init__.cpython-39.pyc +0 -0
  10. segmentation_models_pytorch/__pycache__/__version__.cpython-37.pyc +0 -0
  11. segmentation_models_pytorch/__pycache__/__version__.cpython-39.pyc +0 -0
  12. segmentation_models_pytorch/__version__.py +3 -0
  13. segmentation_models_pytorch/base/__init__.py +11 -0
  14. segmentation_models_pytorch/base/__pycache__/__init__.cpython-37.pyc +0 -0
  15. segmentation_models_pytorch/base/__pycache__/__init__.cpython-39.pyc +0 -0
  16. segmentation_models_pytorch/base/__pycache__/heads.cpython-37.pyc +0 -0
  17. segmentation_models_pytorch/base/__pycache__/heads.cpython-39.pyc +0 -0
  18. segmentation_models_pytorch/base/__pycache__/initialization.cpython-37.pyc +0 -0
  19. segmentation_models_pytorch/base/__pycache__/initialization.cpython-39.pyc +0 -0
  20. segmentation_models_pytorch/base/__pycache__/model.cpython-37.pyc +0 -0
  21. segmentation_models_pytorch/base/__pycache__/model.cpython-39.pyc +0 -0
  22. segmentation_models_pytorch/base/__pycache__/modules.cpython-37.pyc +0 -0
  23. segmentation_models_pytorch/base/__pycache__/modules.cpython-39.pyc +0 -0
  24. segmentation_models_pytorch/base/heads.py +34 -0
  25. segmentation_models_pytorch/base/initialization.py +27 -0
  26. segmentation_models_pytorch/base/model.py +64 -0
  27. segmentation_models_pytorch/base/modules.py +131 -0
  28. segmentation_models_pytorch/datasets/__init__.py +1 -0
  29. segmentation_models_pytorch/datasets/__pycache__/__init__.cpython-37.pyc +0 -0
  30. segmentation_models_pytorch/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  31. segmentation_models_pytorch/datasets/__pycache__/oxford_pet.cpython-37.pyc +0 -0
  32. segmentation_models_pytorch/datasets/__pycache__/oxford_pet.cpython-39.pyc +0 -0
  33. segmentation_models_pytorch/datasets/oxford_pet.py +136 -0
  34. segmentation_models_pytorch/decoders/__init__.py +0 -0
  35. segmentation_models_pytorch/decoders/__pycache__/__init__.cpython-37.pyc +0 -0
  36. segmentation_models_pytorch/decoders/__pycache__/__init__.cpython-39.pyc +0 -0
  37. segmentation_models_pytorch/decoders/deeplabv3/__init__.py +1 -0
  38. segmentation_models_pytorch/decoders/deeplabv3/__pycache__/__init__.cpython-37.pyc +0 -0
  39. segmentation_models_pytorch/decoders/deeplabv3/__pycache__/__init__.cpython-39.pyc +0 -0
  40. segmentation_models_pytorch/decoders/deeplabv3/__pycache__/decoder.cpython-37.pyc +0 -0
  41. segmentation_models_pytorch/decoders/deeplabv3/__pycache__/decoder.cpython-39.pyc +0 -0
  42. segmentation_models_pytorch/decoders/deeplabv3/__pycache__/model.cpython-37.pyc +0 -0
  43. segmentation_models_pytorch/decoders/deeplabv3/__pycache__/model.cpython-39.pyc +0 -0
  44. segmentation_models_pytorch/decoders/deeplabv3/decoder.py +220 -0
  45. segmentation_models_pytorch/decoders/deeplabv3/model.py +179 -0
  46. segmentation_models_pytorch/decoders/fpn/__init__.py +1 -0
  47. segmentation_models_pytorch/decoders/fpn/__pycache__/__init__.cpython-37.pyc +0 -0
  48. segmentation_models_pytorch/decoders/fpn/__pycache__/__init__.cpython-39.pyc +0 -0
  49. segmentation_models_pytorch/decoders/fpn/__pycache__/decoder.cpython-37.pyc +0 -0
  50. segmentation_models_pytorch/decoders/fpn/__pycache__/decoder.cpython-39.pyc +0 -0
app.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def greet(name):
4
+ return "Hello " + name + "!!"
5
+
6
+ iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
+ iface.launch()
main_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6817c7bdd29a33ed9379f72d082390bb4052fb307744671834ff6c011cefd051
3
+ size 485832489
predict.py ADDED
@@ -0,0 +1,1256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import (
3
+ Module,
4
+ Conv2d,
5
+ BatchNorm2d,
6
+ Identity,
7
+ UpsamplingBilinear2d,
8
+ Mish,
9
+ ReLU,
10
+ Sequential,
11
+ )
12
+ from torch.nn.functional import interpolate, grid_sample, pad
13
+ import numpy as np
14
+ from copy import deepcopy
15
+ import os, argparse, math
16
+ import tifffile as tif
17
+ from typing import Tuple, List, Mapping
18
+
19
+ from monai.utils import (
20
+ BlendMode,
21
+ PytorchPadMode,
22
+ convert_data_type,
23
+ ensure_tuple,
24
+ fall_back_tuple,
25
+ look_up_option,
26
+ convert_to_dst_type,
27
+ )
28
+ from monai.utils.misc import ensure_tuple_size, ensure_tuple_rep, issequenceiterable
29
+ from monai.networks.layers.convutils import gaussian_1d
30
+ from monai.networks.layers.simplelayers import separable_filtering
31
+
32
+ from segmentation_models_pytorch import MAnet
33
+
34
+ from skimage.io import imread as io_imread
35
+ from skimage.util.dtype import dtype_range
36
+ from skimage._shared.utils import _supported_float_type
37
+ from scipy.ndimage import find_objects, binary_fill_holes
38
+
39
+
40
+ ########################### Data Loading Modules #########################################################
41
+ DTYPE_RANGE = dtype_range.copy()
42
+ DTYPE_RANGE.update((d.__name__, limits) for d, limits in dtype_range.items())
43
+ DTYPE_RANGE.update(
44
+ {
45
+ "uint10": (0, 2 ** 10 - 1),
46
+ "uint12": (0, 2 ** 12 - 1),
47
+ "uint14": (0, 2 ** 14 - 1),
48
+ "bool": dtype_range[bool],
49
+ "float": dtype_range[np.float64],
50
+ }
51
+ )
52
+
53
+
54
+ def _output_dtype(dtype_or_range, image_dtype):
55
+ if type(dtype_or_range) in [list, tuple, np.ndarray]:
56
+ # pair of values: always return float.
57
+ return _supported_float_type(image_dtype)
58
+ if type(dtype_or_range) == type:
59
+ # already a type: return it
60
+ return dtype_or_range
61
+ if dtype_or_range in DTYPE_RANGE:
62
+ # string key in DTYPE_RANGE dictionary
63
+ try:
64
+ # if it's a canonical numpy dtype, convert
65
+ return np.dtype(dtype_or_range).type
66
+ except TypeError: # uint10, uint12, uint14
67
+ # otherwise, return uint16
68
+ return np.uint16
69
+ else:
70
+ raise ValueError(
71
+ "Incorrect value for out_range, should be a valid image data "
72
+ f"type or a pair of values, got {dtype_or_range}."
73
+ )
74
+
75
+
76
+ def intensity_range(image, range_values="image", clip_negative=False):
77
+ if range_values == "dtype":
78
+ range_values = image.dtype.type
79
+
80
+ if range_values == "image":
81
+ i_min = np.min(image)
82
+ i_max = np.max(image)
83
+ elif range_values in DTYPE_RANGE:
84
+ i_min, i_max = DTYPE_RANGE[range_values]
85
+ if clip_negative:
86
+ i_min = 0
87
+ else:
88
+ i_min, i_max = range_values
89
+ return i_min, i_max
90
+
91
+
92
+ def rescale_intensity(image, in_range="image", out_range="dtype"):
93
+ out_dtype = _output_dtype(out_range, image.dtype)
94
+
95
+ imin, imax = map(float, intensity_range(image, in_range))
96
+ omin, omax = map(
97
+ float, intensity_range(image, out_range, clip_negative=(imin >= 0))
98
+ )
99
+ image = np.clip(image, imin, imax)
100
+
101
+ if imin != imax:
102
+ image = (image - imin) / (imax - imin)
103
+ return np.asarray(image * (omax - omin) + omin, dtype=out_dtype)
104
+ else:
105
+ return np.clip(image, omin, omax).astype(out_dtype)
106
+
107
+
108
+ def _normalize(img):
109
+ non_zero_vals = img[np.nonzero(img)]
110
+ percentiles = np.percentile(non_zero_vals, [0, 99.5])
111
+ img_norm = rescale_intensity(
112
+ img, in_range=(percentiles[0], percentiles[1]), out_range="uint8"
113
+ )
114
+
115
+ return img_norm.astype(np.uint8)
116
+
117
+
118
+ def pred_transforms(filename):
119
+ # LoadImage
120
+ img = (
121
+ tif.imread(filename)
122
+ if filename.endswith(".tif") or filename.endswith(".tiff")
123
+ else io_imread(filename)
124
+ )
125
+
126
+ if len(img.shape) == 2:
127
+ img = np.repeat(np.expand_dims(img, axis=-1), 3, axis=-1)
128
+ elif len(img.shape) == 3 and img.shape[-1] > 3:
129
+ img = img[:, :, :3]
130
+
131
+ img = img.astype(np.float32)
132
+ img = _normalize(img)
133
+ img = np.moveaxis(img, -1, 0)
134
+ img = (img - img.min()) / (img.max() - img.min())
135
+
136
+ return torch.FloatTensor(img).unsqueeze(0)
137
+
138
+
139
+ ################################################################################
140
+
141
+ ########################### MODEL Architecture #################################
142
+ class SegformerGH(MAnet):
143
+ def __init__(
144
+ self,
145
+ encoder_name: str = "mit_b5",
146
+ encoder_weights="imagenet",
147
+ decoder_channels=(256, 128, 64, 32, 32),
148
+ decoder_pab_channels=256,
149
+ in_channels: int = 3,
150
+ classes: int = 3,
151
+ ):
152
+ super(SegformerGH, self).__init__(
153
+ encoder_name=encoder_name,
154
+ encoder_weights=encoder_weights,
155
+ decoder_channels=decoder_channels,
156
+ decoder_pab_channels=decoder_pab_channels,
157
+ in_channels=in_channels,
158
+ classes=classes,
159
+ )
160
+
161
+ convert_relu_to_mish(self.encoder)
162
+ convert_relu_to_mish(self.decoder)
163
+
164
+ self.cellprob_head = DeepSegmantationHead(
165
+ in_channels=decoder_channels[-1], out_channels=1, kernel_size=3,
166
+ )
167
+ self.gradflow_head = DeepSegmantationHead(
168
+ in_channels=decoder_channels[-1], out_channels=2, kernel_size=3,
169
+ )
170
+
171
+ def forward(self, x):
172
+ """Sequentially pass `x` trough model`s encoder, decoder and heads"""
173
+ self.check_input_shape(x)
174
+
175
+ features = self.encoder(x)
176
+ decoder_output = self.decoder(*features)
177
+
178
+ gradflow_mask = self.gradflow_head(decoder_output)
179
+ cellprob_mask = self.cellprob_head(decoder_output)
180
+
181
+ masks = torch.cat([gradflow_mask, cellprob_mask], dim=1)
182
+
183
+ return masks
184
+
185
+
186
+ class DeepSegmantationHead(Sequential):
187
+ def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
188
+ conv2d_1 = Conv2d(
189
+ in_channels,
190
+ in_channels // 2,
191
+ kernel_size=kernel_size,
192
+ padding=kernel_size // 2,
193
+ )
194
+ bn = BatchNorm2d(in_channels // 2)
195
+ conv2d_2 = Conv2d(
196
+ in_channels // 2,
197
+ out_channels,
198
+ kernel_size=kernel_size,
199
+ padding=kernel_size // 2,
200
+ )
201
+ mish = Mish(inplace=True)
202
+
203
+ upsampling = (
204
+ UpsamplingBilinear2d(scale_factor=upsampling)
205
+ if upsampling > 1
206
+ else Identity()
207
+ )
208
+ activation = Identity()
209
+ super().__init__(conv2d_1, mish, bn, conv2d_2, upsampling, activation)
210
+
211
+
212
+ def convert_relu_to_mish(model):
213
+ for child_name, child in model.named_children():
214
+ if isinstance(child, ReLU):
215
+ setattr(model, child_name, Mish(inplace=True))
216
+ else:
217
+ convert_relu_to_mish(child)
218
+
219
+
220
+ #####################################################################################
221
+
222
+ ########################### Sliding Window Inference #################################
223
+ class GaussianFilter(Module):
224
+ def __init__(
225
+ self, spatial_dims, sigma, truncated=4.0, approx="erf", requires_grad=False,
226
+ ) -> None:
227
+ if issequenceiterable(sigma):
228
+ if len(sigma) != spatial_dims: # type: ignore
229
+ raise ValueError
230
+ else:
231
+ sigma = [deepcopy(sigma) for _ in range(spatial_dims)] # type: ignore
232
+ super().__init__()
233
+ self.sigma = [
234
+ torch.nn.Parameter(
235
+ torch.as_tensor(
236
+ s,
237
+ dtype=torch.float,
238
+ device=s.device if isinstance(s, torch.Tensor) else None,
239
+ ),
240
+ requires_grad=requires_grad,
241
+ )
242
+ for s in sigma # type: ignore
243
+ ]
244
+ self.truncated = truncated
245
+ self.approx = approx
246
+ for idx, param in enumerate(self.sigma):
247
+ self.register_parameter(f"kernel_sigma_{idx}", param)
248
+
249
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
250
+ _kernel = [
251
+ gaussian_1d(s, truncated=self.truncated, approx=self.approx)
252
+ for s in self.sigma
253
+ ]
254
+ return separable_filtering(x=x, kernels=_kernel)
255
+
256
+
257
+ def compute_importance_map(
258
+ patch_size, mode=BlendMode.CONSTANT, sigma_scale=0.125, device="cpu"
259
+ ):
260
+ mode = look_up_option(mode, BlendMode)
261
+ device = torch.device(device)
262
+
263
+ center_coords = [i // 2 for i in patch_size]
264
+ sigma_scale = ensure_tuple_rep(sigma_scale, len(patch_size))
265
+ sigmas = [i * sigma_s for i, sigma_s in zip(patch_size, sigma_scale)]
266
+
267
+ importance_map = torch.zeros(patch_size, device=device)
268
+ importance_map[tuple(center_coords)] = 1
269
+ pt_gaussian = GaussianFilter(len(patch_size), sigmas).to(
270
+ device=device, dtype=torch.float
271
+ )
272
+ importance_map = pt_gaussian(importance_map.unsqueeze(0).unsqueeze(0))
273
+ importance_map = importance_map.squeeze(0).squeeze(0)
274
+ importance_map = importance_map / torch.max(importance_map)
275
+ importance_map = importance_map.float()
276
+
277
+ return importance_map
278
+
279
+
280
+ def first(iterable, default=None):
281
+ for i in iterable:
282
+ return i
283
+
284
+ return default
285
+
286
+
287
+ def dense_patch_slices(image_size, patch_size, scan_interval):
288
+ num_spatial_dims = len(image_size)
289
+ patch_size = get_valid_patch_size(image_size, patch_size)
290
+ scan_interval = ensure_tuple_size(scan_interval, num_spatial_dims)
291
+
292
+ scan_num = []
293
+ for i in range(num_spatial_dims):
294
+ if scan_interval[i] == 0:
295
+ scan_num.append(1)
296
+ else:
297
+ num = int(math.ceil(float(image_size[i]) / scan_interval[i]))
298
+ scan_dim = first(
299
+ d
300
+ for d in range(num)
301
+ if d * scan_interval[i] + patch_size[i] >= image_size[i]
302
+ )
303
+ scan_num.append(scan_dim + 1 if scan_dim is not None else 1)
304
+
305
+ starts = []
306
+ for dim in range(num_spatial_dims):
307
+ dim_starts = []
308
+ for idx in range(scan_num[dim]):
309
+ start_idx = idx * scan_interval[dim]
310
+ start_idx -= max(start_idx + patch_size[dim] - image_size[dim], 0)
311
+ dim_starts.append(start_idx)
312
+ starts.append(dim_starts)
313
+ out = np.asarray([x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T
314
+ return [tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out]
315
+
316
+
317
+ def get_valid_patch_size(image_size, patch_size):
318
+ ndim = len(image_size)
319
+ patch_size_ = ensure_tuple_size(patch_size, ndim)
320
+
321
+ # ensure patch size dimensions are not larger than image dimension, if a dimension is None or 0 use whole dimension
322
+ return tuple(min(ms, ps or ms) for ms, ps in zip(image_size, patch_size_))
323
+
324
+
325
+ class Resize:
326
+ def __init__(self, spatial_size):
327
+ self.size_mode = "all"
328
+ self.spatial_size = spatial_size
329
+
330
+ def __call__(self, img):
331
+ input_ndim = img.ndim - 1 # spatial ndim
332
+ output_ndim = len(ensure_tuple(self.spatial_size))
333
+
334
+ if output_ndim > input_ndim:
335
+ input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1)
336
+ img = img.reshape(input_shape)
337
+
338
+ spatial_size_ = fall_back_tuple(self.spatial_size, img.shape[1:])
339
+
340
+ if (
341
+ tuple(img.shape[1:]) == spatial_size_
342
+ ): # spatial shape is already the desired
343
+ return img
344
+
345
+ img_, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float)
346
+
347
+ resized = interpolate(
348
+ input=img_.unsqueeze(0), size=spatial_size_, mode="nearest",
349
+ )
350
+ out, *_ = convert_to_dst_type(resized.squeeze(0), img)
351
+ return out
352
+
353
+
354
+ def sliding_window_inference(
355
+ inputs,
356
+ roi_size,
357
+ sw_batch_size,
358
+ predictor,
359
+ overlap,
360
+ mode=BlendMode.CONSTANT,
361
+ sigma_scale=0.125,
362
+ padding_mode=PytorchPadMode.CONSTANT,
363
+ cval=0.0,
364
+ sw_device=None,
365
+ device=None,
366
+ roi_weight_map=None,
367
+ ):
368
+ compute_dtype = inputs.dtype
369
+ num_spatial_dims = len(inputs.shape) - 2
370
+ batch_size, _, *image_size_ = inputs.shape
371
+
372
+ roi_size = fall_back_tuple(roi_size, image_size_)
373
+ # in case that image size is smaller than roi size
374
+ image_size = tuple(
375
+ max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims)
376
+ )
377
+ pad_size = []
378
+
379
+ for k in range(len(inputs.shape) - 1, 1, -1):
380
+ diff = max(roi_size[k - 2] - inputs.shape[k], 0)
381
+ half = diff // 2
382
+ pad_size.extend([half, diff - half])
383
+
384
+ inputs = pad(
385
+ inputs,
386
+ pad=pad_size,
387
+ mode=look_up_option(padding_mode, PytorchPadMode).value,
388
+ value=cval,
389
+ )
390
+
391
+ scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)
392
+
393
+ # Store all slices in list
394
+ slices = dense_patch_slices(image_size, roi_size, scan_interval)
395
+ num_win = len(slices) # number of windows per image
396
+ total_slices = num_win * batch_size # total number of windows
397
+
398
+ # Create window-level importance map
399
+ valid_patch_size = get_valid_patch_size(image_size, roi_size)
400
+ if valid_patch_size == roi_size and (roi_weight_map is not None):
401
+ importance_map = roi_weight_map
402
+ else:
403
+ importance_map = compute_importance_map(
404
+ valid_patch_size, mode=mode, sigma_scale=sigma_scale, device=device
405
+ )
406
+
407
+ importance_map = convert_data_type(importance_map, torch.Tensor, device, compute_dtype)[0] # type: ignore
408
+ # handle non-positive weights
409
+ min_non_zero = max(importance_map[importance_map != 0].min().item(), 1e-3)
410
+ importance_map = torch.clamp(importance_map.to(torch.float32), min=min_non_zero).to(
411
+ compute_dtype
412
+ )
413
+
414
+ # Perform predictions
415
+ dict_key, output_image_list, count_map_list = None, [], []
416
+ _initialized_ss = -1
417
+ is_tensor_output = (
418
+ True # whether the predictor's output is a tensor (instead of dict/tuple)
419
+ )
420
+
421
+ # for each patch
422
+ for slice_g in range(0, total_slices, sw_batch_size):
423
+ slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices))
424
+ unravel_slice = [
425
+ [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)]
426
+ + list(slices[idx % num_win])
427
+ for idx in slice_range
428
+ ]
429
+ window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(
430
+ sw_device
431
+ )
432
+ seg_prob_out = predictor(window_data) # batched patch segmentation
433
+
434
+ # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory.
435
+ seg_prob_tuple: Tuple[torch.Tensor, ...]
436
+ if isinstance(seg_prob_out, torch.Tensor):
437
+ seg_prob_tuple = (seg_prob_out,)
438
+ elif isinstance(seg_prob_out, Mapping):
439
+ if dict_key is None:
440
+ dict_key = sorted(seg_prob_out.keys()) # track predictor's output keys
441
+ seg_prob_tuple = tuple(seg_prob_out[k] for k in dict_key)
442
+ is_tensor_output = False
443
+ else:
444
+ seg_prob_tuple = ensure_tuple(seg_prob_out)
445
+ is_tensor_output = False
446
+
447
+ # for each output in multi-output list
448
+ for ss, seg_prob in enumerate(seg_prob_tuple):
449
+ seg_prob = seg_prob.to(device) # BxCxMxNxP or BxCxMxN
450
+
451
+ # compute zoom scale: out_roi_size/in_roi_size
452
+ zoom_scale = []
453
+ for axis, (img_s_i, out_w_i, in_w_i) in enumerate(
454
+ zip(image_size, seg_prob.shape[2:], window_data.shape[2:])
455
+ ):
456
+ _scale = out_w_i / float(in_w_i)
457
+
458
+ zoom_scale.append(_scale)
459
+
460
+ if _initialized_ss < ss: # init. the ss-th buffer at the first iteration
461
+ # construct multi-resolution outputs
462
+ output_classes = seg_prob.shape[1]
463
+ output_shape = [batch_size, output_classes] + [
464
+ int(image_size_d * zoom_scale_d)
465
+ for image_size_d, zoom_scale_d in zip(image_size, zoom_scale)
466
+ ]
467
+ # allocate memory to store the full output and the count for overlapping parts
468
+ output_image_list.append(
469
+ torch.zeros(output_shape, dtype=compute_dtype, device=device)
470
+ )
471
+ count_map_list.append(
472
+ torch.zeros(
473
+ [1, 1] + output_shape[2:], dtype=compute_dtype, device=device
474
+ )
475
+ )
476
+ _initialized_ss += 1
477
+
478
+ # resizing the importance_map
479
+ resizer = Resize(spatial_size=seg_prob.shape[2:])
480
+
481
+ # store the result in the proper location of the full output. Apply weights from importance map.
482
+ for idx, original_idx in zip(slice_range, unravel_slice):
483
+ # zoom roi
484
+ original_idx_zoom = list(
485
+ original_idx
486
+ ) # 4D for 2D image, 5D for 3D image
487
+ for axis in range(2, len(original_idx_zoom)):
488
+ zoomed_start = original_idx[axis].start * zoom_scale[axis - 2]
489
+ zoomed_end = original_idx[axis].stop * zoom_scale[axis - 2]
490
+
491
+ original_idx_zoom[axis] = slice(
492
+ int(zoomed_start), int(zoomed_end), None
493
+ )
494
+ importance_map_zoom = resizer(importance_map.unsqueeze(0))[0].to(
495
+ compute_dtype
496
+ )
497
+ # store results and weights
498
+ output_image_list[ss][original_idx_zoom] += (
499
+ importance_map_zoom * seg_prob[idx - slice_g]
500
+ )
501
+ count_map_list[ss][original_idx_zoom] += (
502
+ importance_map_zoom.unsqueeze(0)
503
+ .unsqueeze(0)
504
+ .expand(count_map_list[ss][original_idx_zoom].shape)
505
+ )
506
+
507
+ # account for any overlapping sections
508
+ for ss in range(len(output_image_list)):
509
+ output_image_list[ss] = (output_image_list[ss] / count_map_list.pop(0)).to(
510
+ compute_dtype
511
+ )
512
+
513
+ # remove padding if image_size smaller than roi_size
514
+ for ss, output_i in enumerate(output_image_list):
515
+ zoom_scale = [
516
+ seg_prob_map_shape_d / roi_size_d
517
+ for seg_prob_map_shape_d, roi_size_d in zip(output_i.shape[2:], roi_size)
518
+ ]
519
+
520
+ final_slicing: List[slice] = []
521
+ for sp in range(num_spatial_dims):
522
+ slice_dim = slice(
523
+ pad_size[sp * 2],
524
+ image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2],
525
+ )
526
+ slice_dim = slice(
527
+ int(round(slice_dim.start * zoom_scale[num_spatial_dims - sp - 1])),
528
+ int(round(slice_dim.stop * zoom_scale[num_spatial_dims - sp - 1])),
529
+ )
530
+ final_slicing.insert(0, slice_dim)
531
+ while len(final_slicing) < len(output_i.shape):
532
+ final_slicing.insert(0, slice(None))
533
+ output_image_list[ss] = output_i[final_slicing]
534
+
535
+ if dict_key is not None: # if output of predictor is a dict
536
+ final_output = dict(zip(dict_key, output_image_list))
537
+ else:
538
+ final_output = tuple(output_image_list) # type: ignore
539
+
540
+ return final_output[0] if is_tensor_output else final_output # type: ignore
541
+
542
+
543
+ def _get_scan_interval(
544
+ image_size, roi_size, num_spatial_dims: int, overlap: float
545
+ ) -> Tuple[int, ...]:
546
+ scan_interval = []
547
+
548
+ for i in range(num_spatial_dims):
549
+ if roi_size[i] == image_size[i]:
550
+ scan_interval.append(int(roi_size[i]))
551
+ else:
552
+ interval = int(roi_size[i] * (1 - overlap))
553
+ scan_interval.append(interval if interval > 0 else 1)
554
+
555
+ return tuple(scan_interval)
556
+
557
+
558
+ #####################################################################################
559
+
560
+ ########################### Main Inference Functions #################################
561
+ def post_process(pred_mask, device):
562
+ dP, cellprob = pred_mask[:2], 1 / (1 + np.exp(-pred_mask[-1]))
563
+ H, W = pred_mask.shape[-2], pred_mask.shape[-1]
564
+
565
+ if np.prod(H * W) < (5000 * 5000):
566
+ pred_mask = compute_masks(
567
+ dP,
568
+ cellprob,
569
+ use_gpu=True,
570
+ flow_threshold=0.4,
571
+ device=device,
572
+ cellprob_threshold=0.4,
573
+ )[0]
574
+
575
+ else:
576
+ print("\n[Whole Slide] Grid Prediction starting...")
577
+ roi_size = 2000
578
+
579
+ # Get patch grid by roi_size
580
+ if H % roi_size != 0:
581
+ n_H = H // roi_size + 1
582
+ new_H = roi_size * n_H
583
+ else:
584
+ n_H = H // roi_size
585
+ new_H = H
586
+
587
+ if W % roi_size != 0:
588
+ n_W = W // roi_size + 1
589
+ new_W = roi_size * n_W
590
+ else:
591
+ n_W = W // roi_size
592
+ new_W = W
593
+
594
+ # Allocate values on the grid
595
+ pred_pad = np.zeros((new_H, new_W), dtype=np.uint32)
596
+ dP_pad = np.zeros((2, new_H, new_W), dtype=np.float32)
597
+ cellprob_pad = np.zeros((new_H, new_W), dtype=np.float32)
598
+
599
+ dP_pad[:, :H, :W], cellprob_pad[:H, :W] = dP, cellprob
600
+
601
+ for i in range(n_H):
602
+ for j in range(n_W):
603
+ print("Pred on Grid (%d, %d) processing..." % (i, j))
604
+ dP_roi = dP_pad[
605
+ :,
606
+ roi_size * i : roi_size * (i + 1),
607
+ roi_size * j : roi_size * (j + 1),
608
+ ]
609
+ cellprob_roi = cellprob_pad[
610
+ roi_size * i : roi_size * (i + 1),
611
+ roi_size * j : roi_size * (j + 1),
612
+ ]
613
+
614
+ pred_mask = compute_masks(
615
+ dP_roi,
616
+ cellprob_roi,
617
+ use_gpu=True,
618
+ flow_threshold=0.4,
619
+ device=device,
620
+ cellprob_threshold=0.4,
621
+ )[0]
622
+
623
+ pred_pad[
624
+ roi_size * i : roi_size * (i + 1),
625
+ roi_size * j : roi_size * (j + 1),
626
+ ] = pred_mask
627
+
628
+ pred_mask = pred_pad[:H, :W]
629
+
630
+ cell_idx, cell_sizes = np.unique(pred_mask, return_counts=True)
631
+ cell_idx, cell_sizes = cell_idx[1:], cell_sizes[1:]
632
+ cell_drop = np.where(cell_sizes < np.mean(cell_sizes) - 2.7 * np.std(cell_sizes))
633
+
634
+ for drop_cell in cell_idx[cell_drop]:
635
+ pred_mask[pred_mask == drop_cell] = 0
636
+
637
+ return pred_mask
638
+
639
+
640
+ def hflip(x):
641
+ """flip batch of images horizontally"""
642
+ return x.flip(3)
643
+
644
+
645
+ def vflip(x):
646
+ """flip batch of images vertically"""
647
+ return x.flip(2)
648
+
649
+
650
+ class DualTransform:
651
+ identity_param = None
652
+
653
+ def __init__(
654
+ self, name: str, params,
655
+ ):
656
+ self.params = params
657
+ self.pname = name
658
+
659
+ def apply_aug_image(self, image, *args, **params):
660
+ raise NotImplementedError
661
+
662
+ def apply_deaug_mask(self, mask, *args, **params):
663
+ raise NotImplementedError
664
+
665
+
666
+ class HorizontalFlip(DualTransform):
667
+ """Flip images horizontally (left->right)"""
668
+
669
+ identity_param = False
670
+
671
+ def __init__(self):
672
+ super().__init__("apply", [False, True])
673
+
674
+ def apply_aug_image(self, image, apply=False, **kwargs):
675
+ if apply:
676
+ image = hflip(image)
677
+ return image
678
+
679
+ def apply_deaug_mask(self, mask, apply=False, **kwargs):
680
+ if apply:
681
+ mask = hflip(mask)
682
+ return mask
683
+
684
+
685
+ class VerticalFlip(DualTransform):
686
+ """Flip images vertically (up->down)"""
687
+
688
+ identity_param = False
689
+
690
+ def __init__(self):
691
+ super().__init__("apply", [False, True])
692
+
693
+ def apply_aug_image(self, image, apply=False, **kwargs):
694
+ if apply:
695
+ image = vflip(image)
696
+ return image
697
+
698
+ def apply_deaug_mask(self, mask, apply=False, **kwargs):
699
+ if apply:
700
+ mask = vflip(mask)
701
+ return mask
702
+
703
+
704
+ #################### GradFlow Modules ##################################################
705
+ from scipy.ndimage.filters import maximum_filter1d
706
+ import scipy.ndimage
707
+ import fastremap
708
+ from skimage import morphology
709
+
710
+ from scipy.ndimage import mean
711
+
712
+ torch_GPU = torch.device("cuda")
713
+ torch_CPU = torch.device("cpu")
714
+
715
+
716
+ def _extend_centers_gpu(
717
+ neighbors, centers, isneighbor, Ly, Lx, n_iter=200, device=torch.device("cuda")
718
+ ):
719
+ if device is not None:
720
+ device = device
721
+ nimg = neighbors.shape[0] // 9
722
+ pt = torch.from_numpy(neighbors).to(device)
723
+
724
+ T = torch.zeros((nimg, Ly, Lx), dtype=torch.double, device=device)
725
+ meds = torch.from_numpy(centers.astype(int)).to(device).long()
726
+ isneigh = torch.from_numpy(isneighbor).to(device)
727
+ for i in range(n_iter):
728
+ T[:, meds[:, 0], meds[:, 1]] += 1
729
+ Tneigh = T[:, pt[:, :, 0], pt[:, :, 1]]
730
+ Tneigh *= isneigh
731
+ T[:, pt[0, :, 0], pt[0, :, 1]] = Tneigh.mean(axis=1)
732
+ del meds, isneigh, Tneigh
733
+ T = torch.log(1.0 + T)
734
+ # gradient positions
735
+ grads = T[:, pt[[2, 1, 4, 3], :, 0], pt[[2, 1, 4, 3], :, 1]]
736
+ del pt
737
+ dy = grads[:, 0] - grads[:, 1]
738
+ dx = grads[:, 2] - grads[:, 3]
739
+ del grads
740
+ mu_torch = np.stack((dy.cpu().squeeze(), dx.cpu().squeeze()), axis=-2)
741
+ return mu_torch
742
+
743
+
744
+ def diameters(masks):
745
+ _, counts = np.unique(np.int32(masks), return_counts=True)
746
+ counts = counts[1:]
747
+ md = np.median(counts ** 0.5)
748
+ if np.isnan(md):
749
+ md = 0
750
+ md /= (np.pi ** 0.5) / 2
751
+ return md, counts ** 0.5
752
+
753
+
754
+ def masks_to_flows_gpu(masks, device=None):
755
+ if device is None:
756
+ device = torch.device("cuda")
757
+
758
+ Ly0, Lx0 = masks.shape
759
+ Ly, Lx = Ly0 + 2, Lx0 + 2
760
+
761
+ masks_padded = np.zeros((Ly, Lx), np.int64)
762
+ masks_padded[1:-1, 1:-1] = masks
763
+
764
+ # get mask pixel neighbors
765
+ y, x = np.nonzero(masks_padded)
766
+ neighborsY = np.stack((y, y - 1, y + 1, y, y, y - 1, y - 1, y + 1, y + 1), axis=0)
767
+ neighborsX = np.stack((x, x, x, x - 1, x + 1, x - 1, x + 1, x - 1, x + 1), axis=0)
768
+ neighbors = np.stack((neighborsY, neighborsX), axis=-1)
769
+
770
+ # get mask centers
771
+ slices = scipy.ndimage.find_objects(masks)
772
+
773
+ centers = np.zeros((masks.max(), 2), "int")
774
+ for i, si in enumerate(slices):
775
+ if si is not None:
776
+ sr, sc = si
777
+
778
+ ly, lx = sr.stop - sr.start + 1, sc.stop - sc.start + 1
779
+ yi, xi = np.nonzero(masks[sr, sc] == (i + 1))
780
+ yi = yi.astype(np.int32) + 1 # add padding
781
+ xi = xi.astype(np.int32) + 1 # add padding
782
+ ymed = np.median(yi)
783
+ xmed = np.median(xi)
784
+ imin = np.argmin((xi - xmed) ** 2 + (yi - ymed) ** 2)
785
+ xmed = xi[imin]
786
+ ymed = yi[imin]
787
+ centers[i, 0] = ymed + sr.start
788
+ centers[i, 1] = xmed + sc.start
789
+
790
+ # get neighbor validator (not all neighbors are in same mask)
791
+ neighbor_masks = masks_padded[neighbors[:, :, 0], neighbors[:, :, 1]]
792
+ isneighbor = neighbor_masks == neighbor_masks[0]
793
+ ext = np.array(
794
+ [[sr.stop - sr.start + 1, sc.stop - sc.start + 1] for sr, sc in slices]
795
+ )
796
+ n_iter = 2 * (ext.sum(axis=1)).max()
797
+ # run diffusion
798
+ mu = _extend_centers_gpu(
799
+ neighbors, centers, isneighbor, Ly, Lx, n_iter=n_iter, device=device
800
+ )
801
+
802
+ # normalize
803
+ mu /= 1e-20 + (mu ** 2).sum(axis=0) ** 0.5
804
+
805
+ # put into original image
806
+ mu0 = np.zeros((2, Ly0, Lx0))
807
+ mu0[:, y - 1, x - 1] = mu
808
+ mu_c = np.zeros_like(mu0)
809
+ return mu0, mu_c
810
+
811
+
812
+ def masks_to_flows(masks, use_gpu=False, device=None):
813
+ if masks.max() == 0 or (masks != 0).sum() == 1:
814
+ # dynamics_logger.warning('empty masks!')
815
+ return np.zeros((2, *masks.shape), "float32")
816
+
817
+ if use_gpu:
818
+ if use_gpu and device is None:
819
+ device = torch_GPU
820
+ elif device is None:
821
+ device = torch_CPU
822
+ masks_to_flows_device = masks_to_flows_gpu
823
+
824
+ if masks.ndim == 3:
825
+ Lz, Ly, Lx = masks.shape
826
+ mu = np.zeros((3, Lz, Ly, Lx), np.float32)
827
+ for z in range(Lz):
828
+ mu0 = masks_to_flows_device(masks[z], device=device)[0]
829
+ mu[[1, 2], z] += mu0
830
+ for y in range(Ly):
831
+ mu0 = masks_to_flows_device(masks[:, y], device=device)[0]
832
+ mu[[0, 2], :, y] += mu0
833
+ for x in range(Lx):
834
+ mu0 = masks_to_flows_device(masks[:, :, x], device=device)[0]
835
+ mu[[0, 1], :, :, x] += mu0
836
+ return mu
837
+ elif masks.ndim == 2:
838
+ mu, mu_c = masks_to_flows_device(masks, device=device)
839
+ return mu
840
+
841
+ else:
842
+ raise ValueError("masks_to_flows only takes 2D or 3D arrays")
843
+
844
+
845
+ def steps2D_interp(p, dP, niter, use_gpu=False, device=None):
846
+ shape = dP.shape[1:]
847
+ if use_gpu:
848
+ if device is None:
849
+ device = torch_GPU
850
+ shape = (
851
+ np.array(shape)[[1, 0]].astype("float") - 1
852
+ ) # Y and X dimensions (dP is 2.Ly.Lx), flipped X-1, Y-1
853
+ pt = (
854
+ torch.from_numpy(p[[1, 0]].T).float().to(device).unsqueeze(0).unsqueeze(0)
855
+ ) # p is n_points by 2, so pt is [1 1 2 n_points]
856
+ im = (
857
+ torch.from_numpy(dP[[1, 0]]).float().to(device).unsqueeze(0)
858
+ ) # covert flow numpy array to tensor on GPU, add dimension
859
+ # normalize pt between 0 and 1, normalize the flow
860
+ for k in range(2):
861
+ im[:, k, :, :] *= 2.0 / shape[k]
862
+ pt[:, :, :, k] /= shape[k]
863
+
864
+ # normalize to between -1 and 1
865
+ pt = pt * 2 - 1
866
+
867
+ # here is where the stepping happens
868
+ for t in range(niter):
869
+ # align_corners default is False, just added to suppress warning
870
+ dPt = grid_sample(im, pt, align_corners=False)
871
+
872
+ for k in range(2): # clamp the final pixel locations
873
+ pt[:, :, :, k] = torch.clamp(
874
+ pt[:, :, :, k] + dPt[:, k, :, :], -1.0, 1.0
875
+ )
876
+
877
+ # undo the normalization from before, reverse order of operations
878
+ pt = (pt + 1) * 0.5
879
+ for k in range(2):
880
+ pt[:, :, :, k] *= shape[k]
881
+
882
+ p = pt[:, :, :, [1, 0]].cpu().numpy().squeeze().T
883
+ return p
884
+
885
+ else:
886
+ assert print("ho")
887
+
888
+
889
+ def follow_flows(dP, mask=None, niter=200, interp=True, use_gpu=True, device=None):
890
+ shape = np.array(dP.shape[1:]).astype(np.int32)
891
+ niter = np.uint32(niter)
892
+
893
+ p = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing="ij")
894
+ p = np.array(p).astype(np.float32)
895
+
896
+ inds = np.array(np.nonzero(np.abs(dP[0]) > 1e-3)).astype(np.int32).T
897
+
898
+ if inds.ndim < 2 or inds.shape[0] < 5:
899
+ return p, None
900
+
901
+ if not interp:
902
+ assert print("woo")
903
+
904
+ else:
905
+ p_interp = steps2D_interp(
906
+ p[:, inds[:, 0], inds[:, 1]], dP, niter, use_gpu=use_gpu, device=device
907
+ )
908
+ p[:, inds[:, 0], inds[:, 1]] = p_interp
909
+
910
+ return p, inds
911
+
912
+
913
+ def flow_error(maski, dP_net, use_gpu=False, device=None):
914
+ if dP_net.shape[1:] != maski.shape:
915
+ print("ERROR: net flow is not same size as predicted masks")
916
+ return
917
+
918
+ # flows predicted from estimated masks
919
+ dP_masks = masks_to_flows(maski, use_gpu=use_gpu, device=device)
920
+ # difference between predicted flows vs mask flows
921
+ flow_errors = np.zeros(maski.max())
922
+ for i in range(dP_masks.shape[0]):
923
+ flow_errors += mean(
924
+ (dP_masks[i] - dP_net[i] / 5.0) ** 2,
925
+ maski,
926
+ index=np.arange(1, maski.max() + 1),
927
+ )
928
+
929
+ return flow_errors, dP_masks
930
+
931
+
932
+ def remove_bad_flow_masks(masks, flows, threshold=0.4, use_gpu=False, device=None):
933
+ merrors, _ = flow_error(masks, flows, use_gpu, device)
934
+ badi = 1 + (merrors > threshold).nonzero()[0]
935
+ masks[np.isin(masks, badi)] = 0
936
+ return masks
937
+
938
+
939
+ def get_masks(p, iscell=None, rpad=20):
940
+ pflows = []
941
+ edges = []
942
+ shape0 = p.shape[1:]
943
+ dims = len(p)
944
+
945
+ for i in range(dims):
946
+ pflows.append(p[i].flatten().astype("int32"))
947
+ edges.append(np.arange(-0.5 - rpad, shape0[i] + 0.5 + rpad, 1))
948
+
949
+ h, _ = np.histogramdd(tuple(pflows), bins=edges)
950
+ hmax = h.copy()
951
+ for i in range(dims):
952
+ hmax = maximum_filter1d(hmax, 5, axis=i)
953
+
954
+ seeds = np.nonzero(np.logical_and(h - hmax > -1e-6, h > 10))
955
+ Nmax = h[seeds]
956
+ isort = np.argsort(Nmax)[::-1]
957
+ for s in seeds:
958
+ s = s[isort]
959
+
960
+ pix = list(np.array(seeds).T)
961
+
962
+ shape = h.shape
963
+ if dims == 3:
964
+ expand = np.nonzero(np.ones((3, 3, 3)))
965
+ else:
966
+ expand = np.nonzero(np.ones((3, 3)))
967
+ for e in expand:
968
+ e = np.expand_dims(e, 1)
969
+
970
+ for iter in range(5):
971
+ for k in range(len(pix)):
972
+ if iter == 0:
973
+ pix[k] = list(pix[k])
974
+ newpix = []
975
+ iin = []
976
+ for i, e in enumerate(expand):
977
+ epix = e[:, np.newaxis] + np.expand_dims(pix[k][i], 0) - 1
978
+ epix = epix.flatten()
979
+ iin.append(np.logical_and(epix >= 0, epix < shape[i]))
980
+ newpix.append(epix)
981
+ iin = np.all(tuple(iin), axis=0)
982
+ for p in newpix:
983
+ p = p[iin]
984
+ newpix = tuple(newpix)
985
+ igood = h[newpix] > 2
986
+ for i in range(dims):
987
+ pix[k][i] = newpix[i][igood]
988
+ if iter == 4:
989
+ pix[k] = tuple(pix[k])
990
+
991
+ M = np.zeros(h.shape, np.uint32)
992
+ for k in range(len(pix)):
993
+ M[pix[k]] = 1 + k
994
+
995
+ for i in range(dims):
996
+ pflows[i] = pflows[i] + rpad
997
+ M0 = M[tuple(pflows)]
998
+
999
+ # remove big masks
1000
+ uniq, counts = fastremap.unique(M0, return_counts=True)
1001
+ big = np.prod(shape0) * 0.9
1002
+ bigc = uniq[counts > big]
1003
+ if len(bigc) > 0 and (len(bigc) > 1 or bigc[0] != 0):
1004
+ M0 = fastremap.mask(M0, bigc)
1005
+ fastremap.renumber(M0, in_place=True) # convenient to guarantee non-skipped labels
1006
+ M0 = np.reshape(M0, shape0)
1007
+ return M0
1008
+
1009
+ def fill_holes_and_remove_small_masks(masks, min_size=15):
1010
+ """ fill holes in masks (2D/3D) and discard masks smaller than min_size (2D)
1011
+
1012
+ fill holes in each mask using scipy.ndimage.morphology.binary_fill_holes
1013
+ (might have issues at borders between cells, todo: check and fix)
1014
+
1015
+ Parameters
1016
+ ----------------
1017
+ masks: int, 2D or 3D array
1018
+ labelled masks, 0=NO masks; 1,2,...=mask labels,
1019
+ size [Ly x Lx] or [Lz x Ly x Lx]
1020
+ min_size: int (optional, default 15)
1021
+ minimum number of pixels per mask, can turn off with -1
1022
+ Returns
1023
+ ---------------
1024
+ masks: int, 2D or 3D array
1025
+ masks with holes filled and masks smaller than min_size removed,
1026
+ 0=NO masks; 1,2,...=mask labels,
1027
+ size [Ly x Lx] or [Lz x Ly x Lx]
1028
+
1029
+ """
1030
+
1031
+ slices = find_objects(masks)
1032
+ j = 0
1033
+ for i,slc in enumerate(slices):
1034
+ if slc is not None:
1035
+ msk = masks[slc] == (i+1)
1036
+ npix = msk.sum()
1037
+ if min_size > 0 and npix < min_size:
1038
+ masks[slc][msk] = 0
1039
+ elif npix > 0:
1040
+ if msk.ndim==3:
1041
+ for k in range(msk.shape[0]):
1042
+ msk[k] = binary_fill_holes(msk[k])
1043
+ else:
1044
+ msk = binary_fill_holes(msk)
1045
+ masks[slc][msk] = (j+1)
1046
+ j+=1
1047
+ return masks
1048
+
1049
+ def compute_masks(
1050
+ dP,
1051
+ cellprob,
1052
+ p=None,
1053
+ niter=200,
1054
+ cellprob_threshold=0.4,
1055
+ flow_threshold=0.4,
1056
+ interp=True,
1057
+ resize=None,
1058
+ use_gpu=False,
1059
+ device=None,
1060
+ ):
1061
+ """compute masks using dynamics from dP, cellprob, and boundary"""
1062
+
1063
+ cp_mask = cellprob > cellprob_threshold
1064
+ cp_mask = morphology.remove_small_holes(cp_mask, area_threshold=16)
1065
+ cp_mask = morphology.remove_small_objects(cp_mask, min_size=16)
1066
+
1067
+ if np.any(cp_mask): # mask at this point is a cell cluster binary map, not labels
1068
+ # follow flows
1069
+ if p is None:
1070
+ p, inds = follow_flows(
1071
+ dP * cp_mask / 5.0,
1072
+ niter=niter,
1073
+ interp=interp,
1074
+ use_gpu=use_gpu,
1075
+ device=device,
1076
+ )
1077
+ if inds is None:
1078
+ shape = resize if resize is not None else cellprob.shape
1079
+ mask = np.zeros(shape, np.uint16)
1080
+ p = np.zeros((len(shape), *shape), np.uint16)
1081
+ return mask, p
1082
+
1083
+ # calculate masks
1084
+ mask = get_masks(p, iscell=cp_mask)
1085
+
1086
+ # flow thresholding factored out of get_masks
1087
+ shape0 = p.shape[1:]
1088
+ if mask.max() > 0 and flow_threshold is not None and flow_threshold > 0:
1089
+ # make sure labels are unique at output of get_masks
1090
+ mask = remove_bad_flow_masks(
1091
+ mask, dP, threshold=flow_threshold, use_gpu=use_gpu, device=device
1092
+ )
1093
+
1094
+ mask = fill_holes_and_remove_small_masks(mask, min_size=15)
1095
+
1096
+ else: # nothing to compute, just make it compatible
1097
+ shape = resize if resize is not None else cellprob.shape
1098
+ mask = np.zeros(shape, np.uint16)
1099
+ p = np.zeros((len(shape), *shape), np.uint16)
1100
+ return mask, p
1101
+
1102
+ return mask, p
1103
+
1104
+ def main(args):
1105
+ model = torch.load(args.model_path, map_location=args.device)
1106
+ model.eval()
1107
+ hflip_tta = HorizontalFlip()
1108
+ vflip_tta = VerticalFlip()
1109
+
1110
+ img_names = sorted(os.listdir(args.input_path))
1111
+ os.makedirs(args.output_path, exist_ok=True)
1112
+
1113
+ for img_name in img_names:
1114
+ print(f"Segmenting {img_name}")
1115
+ img_path = os.path.join(args.input_path, img_name)
1116
+ img_data = pred_transforms(img_path)
1117
+ img_data = img_data.to(args.device)
1118
+ img_size = img_data.shape[-1] * img_data.shape[-2]
1119
+
1120
+ if img_size < 1150000 and 900000 < img_size:
1121
+ overlap = 0.5
1122
+ else:
1123
+ overlap = 0.6
1124
+
1125
+ with torch.no_grad():
1126
+ img0 = img_data
1127
+ outputs0 = sliding_window_inference(
1128
+ img0,
1129
+ 512,
1130
+ 4,
1131
+ model,
1132
+ padding_mode="reflect",
1133
+ mode="gaussian",
1134
+ overlap=overlap,
1135
+ device="cpu",
1136
+ )
1137
+ outputs0 = outputs0.cpu().squeeze()
1138
+
1139
+ if img_size < 2000 * 2000:
1140
+
1141
+ model.load_state_dict(torch.load(args.model_path2, map_location=args.device))
1142
+ model.eval()
1143
+
1144
+ img2 = hflip_tta.apply_aug_image(img_data, apply=True)
1145
+ outputs2 = sliding_window_inference(
1146
+ img2,
1147
+ 512,
1148
+ 4,
1149
+ model,
1150
+ padding_mode="reflect",
1151
+ mode="gauusian",
1152
+ overlap=overlap,
1153
+ device="cpu",
1154
+ )
1155
+ outputs2 = hflip_tta.apply_deaug_mask(outputs2, apply=True)
1156
+ outputs2 = outputs2.cpu().squeeze()
1157
+
1158
+ outputs = torch.zeros_like(outputs0)
1159
+ outputs[0] = (outputs0[0] + outputs2[0]) / 2
1160
+ outputs[1] = (outputs0[1] - outputs2[1]) / 2
1161
+ outputs[2] = (outputs0[2] + outputs2[2]) / 2
1162
+
1163
+ elif img_size < 5000*5000:
1164
+ # Hflip TTA
1165
+ img2 = hflip_tta.apply_aug_image(img_data, apply=True)
1166
+ outputs2 = sliding_window_inference(
1167
+ img2,
1168
+ 512,
1169
+ 4,
1170
+ model,
1171
+ padding_mode="reflect",
1172
+ mode="gaussian",
1173
+ overlap=overlap,
1174
+ device="cpu",
1175
+ )
1176
+ outputs2 = hflip_tta.apply_deaug_mask(outputs2, apply=True)
1177
+ outputs2 = outputs2.cpu().squeeze()
1178
+ img2 = img2.cpu()
1179
+
1180
+ ##################
1181
+ # #
1182
+ # ensemble #
1183
+ # #
1184
+ ##################
1185
+
1186
+ model.load_state_dict(torch.load(args.model_path2, map_location=args.device))
1187
+ model.eval()
1188
+
1189
+ img1 = img_data
1190
+ outputs1 = sliding_window_inference(
1191
+ img1,
1192
+ 512,
1193
+ 4,
1194
+ model,
1195
+ padding_mode="reflect",
1196
+ mode="gaussian",
1197
+ overlap=overlap,
1198
+ device="cpu",
1199
+ )
1200
+ outputs1 = outputs1.cpu().squeeze()
1201
+
1202
+ # Vflip TTA
1203
+ img3 = vflip_tta.apply_aug_image(img_data, apply=True)
1204
+ outputs3 = sliding_window_inference(
1205
+ img3,
1206
+ 512,
1207
+ 4,
1208
+ model,
1209
+ padding_mode="reflect",
1210
+ mode="gaussian",
1211
+ overlap=overlap,
1212
+ device="cpu",
1213
+ )
1214
+ outputs3 = vflip_tta.apply_deaug_mask(outputs3, apply=True)
1215
+ outputs3 = outputs3.cpu().squeeze()
1216
+ img3 = img3.cpu()
1217
+
1218
+ # Merge Results
1219
+ outputs = torch.zeros_like(outputs0)
1220
+ outputs[0] = (outputs0[0] + outputs1[0] + outputs2[0] - outputs3[0]) / 4
1221
+ outputs[1] = (outputs0[1] + outputs1[1] - outputs2[1] + outputs3[1]) / 4
1222
+ outputs[2] = (outputs0[2] + outputs1[2] + outputs2[2] + outputs3[2]) / 4
1223
+ else:
1224
+ outputs = outputs0
1225
+
1226
+ pred_mask = post_process(outputs.squeeze(0).cpu().numpy(), args.device)
1227
+
1228
+ file_path = os.path.join(
1229
+ args.output_path, img_name.split(".")[0] + "_label.tiff"
1230
+ )
1231
+
1232
+ tif.imwrite(file_path, pred_mask, compression="zlib")
1233
+
1234
+
1235
+ parser = argparse.ArgumentParser("Submission for Challenge", add_help=False)
1236
+ parser.add_argument("--model_path", default="./model.pt", type=str)
1237
+ parser.add_argument("--model_path2", default="./model_sec.pth", type=str)
1238
+
1239
+ # Dataset parameters
1240
+ parser.add_argument(
1241
+ "-i",
1242
+ "--input_path",
1243
+ default="/workspace/inputs/",
1244
+ type=str,
1245
+ help="training data path; subfolders: images, labels",
1246
+ )
1247
+ parser.add_argument(
1248
+ "-o", "--output_path", default="/workspace/outputs/", type=str, help="output path",
1249
+ )
1250
+ parser.add_argument("--device", default="cuda:0", type=str)
1251
+
1252
+ args = parser.parse_args()
1253
+
1254
+ if __name__ == "__main__":
1255
+ print("Starting")
1256
+ main(args)
predict.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ python predict.py -i "./inputs" -o "./outputs" --device "cuda:0" --model_path="./main_model.pt" --model_path2="./sub_model.pth"
requirements.txt ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ backcall @ file:///home/ktietz/src/ci/backcall_1611930011877/work
2
+ beautifulsoup4 @ file:///opt/conda/conda-bld/beautifulsoup4_1650462163268/work
3
+ brotlipy==0.7.0
4
+ certifi @ file:///opt/conda/conda-bld/certifi_1655968806487/work/certifi
5
+ cffi @ file:///opt/conda/conda-bld/cffi_1642701102775/work
6
+ chardet @ file:///tmp/build/80754af9/chardet_1607706768982/work
7
+ charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
8
+ colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work
9
+ coloredlogs==15.0.1
10
+ conda==4.13.0
11
+ conda-build==3.21.9
12
+ conda-content-trust @ file:///tmp/build/80754af9/conda-content-trust_1617045594566/work
13
+ conda-package-handling @ file:///tmp/build/80754af9/conda-package-handling_1649105789509/work
14
+ cryptography @ file:///tmp/build/80754af9/cryptography_1652083456434/work
15
+ decorator @ file:///opt/conda/conda-bld/decorator_1643638310831/work
16
+ fastremap==1.13.3
17
+ filelock @ file:///opt/conda/conda-bld/filelock_1647002191454/work
18
+ flatbuffers==22.9.24
19
+ glob2 @ file:///home/linux1/recipes/ci/glob2_1610991677669/work
20
+ huggingface-hub==0.10.1
21
+ humanfriendly==10.0
22
+ idna @ file:///tmp/build/80754af9/idna_1637925883363/work
23
+ imagecodecs==2021.11.20
24
+ imageio==2.22.2
25
+ importlib-metadata==5.0.0
26
+ itk==5.2.1.post1
27
+ itk-core==5.2.1.post1
28
+ itk-filtering==5.2.1.post1
29
+ itk-io==5.2.1.post1
30
+ itk-numerics==5.2.1.post1
31
+ itk-registration==5.2.1.post1
32
+ itk-segmentation==5.2.1.post1
33
+ jedi @ file:///tmp/build/80754af9/jedi_1644299024593/work
34
+ Jinja2==2.10.1
35
+ libarchive-c @ file:///tmp/build/80754af9/python-libarchive-c_1617780486945/work
36
+ MarkupSafe @ file:///tmp/build/80754af9/markupsafe_1621528142364/work
37
+ matplotlib-inline @ file:///tmp/build/80754af9/matplotlib-inline_1628242447089/work
38
+ mkl-fft==1.3.1
39
+ mkl-random @ file:///tmp/build/80754af9/mkl_random_1626179032232/work
40
+ mkl-service==2.4.0
41
+ monai==0.9.0
42
+ mpmath==1.2.1
43
+ networkx==2.6.3
44
+ numpy @ file:///opt/conda/conda-bld/numpy_and_numpy_base_1651563629415/work
45
+ onnxruntime-gpu==1.12.1
46
+ opencv-python==4.6.0.66
47
+ packaging==21.3
48
+ parso @ file:///opt/conda/conda-bld/parso_1641458642106/work
49
+ pexpect @ file:///tmp/build/80754af9/pexpect_1605563209008/work
50
+ pickleshare @ file:///tmp/build/80754af9/pickleshare_1606932040724/work
51
+ Pillow==9.0.1
52
+ pkginfo @ file:///tmp/build/80754af9/pkginfo_1643162084911/work
53
+ prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1633440160888/work
54
+ protobuf==4.21.8
55
+ psutil @ file:///tmp/build/80754af9/psutil_1612298016854/work
56
+ ptyprocess @ file:///tmp/build/80754af9/ptyprocess_1609355006118/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
57
+ pycosat==0.6.3
58
+ pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
59
+ Pygments @ file:///opt/conda/conda-bld/pygments_1644249106324/work
60
+ pyOpenSSL @ file:///opt/conda/conda-bld/pyopenssl_1643788558760/work
61
+ pyparsing==3.0.9
62
+ PySocks @ file:///tmp/build/80754af9/pysocks_1594394576006/work
63
+ pytz==2022.2.1
64
+ PyWavelets==1.3.0
65
+ PyYAML==6.0
66
+ requests @ file:///opt/conda/conda-bld/requests_1641824580448/work
67
+ ruamel-yaml-conda @ file:///tmp/build/80754af9/ruamel_yaml_1616016701961/work
68
+ scikit-image==0.19.3
69
+ scipy==1.7.2
70
+ six @ file:///tmp/build/80754af9/six_1644875935023/work
71
+ soupsieve @ file:///tmp/build/80754af9/soupsieve_1636706018808/work
72
+ sympy==1.10.1
73
+ tifffile==2021.11.2
74
+ timm==0.6.11
75
+ torch==1.12.1
76
+ torchtext==0.13.1
77
+ torchvision==0.13.1
78
+ tqdm==4.64.1
79
+ traitlets @ file:///tmp/build/80754af9/traitlets_1636710298902/work
80
+ typing_extensions @ file:///tmp/abs_ben9emwtky/croots/recipe/typing_extensions_1659638822008/work
81
+ urllib3 @ file:///opt/conda/conda-bld/urllib3_1643638302206/work
82
+ wcwidth @ file:///Users/ktietz/demo/mc3/conda-bld/wcwidth_1629357192024/work
83
+ zipp==3.9.0
save_model.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from segmentation_models_pytorch import MAnet
5
+ from segmentation_models_pytorch.base.modules import Activation
6
+
7
+
8
+ class SegformerGH(MAnet):
9
+ def __init__(
10
+ self,
11
+ encoder_name: str = "mit_b5",
12
+ encoder_weights="imagenet",
13
+ decoder_channels=(256, 128, 64, 32, 32),
14
+ decoder_pab_channels=256,
15
+ in_channels: int = 3,
16
+ classes: int = 3,
17
+ ):
18
+ super(SegformerGH, self).__init__(
19
+ encoder_name=encoder_name,
20
+ encoder_weights=encoder_weights,
21
+ decoder_channels=decoder_channels,
22
+ decoder_pab_channels=decoder_pab_channels,
23
+ in_channels=in_channels,
24
+ classes=classes,
25
+ )
26
+
27
+ convert_relu_to_mish(self.encoder)
28
+ convert_relu_to_mish(self.decoder)
29
+
30
+ self.cellprob_head = DeepSegmantationHead(
31
+ in_channels=decoder_channels[-1], out_channels=1, kernel_size=3,
32
+ )
33
+ self.gradflow_head = DeepSegmantationHead(
34
+ in_channels=decoder_channels[-1], out_channels=2, kernel_size=3,
35
+ )
36
+
37
+ def forward(self, x):
38
+ """Sequentially pass `x` trough model`s encoder, decoder and heads"""
39
+ self.check_input_shape(x)
40
+
41
+ features = self.encoder(x)
42
+ decoder_output = self.decoder(*features)
43
+
44
+ gradflow_mask = self.gradflow_head(decoder_output)
45
+ cellprob_mask = self.cellprob_head(decoder_output)
46
+
47
+ masks = torch.cat([gradflow_mask, cellprob_mask], dim=1)
48
+
49
+ return masks
50
+
51
+
52
+ class DeepSegmantationHead(nn.Sequential):
53
+ def __init__(
54
+ self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1
55
+ ):
56
+ conv2d_1 = nn.Conv2d(
57
+ in_channels,
58
+ in_channels // 2,
59
+ kernel_size=kernel_size,
60
+ padding=kernel_size // 2,
61
+ )
62
+ bn = nn.BatchNorm2d(in_channels // 2)
63
+ conv2d_2 = nn.Conv2d(
64
+ in_channels // 2,
65
+ out_channels,
66
+ kernel_size=kernel_size,
67
+ padding=kernel_size // 2,
68
+ )
69
+ mish = nn.Mish(inplace=True)
70
+
71
+ upsampling = (
72
+ nn.UpsamplingBilinear2d(scale_factor=upsampling)
73
+ if upsampling > 1
74
+ else nn.Identity()
75
+ )
76
+ activation = Activation(activation)
77
+ super().__init__(conv2d_1, mish, bn, conv2d_2, upsampling, activation)
78
+
79
+
80
+ def convert_relu_to_mish(model):
81
+ for child_name, child in model.named_children():
82
+ if isinstance(child, nn.ReLU):
83
+ setattr(model, child_name, nn.Mish(inplace=True))
84
+ else:
85
+ convert_relu_to_mish(child)
86
+
87
+
88
+ if __name__ == "__main__":
89
+ model = SegformerGH(
90
+ encoder_name="mit_b5",
91
+ encoder_weights=None,
92
+ decoder_channels=(1024, 512, 256, 128, 64),
93
+ decoder_pab_channels=256,
94
+ in_channels=3,
95
+ classes=3,
96
+ )
97
+
98
+ model.load_state_dict(torch.load("./main_model.pth",map_location="cpu"))
99
+ torch.save(model, "main_model.pt")
segmentation_models_pytorch/__init__.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import datasets
2
+ from . import encoders
3
+ from . import decoders
4
+ from . import losses
5
+ from . import metrics
6
+
7
+ from .decoders.unet import Unet
8
+ from .decoders.unetplusplus import UnetPlusPlus
9
+ from .decoders.manet import MAnet
10
+ from .decoders.linknet import Linknet
11
+ from .decoders.fpn import FPN
12
+ from .decoders.pspnet import PSPNet
13
+ from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus
14
+ from .decoders.pan import PAN
15
+
16
+ from .__version__ import __version__
17
+
18
+ # some private imports for create_model function
19
+ from typing import Optional as _Optional
20
+ import torch as _torch
21
+
22
+
23
+ def create_model(
24
+ arch: str,
25
+ encoder_name: str = "resnet34",
26
+ encoder_weights: _Optional[str] = "imagenet",
27
+ in_channels: int = 3,
28
+ classes: int = 1,
29
+ **kwargs,
30
+ ) -> _torch.nn.Module:
31
+ """Models entrypoint, allows to create any model architecture just with
32
+ parameters, without using its class
33
+ """
34
+
35
+ archs = [
36
+ Unet,
37
+ UnetPlusPlus,
38
+ MAnet,
39
+ Linknet,
40
+ FPN,
41
+ PSPNet,
42
+ DeepLabV3,
43
+ DeepLabV3Plus,
44
+ PAN,
45
+ ]
46
+ archs_dict = {a.__name__.lower(): a for a in archs}
47
+ try:
48
+ model_class = archs_dict[arch.lower()]
49
+ except KeyError:
50
+ raise KeyError(
51
+ "Wrong architecture type `{}`. Available options are: {}".format(
52
+ arch, list(archs_dict.keys()),
53
+ )
54
+ )
55
+ return model_class(
56
+ encoder_name=encoder_name,
57
+ encoder_weights=encoder_weights,
58
+ in_channels=in_channels,
59
+ classes=classes,
60
+ **kwargs,
61
+ )
segmentation_models_pytorch/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (1.72 kB). View file
 
segmentation_models_pytorch/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1.76 kB). View file
 
segmentation_models_pytorch/__pycache__/__version__.cpython-37.pyc ADDED
Binary file (217 Bytes). View file
 
segmentation_models_pytorch/__pycache__/__version__.cpython-39.pyc ADDED
Binary file (230 Bytes). View file
 
segmentation_models_pytorch/__version__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ VERSION = (0, 3, 0)
2
+
3
+ __version__ = ".".join(map(str, VERSION))
segmentation_models_pytorch/base/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .model import SegmentationModel
2
+
3
+ from .modules import (
4
+ Conv2dReLU,
5
+ Attention,
6
+ )
7
+
8
+ from .heads import (
9
+ SegmentationHead,
10
+ ClassificationHead,
11
+ )
segmentation_models_pytorch/base/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (326 Bytes). View file
 
segmentation_models_pytorch/base/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (339 Bytes). View file
 
segmentation_models_pytorch/base/__pycache__/heads.cpython-37.pyc ADDED
Binary file (1.54 kB). View file
 
segmentation_models_pytorch/base/__pycache__/heads.cpython-39.pyc ADDED
Binary file (1.55 kB). View file
 
segmentation_models_pytorch/base/__pycache__/initialization.cpython-37.pyc ADDED
Binary file (904 Bytes). View file
 
segmentation_models_pytorch/base/__pycache__/initialization.cpython-39.pyc ADDED
Binary file (910 Bytes). View file
 
segmentation_models_pytorch/base/__pycache__/model.cpython-37.pyc ADDED
Binary file (2.03 kB). View file
 
segmentation_models_pytorch/base/__pycache__/model.cpython-39.pyc ADDED
Binary file (2.08 kB). View file
 
segmentation_models_pytorch/base/__pycache__/modules.cpython-37.pyc ADDED
Binary file (4.3 kB). View file
 
segmentation_models_pytorch/base/__pycache__/modules.cpython-39.pyc ADDED
Binary file (4.27 kB). View file
 
segmentation_models_pytorch/base/heads.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from .modules import Activation
3
+
4
+
5
+ class SegmentationHead(nn.Sequential):
6
+ def __init__(
7
+ self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1
8
+ ):
9
+ conv2d = nn.Conv2d(
10
+ in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2
11
+ )
12
+ upsampling = (
13
+ nn.UpsamplingBilinear2d(scale_factor=upsampling)
14
+ if upsampling > 1
15
+ else nn.Identity()
16
+ )
17
+ activation = Activation(activation)
18
+ super().__init__(conv2d, upsampling, activation)
19
+
20
+
21
+ class ClassificationHead(nn.Sequential):
22
+ def __init__(
23
+ self, in_channels, classes, pooling="avg", dropout=0.2, activation=None
24
+ ):
25
+ if pooling not in ("max", "avg"):
26
+ raise ValueError(
27
+ "Pooling should be one of ('max', 'avg'), got {}.".format(pooling)
28
+ )
29
+ pool = nn.AdaptiveAvgPool2d(1) if pooling == "avg" else nn.AdaptiveMaxPool2d(1)
30
+ flatten = nn.Flatten()
31
+ dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity()
32
+ linear = nn.Linear(in_channels, classes, bias=True)
33
+ activation = Activation(activation)
34
+ super().__init__(pool, flatten, dropout, linear, activation)
segmentation_models_pytorch/base/initialization.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def initialize_decoder(module):
5
+ for m in module.modules():
6
+
7
+ if isinstance(m, nn.Conv2d):
8
+ nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
9
+ if m.bias is not None:
10
+ nn.init.constant_(m.bias, 0)
11
+
12
+ elif isinstance(m, nn.BatchNorm2d):
13
+ nn.init.constant_(m.weight, 1)
14
+ nn.init.constant_(m.bias, 0)
15
+
16
+ elif isinstance(m, nn.Linear):
17
+ nn.init.xavier_uniform_(m.weight)
18
+ if m.bias is not None:
19
+ nn.init.constant_(m.bias, 0)
20
+
21
+
22
+ def initialize_head(module):
23
+ for m in module.modules():
24
+ if isinstance(m, (nn.Linear, nn.Conv2d)):
25
+ nn.init.xavier_uniform_(m.weight)
26
+ if m.bias is not None:
27
+ nn.init.constant_(m.bias, 0)
segmentation_models_pytorch/base/model.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import initialization as init
3
+
4
+
5
+ class SegmentationModel(torch.nn.Module):
6
+ def initialize(self):
7
+ init.initialize_decoder(self.decoder)
8
+ init.initialize_head(self.segmentation_head)
9
+ if self.classification_head is not None:
10
+ init.initialize_head(self.classification_head)
11
+
12
+ def check_input_shape(self, x):
13
+
14
+ h, w = x.shape[-2:]
15
+ output_stride = self.encoder.output_stride
16
+ if h % output_stride != 0 or w % output_stride != 0:
17
+ new_h = (
18
+ (h // output_stride + 1) * output_stride
19
+ if h % output_stride != 0
20
+ else h
21
+ )
22
+ new_w = (
23
+ (w // output_stride + 1) * output_stride
24
+ if w % output_stride != 0
25
+ else w
26
+ )
27
+ raise RuntimeError(
28
+ f"Wrong input shape height={h}, width={w}. Expected image height and width "
29
+ f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})."
30
+ )
31
+
32
+ def forward(self, x):
33
+ """Sequentially pass `x` trough model`s encoder, decoder and heads"""
34
+
35
+ self.check_input_shape(x)
36
+
37
+ features = self.encoder(x)
38
+ decoder_output = self.decoder(*features)
39
+
40
+ masks = self.segmentation_head(decoder_output)
41
+
42
+ if self.classification_head is not None:
43
+ labels = self.classification_head(features[-1])
44
+ return masks, labels
45
+
46
+ return masks
47
+
48
+ @torch.no_grad()
49
+ def predict(self, x):
50
+ """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()`
51
+
52
+ Args:
53
+ x: 4D torch tensor with shape (batch_size, channels, height, width)
54
+
55
+ Return:
56
+ prediction: 4D torch tensor with shape (batch_size, classes, height, width)
57
+
58
+ """
59
+ if self.training:
60
+ self.eval()
61
+
62
+ x = self.forward(x)
63
+
64
+ return x
segmentation_models_pytorch/base/modules.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ try:
5
+ from inplace_abn import InPlaceABN
6
+ except ImportError:
7
+ InPlaceABN = None
8
+
9
+
10
+ class Conv2dReLU(nn.Sequential):
11
+ def __init__(
12
+ self,
13
+ in_channels,
14
+ out_channels,
15
+ kernel_size,
16
+ padding=0,
17
+ stride=1,
18
+ use_batchnorm=True,
19
+ ):
20
+
21
+ if use_batchnorm == "inplace" and InPlaceABN is None:
22
+ raise RuntimeError(
23
+ "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
24
+ + "To install see: https://github.com/mapillary/inplace_abn"
25
+ )
26
+
27
+ conv = nn.Conv2d(
28
+ in_channels,
29
+ out_channels,
30
+ kernel_size,
31
+ stride=stride,
32
+ padding=padding,
33
+ bias=not (use_batchnorm),
34
+ )
35
+ relu = nn.ReLU(inplace=True)
36
+
37
+ if use_batchnorm == "inplace":
38
+ bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
39
+ relu = nn.Identity()
40
+
41
+ elif use_batchnorm and use_batchnorm != "inplace":
42
+ bn = nn.BatchNorm2d(out_channels)
43
+
44
+ else:
45
+ bn = nn.Identity()
46
+
47
+ super(Conv2dReLU, self).__init__(conv, bn, relu)
48
+
49
+
50
+ class SCSEModule(nn.Module):
51
+ def __init__(self, in_channels, reduction=16):
52
+ super().__init__()
53
+ self.cSE = nn.Sequential(
54
+ nn.AdaptiveAvgPool2d(1),
55
+ nn.Conv2d(in_channels, in_channels // reduction, 1),
56
+ nn.ReLU(inplace=True),
57
+ nn.Conv2d(in_channels // reduction, in_channels, 1),
58
+ nn.Sigmoid(),
59
+ )
60
+ self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
61
+
62
+ def forward(self, x):
63
+ return x * self.cSE(x) + x * self.sSE(x)
64
+
65
+
66
+ class ArgMax(nn.Module):
67
+ def __init__(self, dim=None):
68
+ super().__init__()
69
+ self.dim = dim
70
+
71
+ def forward(self, x):
72
+ return torch.argmax(x, dim=self.dim)
73
+
74
+
75
+ class Clamp(nn.Module):
76
+ def __init__(self, min=0, max=1):
77
+ super().__init__()
78
+ self.min, self.max = min, max
79
+
80
+ def forward(self, x):
81
+ return torch.clamp(x, self.min, self.max)
82
+
83
+
84
+ class Activation(nn.Module):
85
+ def __init__(self, name, **params):
86
+
87
+ super().__init__()
88
+
89
+ if name is None or name == "identity":
90
+ self.activation = nn.Identity(**params)
91
+ elif name == "sigmoid":
92
+ self.activation = nn.Sigmoid()
93
+ elif name == "softmax2d":
94
+ self.activation = nn.Softmax(dim=1, **params)
95
+ elif name == "softmax":
96
+ self.activation = nn.Softmax(**params)
97
+ elif name == "logsoftmax":
98
+ self.activation = nn.LogSoftmax(**params)
99
+ elif name == "tanh":
100
+ self.activation = nn.Tanh()
101
+ elif name == "argmax":
102
+ self.activation = ArgMax(**params)
103
+ elif name == "argmax2d":
104
+ self.activation = ArgMax(dim=1, **params)
105
+ elif name == "clamp":
106
+ self.activation = Clamp(**params)
107
+ elif callable(name):
108
+ self.activation = name(**params)
109
+ else:
110
+ raise ValueError(
111
+ f"Activation should be callable/sigmoid/softmax/logsoftmax/tanh/"
112
+ f"argmax/argmax2d/clamp/None; got {name}"
113
+ )
114
+
115
+ def forward(self, x):
116
+ return self.activation(x)
117
+
118
+
119
+ class Attention(nn.Module):
120
+ def __init__(self, name, **params):
121
+ super().__init__()
122
+
123
+ if name is None:
124
+ self.attention = nn.Identity(**params)
125
+ elif name == "scse":
126
+ self.attention = SCSEModule(**params)
127
+ else:
128
+ raise ValueError("Attention {} is not implemented".format(name))
129
+
130
+ def forward(self, x):
131
+ return self.attention(x)
segmentation_models_pytorch/datasets/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .oxford_pet import OxfordPetDataset, SimpleOxfordPetDataset
segmentation_models_pytorch/datasets/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (230 Bytes). View file
 
segmentation_models_pytorch/datasets/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (243 Bytes). View file
 
segmentation_models_pytorch/datasets/__pycache__/oxford_pet.cpython-37.pyc ADDED
Binary file (4.72 kB). View file
 
segmentation_models_pytorch/datasets/__pycache__/oxford_pet.cpython-39.pyc ADDED
Binary file (4.82 kB). View file
 
segmentation_models_pytorch/datasets/oxford_pet.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import shutil
4
+ import numpy as np
5
+
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+ from urllib.request import urlretrieve
9
+
10
+
11
+ class OxfordPetDataset(torch.utils.data.Dataset):
12
+ def __init__(self, root, mode="train", transform=None):
13
+
14
+ assert mode in {"train", "valid", "test"}
15
+
16
+ self.root = root
17
+ self.mode = mode
18
+ self.transform = transform
19
+
20
+ self.images_directory = os.path.join(self.root, "images")
21
+ self.masks_directory = os.path.join(self.root, "annotations", "trimaps")
22
+
23
+ self.filenames = self._read_split() # read train/valid/test splits
24
+
25
+ def __len__(self):
26
+ return len(self.filenames)
27
+
28
+ def __getitem__(self, idx):
29
+
30
+ filename = self.filenames[idx]
31
+ image_path = os.path.join(self.images_directory, filename + ".jpg")
32
+ mask_path = os.path.join(self.masks_directory, filename + ".png")
33
+
34
+ image = np.array(Image.open(image_path).convert("RGB"))
35
+
36
+ trimap = np.array(Image.open(mask_path))
37
+ mask = self._preprocess_mask(trimap)
38
+
39
+ sample = dict(image=image, mask=mask, trimap=trimap)
40
+ if self.transform is not None:
41
+ sample = self.transform(**sample)
42
+
43
+ return sample
44
+
45
+ @staticmethod
46
+ def _preprocess_mask(mask):
47
+ mask = mask.astype(np.float32)
48
+ mask[mask == 2.0] = 0.0
49
+ mask[(mask == 1.0) | (mask == 3.0)] = 1.0
50
+ return mask
51
+
52
+ def _read_split(self):
53
+ split_filename = "test.txt" if self.mode == "test" else "trainval.txt"
54
+ split_filepath = os.path.join(self.root, "annotations", split_filename)
55
+ with open(split_filepath) as f:
56
+ split_data = f.read().strip("\n").split("\n")
57
+ filenames = [x.split(" ")[0] for x in split_data]
58
+ if self.mode == "train": # 90% for train
59
+ filenames = [x for i, x in enumerate(filenames) if i % 10 != 0]
60
+ elif self.mode == "valid": # 10% for validation
61
+ filenames = [x for i, x in enumerate(filenames) if i % 10 == 0]
62
+ return filenames
63
+
64
+ @staticmethod
65
+ def download(root):
66
+
67
+ # load images
68
+ filepath = os.path.join(root, "images.tar.gz")
69
+ download_url(
70
+ url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz",
71
+ filepath=filepath,
72
+ )
73
+ extract_archive(filepath)
74
+
75
+ # load annotations
76
+ filepath = os.path.join(root, "annotations.tar.gz")
77
+ download_url(
78
+ url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz",
79
+ filepath=filepath,
80
+ )
81
+ extract_archive(filepath)
82
+
83
+
84
+ class SimpleOxfordPetDataset(OxfordPetDataset):
85
+ def __getitem__(self, *args, **kwargs):
86
+
87
+ sample = super().__getitem__(*args, **kwargs)
88
+
89
+ # resize images
90
+ image = np.array(
91
+ Image.fromarray(sample["image"]).resize((256, 256), Image.LINEAR)
92
+ )
93
+ mask = np.array(
94
+ Image.fromarray(sample["mask"]).resize((256, 256), Image.NEAREST)
95
+ )
96
+ trimap = np.array(
97
+ Image.fromarray(sample["trimap"]).resize((256, 256), Image.NEAREST)
98
+ )
99
+
100
+ # convert to other format HWC -> CHW
101
+ sample["image"] = np.moveaxis(image, -1, 0)
102
+ sample["mask"] = np.expand_dims(mask, 0)
103
+ sample["trimap"] = np.expand_dims(trimap, 0)
104
+
105
+ return sample
106
+
107
+
108
+ class TqdmUpTo(tqdm):
109
+ def update_to(self, b=1, bsize=1, tsize=None):
110
+ if tsize is not None:
111
+ self.total = tsize
112
+ self.update(b * bsize - self.n)
113
+
114
+
115
+ def download_url(url, filepath):
116
+ directory = os.path.dirname(os.path.abspath(filepath))
117
+ os.makedirs(directory, exist_ok=True)
118
+ if os.path.exists(filepath):
119
+ return
120
+
121
+ with TqdmUpTo(
122
+ unit="B",
123
+ unit_scale=True,
124
+ unit_divisor=1024,
125
+ miniters=1,
126
+ desc=os.path.basename(filepath),
127
+ ) as t:
128
+ urlretrieve(url, filename=filepath, reporthook=t.update_to, data=None)
129
+ t.total = t.n
130
+
131
+
132
+ def extract_archive(filepath):
133
+ extract_dir = os.path.dirname(os.path.abspath(filepath))
134
+ dst_dir = os.path.splitext(filepath)[0]
135
+ if not os.path.exists(dst_dir):
136
+ shutil.unpack_archive(filepath, extract_dir)
segmentation_models_pytorch/decoders/__init__.py ADDED
File without changes
segmentation_models_pytorch/decoders/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (146 Bytes). View file
 
segmentation_models_pytorch/decoders/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (159 Bytes). View file
 
segmentation_models_pytorch/decoders/deeplabv3/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import DeepLabV3, DeepLabV3Plus
segmentation_models_pytorch/decoders/deeplabv3/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (219 Bytes). View file
 
segmentation_models_pytorch/decoders/deeplabv3/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (232 Bytes). View file
 
segmentation_models_pytorch/decoders/deeplabv3/__pycache__/decoder.cpython-37.pyc ADDED
Binary file (6.74 kB). View file
 
segmentation_models_pytorch/decoders/deeplabv3/__pycache__/decoder.cpython-39.pyc ADDED
Binary file (6.62 kB). View file
 
segmentation_models_pytorch/decoders/deeplabv3/__pycache__/model.cpython-37.pyc ADDED
Binary file (7.13 kB). View file
 
segmentation_models_pytorch/decoders/deeplabv3/__pycache__/model.cpython-39.pyc ADDED
Binary file (7.19 kB). View file
 
segmentation_models_pytorch/decoders/deeplabv3/decoder.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BSD 3-Clause License
3
+
4
+ Copyright (c) Soumith Chintala 2016,
5
+ All rights reserved.
6
+
7
+ Redistribution and use in source and binary forms, with or without
8
+ modification, are permitted provided that the following conditions are met:
9
+
10
+ * Redistributions of source code must retain the above copyright notice, this
11
+ list of conditions and the following disclaimer.
12
+
13
+ * Redistributions in binary form must reproduce the above copyright notice,
14
+ this list of conditions and the following disclaimer in the documentation
15
+ and/or other materials provided with the distribution.
16
+
17
+ * Neither the name of the copyright holder nor the names of its
18
+ contributors may be used to endorse or promote products derived from
19
+ this software without specific prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
+ """
32
+
33
+ import torch
34
+ from torch import nn
35
+ from torch.nn import functional as F
36
+
37
+ __all__ = ["DeepLabV3Decoder"]
38
+
39
+
40
+ class DeepLabV3Decoder(nn.Sequential):
41
+ def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36)):
42
+ super().__init__(
43
+ ASPP(in_channels, out_channels, atrous_rates),
44
+ nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
45
+ nn.BatchNorm2d(out_channels),
46
+ nn.ReLU(),
47
+ )
48
+ self.out_channels = out_channels
49
+
50
+ def forward(self, *features):
51
+ return super().forward(features[-1])
52
+
53
+
54
+ class DeepLabV3PlusDecoder(nn.Module):
55
+ def __init__(
56
+ self,
57
+ encoder_channels,
58
+ out_channels=256,
59
+ atrous_rates=(12, 24, 36),
60
+ output_stride=16,
61
+ ):
62
+ super().__init__()
63
+ if output_stride not in {8, 16}:
64
+ raise ValueError(
65
+ "Output stride should be 8 or 16, got {}.".format(output_stride)
66
+ )
67
+
68
+ self.out_channels = out_channels
69
+ self.output_stride = output_stride
70
+
71
+ self.aspp = nn.Sequential(
72
+ ASPP(encoder_channels[-1], out_channels, atrous_rates, separable=True),
73
+ SeparableConv2d(
74
+ out_channels, out_channels, kernel_size=3, padding=1, bias=False
75
+ ),
76
+ nn.BatchNorm2d(out_channels),
77
+ nn.ReLU(),
78
+ )
79
+
80
+ scale_factor = 2 if output_stride == 8 else 4
81
+ self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor)
82
+
83
+ highres_in_channels = encoder_channels[-4]
84
+ highres_out_channels = 48 # proposed by authors of paper
85
+ self.block1 = nn.Sequential(
86
+ nn.Conv2d(
87
+ highres_in_channels, highres_out_channels, kernel_size=1, bias=False
88
+ ),
89
+ nn.BatchNorm2d(highres_out_channels),
90
+ nn.ReLU(),
91
+ )
92
+ self.block2 = nn.Sequential(
93
+ SeparableConv2d(
94
+ highres_out_channels + out_channels,
95
+ out_channels,
96
+ kernel_size=3,
97
+ padding=1,
98
+ bias=False,
99
+ ),
100
+ nn.BatchNorm2d(out_channels),
101
+ nn.ReLU(),
102
+ )
103
+
104
+ def forward(self, *features):
105
+ aspp_features = self.aspp(features[-1])
106
+ aspp_features = self.up(aspp_features)
107
+ high_res_features = self.block1(features[-4])
108
+ concat_features = torch.cat([aspp_features, high_res_features], dim=1)
109
+ fused_features = self.block2(concat_features)
110
+ return fused_features
111
+
112
+
113
+ class ASPPConv(nn.Sequential):
114
+ def __init__(self, in_channels, out_channels, dilation):
115
+ super().__init__(
116
+ nn.Conv2d(
117
+ in_channels,
118
+ out_channels,
119
+ kernel_size=3,
120
+ padding=dilation,
121
+ dilation=dilation,
122
+ bias=False,
123
+ ),
124
+ nn.BatchNorm2d(out_channels),
125
+ nn.ReLU(),
126
+ )
127
+
128
+
129
+ class ASPPSeparableConv(nn.Sequential):
130
+ def __init__(self, in_channels, out_channels, dilation):
131
+ super().__init__(
132
+ SeparableConv2d(
133
+ in_channels,
134
+ out_channels,
135
+ kernel_size=3,
136
+ padding=dilation,
137
+ dilation=dilation,
138
+ bias=False,
139
+ ),
140
+ nn.BatchNorm2d(out_channels),
141
+ nn.ReLU(),
142
+ )
143
+
144
+
145
+ class ASPPPooling(nn.Sequential):
146
+ def __init__(self, in_channels, out_channels):
147
+ super().__init__(
148
+ nn.AdaptiveAvgPool2d(1),
149
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
150
+ nn.BatchNorm2d(out_channels),
151
+ nn.ReLU(),
152
+ )
153
+
154
+ def forward(self, x):
155
+ size = x.shape[-2:]
156
+ for mod in self:
157
+ x = mod(x)
158
+ return F.interpolate(x, size=size, mode="bilinear", align_corners=False)
159
+
160
+
161
+ class ASPP(nn.Module):
162
+ def __init__(self, in_channels, out_channels, atrous_rates, separable=False):
163
+ super(ASPP, self).__init__()
164
+ modules = []
165
+ modules.append(
166
+ nn.Sequential(
167
+ nn.Conv2d(in_channels, out_channels, 1, bias=False),
168
+ nn.BatchNorm2d(out_channels),
169
+ nn.ReLU(),
170
+ )
171
+ )
172
+
173
+ rate1, rate2, rate3 = tuple(atrous_rates)
174
+ ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv
175
+
176
+ modules.append(ASPPConvModule(in_channels, out_channels, rate1))
177
+ modules.append(ASPPConvModule(in_channels, out_channels, rate2))
178
+ modules.append(ASPPConvModule(in_channels, out_channels, rate3))
179
+ modules.append(ASPPPooling(in_channels, out_channels))
180
+
181
+ self.convs = nn.ModuleList(modules)
182
+
183
+ self.project = nn.Sequential(
184
+ nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=False),
185
+ nn.BatchNorm2d(out_channels),
186
+ nn.ReLU(),
187
+ nn.Dropout(0.5),
188
+ )
189
+
190
+ def forward(self, x):
191
+ res = []
192
+ for conv in self.convs:
193
+ res.append(conv(x))
194
+ res = torch.cat(res, dim=1)
195
+ return self.project(res)
196
+
197
+
198
+ class SeparableConv2d(nn.Sequential):
199
+ def __init__(
200
+ self,
201
+ in_channels,
202
+ out_channels,
203
+ kernel_size,
204
+ stride=1,
205
+ padding=0,
206
+ dilation=1,
207
+ bias=True,
208
+ ):
209
+ dephtwise_conv = nn.Conv2d(
210
+ in_channels,
211
+ in_channels,
212
+ kernel_size,
213
+ stride=stride,
214
+ padding=padding,
215
+ dilation=dilation,
216
+ groups=in_channels,
217
+ bias=False,
218
+ )
219
+ pointwise_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias,)
220
+ super().__init__(dephtwise_conv, pointwise_conv)
segmentation_models_pytorch/decoders/deeplabv3/model.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from typing import Optional
3
+
4
+ from segmentation_models_pytorch.base import (
5
+ SegmentationModel,
6
+ SegmentationHead,
7
+ ClassificationHead,
8
+ )
9
+ from segmentation_models_pytorch.encoders import get_encoder
10
+ from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder
11
+
12
+
13
+ class DeepLabV3(SegmentationModel):
14
+ """DeepLabV3_ implementation from "Rethinking Atrous Convolution for Semantic Image Segmentation"
15
+
16
+ Args:
17
+ encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
18
+ to extract features of different spatial resolution
19
+ encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
20
+ two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
21
+ with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
22
+ Default is 5
23
+ encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
24
+ other pretrained weights (see table with available weights for each encoder_name)
25
+ decoder_channels: A number of convolution filters in ASPP module. Default is 256
26
+ in_channels: A number of input channels for the model, default is 3 (RGB images)
27
+ classes: A number of classes for output mask (or you can think as a number of channels of output mask)
28
+ activation: An activation function to apply after the final convolution layer.
29
+ Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
30
+ **callable** and **None**.
31
+ Default is **None**
32
+ upsampling: Final upsampling factor. Default is 8 to preserve input-output spatial shape identity
33
+ aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
34
+ on top of encoder if **aux_params** is not **None** (default). Supported params:
35
+ - classes (int): A number of classes
36
+ - pooling (str): One of "max", "avg". Default is "avg"
37
+ - dropout (float): Dropout factor in [0, 1)
38
+ - activation (str): An activation function to apply "sigmoid"/"softmax"
39
+ (could be **None** to return logits)
40
+ Returns:
41
+ ``torch.nn.Module``: **DeepLabV3**
42
+
43
+ .. _DeeplabV3:
44
+ https://arxiv.org/abs/1706.05587
45
+
46
+ """
47
+
48
+ def __init__(
49
+ self,
50
+ encoder_name: str = "resnet34",
51
+ encoder_depth: int = 5,
52
+ encoder_weights: Optional[str] = "imagenet",
53
+ decoder_channels: int = 256,
54
+ in_channels: int = 3,
55
+ classes: int = 1,
56
+ activation: Optional[str] = None,
57
+ upsampling: int = 8,
58
+ aux_params: Optional[dict] = None,
59
+ ):
60
+ super().__init__()
61
+
62
+ self.encoder = get_encoder(
63
+ encoder_name,
64
+ in_channels=in_channels,
65
+ depth=encoder_depth,
66
+ weights=encoder_weights,
67
+ output_stride=8,
68
+ )
69
+
70
+ self.decoder = DeepLabV3Decoder(
71
+ in_channels=self.encoder.out_channels[-1], out_channels=decoder_channels,
72
+ )
73
+
74
+ self.segmentation_head = SegmentationHead(
75
+ in_channels=self.decoder.out_channels,
76
+ out_channels=classes,
77
+ activation=activation,
78
+ kernel_size=1,
79
+ upsampling=upsampling,
80
+ )
81
+
82
+ if aux_params is not None:
83
+ self.classification_head = ClassificationHead(
84
+ in_channels=self.encoder.out_channels[-1], **aux_params
85
+ )
86
+ else:
87
+ self.classification_head = None
88
+
89
+
90
+ class DeepLabV3Plus(SegmentationModel):
91
+ """DeepLabV3+ implementation from "Encoder-Decoder with Atrous Separable
92
+ Convolution for Semantic Image Segmentation"
93
+
94
+ Args:
95
+ encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
96
+ to extract features of different spatial resolution
97
+ encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
98
+ two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
99
+ with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
100
+ Default is 5
101
+ encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
102
+ other pretrained weights (see table with available weights for each encoder_name)
103
+ encoder_output_stride: Downsampling factor for last encoder features (see original paper for explanation)
104
+ decoder_atrous_rates: Dilation rates for ASPP module (should be a tuple of 3 integer values)
105
+ decoder_channels: A number of convolution filters in ASPP module. Default is 256
106
+ in_channels: A number of input channels for the model, default is 3 (RGB images)
107
+ classes: A number of classes for output mask (or you can think as a number of channels of output mask)
108
+ activation: An activation function to apply after the final convolution layer.
109
+ Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
110
+ **callable** and **None**.
111
+ Default is **None**
112
+ upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity
113
+ aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
114
+ on top of encoder if **aux_params** is not **None** (default). Supported params:
115
+ - classes (int): A number of classes
116
+ - pooling (str): One of "max", "avg". Default is "avg"
117
+ - dropout (float): Dropout factor in [0, 1)
118
+ - activation (str): An activation function to apply "sigmoid"/"softmax"
119
+ (could be **None** to return logits)
120
+ Returns:
121
+ ``torch.nn.Module``: **DeepLabV3Plus**
122
+
123
+ Reference:
124
+ https://arxiv.org/abs/1802.02611v3
125
+
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ encoder_name: str = "resnet34",
131
+ encoder_depth: int = 5,
132
+ encoder_weights: Optional[str] = "imagenet",
133
+ encoder_output_stride: int = 16,
134
+ decoder_channels: int = 256,
135
+ decoder_atrous_rates: tuple = (12, 24, 36),
136
+ in_channels: int = 3,
137
+ classes: int = 1,
138
+ activation: Optional[str] = None,
139
+ upsampling: int = 4,
140
+ aux_params: Optional[dict] = None,
141
+ ):
142
+ super().__init__()
143
+
144
+ if encoder_output_stride not in [8, 16]:
145
+ raise ValueError(
146
+ "Encoder output stride should be 8 or 16, got {}".format(
147
+ encoder_output_stride
148
+ )
149
+ )
150
+
151
+ self.encoder = get_encoder(
152
+ encoder_name,
153
+ in_channels=in_channels,
154
+ depth=encoder_depth,
155
+ weights=encoder_weights,
156
+ output_stride=encoder_output_stride,
157
+ )
158
+
159
+ self.decoder = DeepLabV3PlusDecoder(
160
+ encoder_channels=self.encoder.out_channels,
161
+ out_channels=decoder_channels,
162
+ atrous_rates=decoder_atrous_rates,
163
+ output_stride=encoder_output_stride,
164
+ )
165
+
166
+ self.segmentation_head = SegmentationHead(
167
+ in_channels=self.decoder.out_channels,
168
+ out_channels=classes,
169
+ activation=activation,
170
+ kernel_size=1,
171
+ upsampling=upsampling,
172
+ )
173
+
174
+ if aux_params is not None:
175
+ self.classification_head = ClassificationHead(
176
+ in_channels=self.encoder.out_channels[-1], **aux_params
177
+ )
178
+ else:
179
+ self.classification_head = None
segmentation_models_pytorch/decoders/fpn/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import FPN
segmentation_models_pytorch/decoders/fpn/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (183 Bytes). View file
 
segmentation_models_pytorch/decoders/fpn/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (196 Bytes). View file
 
segmentation_models_pytorch/decoders/fpn/__pycache__/decoder.cpython-37.pyc ADDED
Binary file (4.52 kB). View file
 
segmentation_models_pytorch/decoders/fpn/__pycache__/decoder.cpython-39.pyc ADDED
Binary file (4.43 kB). View file