feat(transforms): added transforms for post-processing
Browse files- transforms.py +265 -0
transforms.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import warnings
|
3 |
+
from functools import lru_cache
|
4 |
+
from typing import Tuple, Optional, List
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
transforms_logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
@lru_cache(maxsize=None)
|
13 |
+
def _patch_intensity_mask(patch_height: int = 224, patch_width: int = 224, sig: float = 7.5):
|
14 |
+
"""
|
15 |
+
Provides an intensity mask, given a patch size, based on an exponential function.
|
16 |
+
Values close to the center of the patch are close to 1.
|
17 |
+
When we are 20 pixels from the edges, we are ~0.88. Then, we have a drastic drop
|
18 |
+
to 0 at the edges.
|
19 |
+
Args:
|
20 |
+
patch_height: Input patch height
|
21 |
+
patch_width: Input patch width
|
22 |
+
sig: Sigma that divides the exponential.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
An intensity map as a numpy array, with shape == (patch_height, patch_width)
|
26 |
+
"""
|
27 |
+
max_size = max(224, patch_height, patch_width)
|
28 |
+
xm = np.arange(max_size)
|
29 |
+
xm = np.abs(xm - xm.mean())
|
30 |
+
mask = 1 / (1 + np.exp((xm - (max_size / 2 - 20)) / sig))
|
31 |
+
mask = mask * mask[:, np.newaxis]
|
32 |
+
mask = mask[
|
33 |
+
max_size // 2 - patch_height // 2: max_size // 2 + patch_height // 2 + patch_height % 2,
|
34 |
+
max_size // 2 - patch_width // 2: max_size // 2 + patch_width // 2 + patch_width % 2]
|
35 |
+
return mask
|
36 |
+
|
37 |
+
|
38 |
+
def average_patches(patches: np.ndarray, y_sub: List[Tuple[int, int]], x_sub: List[Tuple[int, int]],
|
39 |
+
height: int, width: int):
|
40 |
+
"""
|
41 |
+
Average the patch values over an image of (height, width).
|
42 |
+
Args:
|
43 |
+
patches: numpy array of (# patches, # classes == 3, patch_height, patch_width)
|
44 |
+
y_sub: list of integer tuples. Each tuple contains the start and ending position of
|
45 |
+
the patch in the y-axis.
|
46 |
+
x_sub: list of integer tuples. Each tuple contains the start and ending position of
|
47 |
+
the patch in the x-axis.
|
48 |
+
height: output image height
|
49 |
+
width: output image width
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
A numpy array of (height, width) with the average of the patches with appropriate overlap interpolation.
|
53 |
+
"""
|
54 |
+
intensity_mask = np.zeros((height, width), dtype=np.float32)
|
55 |
+
mean_output = np.zeros((patches.shape[1], height, width), dtype=np.float32)
|
56 |
+
patch_intensity_mask = _patch_intensity_mask(patch_height=patches.shape[-2], patch_width=patches.shape[-1])
|
57 |
+
|
58 |
+
for i in range(len(y_sub)):
|
59 |
+
mean_output[:, y_sub[i][0]:y_sub[i][1], x_sub[i][0]:x_sub[i][1]] += patches[i] * patch_intensity_mask
|
60 |
+
intensity_mask[y_sub[i][0]:y_sub[i][1], x_sub[i][0]:x_sub[i][1]] += patch_intensity_mask
|
61 |
+
|
62 |
+
return mean_output / intensity_mask
|
63 |
+
|
64 |
+
|
65 |
+
def split_in_patches(x: np.ndarray, patch_size: int = 224, tile_overlap: float = 0.1):
|
66 |
+
""" make tiles of image to run at test-time
|
67 |
+
|
68 |
+
Parameters
|
69 |
+
----------
|
70 |
+
x : float32
|
71 |
+
array that's n_channels x height x width
|
72 |
+
|
73 |
+
patch_size : int (optional, default 224)
|
74 |
+
size of tiles
|
75 |
+
|
76 |
+
|
77 |
+
tile_overlap: float (optional, default 0.1)
|
78 |
+
fraction of overlap of tiles
|
79 |
+
|
80 |
+
Returns
|
81 |
+
-------
|
82 |
+
patches : float32
|
83 |
+
array that's ntiles x n_channels x bsize x bsize
|
84 |
+
|
85 |
+
y_sub : list
|
86 |
+
list of arrays with start and end of tiles in Y of length ntiles
|
87 |
+
|
88 |
+
x_sub : list
|
89 |
+
list of arrays with start and end of tiles in X of length ntiles
|
90 |
+
|
91 |
+
|
92 |
+
"""
|
93 |
+
|
94 |
+
n_channels, height, width = x.shape
|
95 |
+
|
96 |
+
tile_overlap = min(0.5, max(0.05, tile_overlap))
|
97 |
+
patch_height = np.int32(min(patch_size, height))
|
98 |
+
patch_width = np.int32(min(patch_size, width))
|
99 |
+
|
100 |
+
# tiles overlap by 10% tile size
|
101 |
+
ny = 1 if height <= patch_size else int(np.ceil((1. + 2 * tile_overlap) * height / patch_size))
|
102 |
+
nx = 1 if width <= patch_size else int(np.ceil((1. + 2 * tile_overlap) * width / patch_size))
|
103 |
+
|
104 |
+
y_start = np.linspace(0, height - patch_height, ny).astype(np.int32)
|
105 |
+
x_start = np.linspace(0, width - patch_width, nx).astype(np.int32)
|
106 |
+
|
107 |
+
y_sub, x_sub = [], []
|
108 |
+
patches = np.zeros((len(y_start), len(x_start), n_channels, patch_height, patch_width), np.float32)
|
109 |
+
for j in range(len(y_start)):
|
110 |
+
for i in range(len(x_start)):
|
111 |
+
y_sub.append([y_start[j], y_start[j] + patch_height])
|
112 |
+
x_sub.append([x_start[i], x_start[i] + patch_width])
|
113 |
+
patches[j, i] = x[:, y_sub[-1][0]:y_sub[-1][1], x_sub[-1][0]:x_sub[-1][1]]
|
114 |
+
|
115 |
+
return patches, y_sub, x_sub
|
116 |
+
|
117 |
+
|
118 |
+
def convert_image_grayscale(x: np.ndarray):
|
119 |
+
assert x.ndim == 2
|
120 |
+
x = x.astype(np.float32)
|
121 |
+
x = x[:, :, np.newaxis]
|
122 |
+
x = np.concatenate((x, np.zeros_like(x)), axis=-1)
|
123 |
+
return x
|
124 |
+
|
125 |
+
|
126 |
+
def convert_image(x, channels: Tuple[int, int]):
|
127 |
+
assert len(channels) == 2
|
128 |
+
|
129 |
+
return reshape(x, channels=channels)
|
130 |
+
|
131 |
+
|
132 |
+
def reshape(x: np.ndarray, channels=(0, 0)):
|
133 |
+
""" reshape data using channels
|
134 |
+
|
135 |
+
Parameters
|
136 |
+
----------
|
137 |
+
x : Numpy array, channel last.
|
138 |
+
|
139 |
+
channels : list of int of length 2 (optional, default [0,0])
|
140 |
+
First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
|
141 |
+
Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
|
142 |
+
For instance, to train on grayscale images, input [0,0]. To train on images with cells
|
143 |
+
in green and nuclei in blue, input [2,3].
|
144 |
+
|
145 |
+
|
146 |
+
Returns
|
147 |
+
-------
|
148 |
+
data : numpy array that's (Z x ) Ly x Lx x nchan (if chan_first==False)
|
149 |
+
|
150 |
+
"""
|
151 |
+
x = x.astype(np.float32)
|
152 |
+
if x.ndim < 3:
|
153 |
+
x = x[:, :, np.newaxis]
|
154 |
+
|
155 |
+
if x.shape[-1] == 1:
|
156 |
+
x = np.concatenate((x, np.zeros_like(x)), axis=-1)
|
157 |
+
else:
|
158 |
+
if channels[0] == 0:
|
159 |
+
x = x.mean(axis=-1, keepdims=True)
|
160 |
+
x = np.concatenate((x, np.zeros_like(x)), axis=-1)
|
161 |
+
else:
|
162 |
+
channels_index = [channels[0] - 1]
|
163 |
+
if channels[1] > 0:
|
164 |
+
channels_index.append(channels[1] - 1)
|
165 |
+
x = x[..., channels_index]
|
166 |
+
for i in range(x.shape[-1]):
|
167 |
+
if np.ptp(x[..., i]) == 0.0:
|
168 |
+
if i == 0:
|
169 |
+
warnings.warn("chan to seg' has value range of ZERO")
|
170 |
+
else:
|
171 |
+
warnings.warn("'chan2 (opt)' has value range of ZERO, can instead set chan2 to 0")
|
172 |
+
if x.shape[-1] == 1:
|
173 |
+
x = np.concatenate((x, np.zeros_like(x)), axis=-1)
|
174 |
+
|
175 |
+
return np.transpose(x, (2, 0, 1))
|
176 |
+
|
177 |
+
|
178 |
+
def resize_image(image, height: Optional[int] = None, width: Optional[int] = None, resize: Optional[float] = None,
|
179 |
+
interpolation=cv2.INTER_LINEAR, no_channels=False):
|
180 |
+
""" resize image for computing flows / unresize for computing dynamics
|
181 |
+
|
182 |
+
Parameters
|
183 |
+
-------------
|
184 |
+
|
185 |
+
image: ND-array
|
186 |
+
image of size [Y x X x nchan] or [Lz x Y x X x nchan] or [Lz x Y x X]
|
187 |
+
|
188 |
+
height: int, optional
|
189 |
+
|
190 |
+
width: int, optional
|
191 |
+
|
192 |
+
resize: float, optional
|
193 |
+
resize coefficient(s) for image; if Ly is None then rsz is used
|
194 |
+
|
195 |
+
interpolation: cv2 interp method (optional, default cv2.INTER_LINEAR)
|
196 |
+
|
197 |
+
Returns
|
198 |
+
--------------
|
199 |
+
|
200 |
+
imgs: ND-array
|
201 |
+
image of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan]
|
202 |
+
|
203 |
+
"""
|
204 |
+
if height is None and resize is None:
|
205 |
+
error_message = 'must give size to resize to or factor to use for resizing'
|
206 |
+
transforms_logger.critical(error_message)
|
207 |
+
raise ValueError(error_message)
|
208 |
+
|
209 |
+
if height is None:
|
210 |
+
# determine Ly and Lx using rsz
|
211 |
+
if not isinstance(resize, list) and not isinstance(resize, np.ndarray):
|
212 |
+
resize = [resize, resize]
|
213 |
+
if no_channels:
|
214 |
+
height = int(image.shape[-2] * resize[-2])
|
215 |
+
width = int(image.shape[-1] * resize[-1])
|
216 |
+
else:
|
217 |
+
height = int(image.shape[-3] * resize[-2])
|
218 |
+
width = int(image.shape[-2] * resize[-1])
|
219 |
+
|
220 |
+
return cv2.resize(image, (width, height), interpolation=interpolation)
|
221 |
+
|
222 |
+
|
223 |
+
def pad_image(x: np.ndarray, div: int = 16):
|
224 |
+
""" pad image for test-time so that its dimensions are a multiple of 16 (2D or 3D)
|
225 |
+
|
226 |
+
Parameters
|
227 |
+
-------------
|
228 |
+
|
229 |
+
x: ND-array
|
230 |
+
image of size [nchan (x Lz) x height x width]
|
231 |
+
|
232 |
+
div: int (optional, default 16)
|
233 |
+
|
234 |
+
Returns
|
235 |
+
--------------
|
236 |
+
|
237 |
+
output: ND-array
|
238 |
+
padded image
|
239 |
+
|
240 |
+
y_sub: array, int
|
241 |
+
yrange of pixels in output corresponding to img0
|
242 |
+
|
243 |
+
x_sub: array, int
|
244 |
+
xrange of pixels in output corresponding to img0
|
245 |
+
|
246 |
+
"""
|
247 |
+
x_pad = int(div * np.ceil(x.shape[-2] / div) - x.shape[-2])
|
248 |
+
x_pad_left = div // 2 + x_pad // 2
|
249 |
+
x_pad_right = div // 2 + x_pad - x_pad // 2
|
250 |
+
|
251 |
+
y_pad = int(div * np.ceil(x.shape[-1] / div) - x.shape[-1])
|
252 |
+
y_pad_left = div // 2 + y_pad // 2
|
253 |
+
y_pad_right = div // 2 + y_pad - y_pad // 2
|
254 |
+
|
255 |
+
if x.ndim > 3:
|
256 |
+
pads = np.array([[0, 0], [0, 0], [x_pad_left, x_pad_right], [y_pad_left, y_pad_right]])
|
257 |
+
else:
|
258 |
+
pads = np.array([[0, 0], [x_pad_left, x_pad_right], [y_pad_left, y_pad_right]])
|
259 |
+
|
260 |
+
output = np.pad(x, pads, mode='constant')
|
261 |
+
|
262 |
+
height, width = x.shape[-2:]
|
263 |
+
y_sub = np.arange(x_pad_left, x_pad_left + height)
|
264 |
+
x_sub = np.arange(y_pad_left, y_pad_left + width)
|
265 |
+
return output, y_sub, x_sub
|