fbeckk commited on
Commit
f5fff27
1 Parent(s): abed53e

feat(transforms): added transforms for post-processing

Browse files
Files changed (1) hide show
  1. transforms.py +265 -0
transforms.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import warnings
3
+ from functools import lru_cache
4
+ from typing import Tuple, Optional, List
5
+
6
+ import cv2
7
+ import numpy as np
8
+
9
+ transforms_logger = logging.getLogger(__name__)
10
+
11
+
12
+ @lru_cache(maxsize=None)
13
+ def _patch_intensity_mask(patch_height: int = 224, patch_width: int = 224, sig: float = 7.5):
14
+ """
15
+ Provides an intensity mask, given a patch size, based on an exponential function.
16
+ Values close to the center of the patch are close to 1.
17
+ When we are 20 pixels from the edges, we are ~0.88. Then, we have a drastic drop
18
+ to 0 at the edges.
19
+ Args:
20
+ patch_height: Input patch height
21
+ patch_width: Input patch width
22
+ sig: Sigma that divides the exponential.
23
+
24
+ Returns:
25
+ An intensity map as a numpy array, with shape == (patch_height, patch_width)
26
+ """
27
+ max_size = max(224, patch_height, patch_width)
28
+ xm = np.arange(max_size)
29
+ xm = np.abs(xm - xm.mean())
30
+ mask = 1 / (1 + np.exp((xm - (max_size / 2 - 20)) / sig))
31
+ mask = mask * mask[:, np.newaxis]
32
+ mask = mask[
33
+ max_size // 2 - patch_height // 2: max_size // 2 + patch_height // 2 + patch_height % 2,
34
+ max_size // 2 - patch_width // 2: max_size // 2 + patch_width // 2 + patch_width % 2]
35
+ return mask
36
+
37
+
38
+ def average_patches(patches: np.ndarray, y_sub: List[Tuple[int, int]], x_sub: List[Tuple[int, int]],
39
+ height: int, width: int):
40
+ """
41
+ Average the patch values over an image of (height, width).
42
+ Args:
43
+ patches: numpy array of (# patches, # classes == 3, patch_height, patch_width)
44
+ y_sub: list of integer tuples. Each tuple contains the start and ending position of
45
+ the patch in the y-axis.
46
+ x_sub: list of integer tuples. Each tuple contains the start and ending position of
47
+ the patch in the x-axis.
48
+ height: output image height
49
+ width: output image width
50
+
51
+ Returns:
52
+ A numpy array of (height, width) with the average of the patches with appropriate overlap interpolation.
53
+ """
54
+ intensity_mask = np.zeros((height, width), dtype=np.float32)
55
+ mean_output = np.zeros((patches.shape[1], height, width), dtype=np.float32)
56
+ patch_intensity_mask = _patch_intensity_mask(patch_height=patches.shape[-2], patch_width=patches.shape[-1])
57
+
58
+ for i in range(len(y_sub)):
59
+ mean_output[:, y_sub[i][0]:y_sub[i][1], x_sub[i][0]:x_sub[i][1]] += patches[i] * patch_intensity_mask
60
+ intensity_mask[y_sub[i][0]:y_sub[i][1], x_sub[i][0]:x_sub[i][1]] += patch_intensity_mask
61
+
62
+ return mean_output / intensity_mask
63
+
64
+
65
+ def split_in_patches(x: np.ndarray, patch_size: int = 224, tile_overlap: float = 0.1):
66
+ """ make tiles of image to run at test-time
67
+
68
+ Parameters
69
+ ----------
70
+ x : float32
71
+ array that's n_channels x height x width
72
+
73
+ patch_size : int (optional, default 224)
74
+ size of tiles
75
+
76
+
77
+ tile_overlap: float (optional, default 0.1)
78
+ fraction of overlap of tiles
79
+
80
+ Returns
81
+ -------
82
+ patches : float32
83
+ array that's ntiles x n_channels x bsize x bsize
84
+
85
+ y_sub : list
86
+ list of arrays with start and end of tiles in Y of length ntiles
87
+
88
+ x_sub : list
89
+ list of arrays with start and end of tiles in X of length ntiles
90
+
91
+
92
+ """
93
+
94
+ n_channels, height, width = x.shape
95
+
96
+ tile_overlap = min(0.5, max(0.05, tile_overlap))
97
+ patch_height = np.int32(min(patch_size, height))
98
+ patch_width = np.int32(min(patch_size, width))
99
+
100
+ # tiles overlap by 10% tile size
101
+ ny = 1 if height <= patch_size else int(np.ceil((1. + 2 * tile_overlap) * height / patch_size))
102
+ nx = 1 if width <= patch_size else int(np.ceil((1. + 2 * tile_overlap) * width / patch_size))
103
+
104
+ y_start = np.linspace(0, height - patch_height, ny).astype(np.int32)
105
+ x_start = np.linspace(0, width - patch_width, nx).astype(np.int32)
106
+
107
+ y_sub, x_sub = [], []
108
+ patches = np.zeros((len(y_start), len(x_start), n_channels, patch_height, patch_width), np.float32)
109
+ for j in range(len(y_start)):
110
+ for i in range(len(x_start)):
111
+ y_sub.append([y_start[j], y_start[j] + patch_height])
112
+ x_sub.append([x_start[i], x_start[i] + patch_width])
113
+ patches[j, i] = x[:, y_sub[-1][0]:y_sub[-1][1], x_sub[-1][0]:x_sub[-1][1]]
114
+
115
+ return patches, y_sub, x_sub
116
+
117
+
118
+ def convert_image_grayscale(x: np.ndarray):
119
+ assert x.ndim == 2
120
+ x = x.astype(np.float32)
121
+ x = x[:, :, np.newaxis]
122
+ x = np.concatenate((x, np.zeros_like(x)), axis=-1)
123
+ return x
124
+
125
+
126
+ def convert_image(x, channels: Tuple[int, int]):
127
+ assert len(channels) == 2
128
+
129
+ return reshape(x, channels=channels)
130
+
131
+
132
+ def reshape(x: np.ndarray, channels=(0, 0)):
133
+ """ reshape data using channels
134
+
135
+ Parameters
136
+ ----------
137
+ x : Numpy array, channel last.
138
+
139
+ channels : list of int of length 2 (optional, default [0,0])
140
+ First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
141
+ Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
142
+ For instance, to train on grayscale images, input [0,0]. To train on images with cells
143
+ in green and nuclei in blue, input [2,3].
144
+
145
+
146
+ Returns
147
+ -------
148
+ data : numpy array that's (Z x ) Ly x Lx x nchan (if chan_first==False)
149
+
150
+ """
151
+ x = x.astype(np.float32)
152
+ if x.ndim < 3:
153
+ x = x[:, :, np.newaxis]
154
+
155
+ if x.shape[-1] == 1:
156
+ x = np.concatenate((x, np.zeros_like(x)), axis=-1)
157
+ else:
158
+ if channels[0] == 0:
159
+ x = x.mean(axis=-1, keepdims=True)
160
+ x = np.concatenate((x, np.zeros_like(x)), axis=-1)
161
+ else:
162
+ channels_index = [channels[0] - 1]
163
+ if channels[1] > 0:
164
+ channels_index.append(channels[1] - 1)
165
+ x = x[..., channels_index]
166
+ for i in range(x.shape[-1]):
167
+ if np.ptp(x[..., i]) == 0.0:
168
+ if i == 0:
169
+ warnings.warn("chan to seg' has value range of ZERO")
170
+ else:
171
+ warnings.warn("'chan2 (opt)' has value range of ZERO, can instead set chan2 to 0")
172
+ if x.shape[-1] == 1:
173
+ x = np.concatenate((x, np.zeros_like(x)), axis=-1)
174
+
175
+ return np.transpose(x, (2, 0, 1))
176
+
177
+
178
+ def resize_image(image, height: Optional[int] = None, width: Optional[int] = None, resize: Optional[float] = None,
179
+ interpolation=cv2.INTER_LINEAR, no_channels=False):
180
+ """ resize image for computing flows / unresize for computing dynamics
181
+
182
+ Parameters
183
+ -------------
184
+
185
+ image: ND-array
186
+ image of size [Y x X x nchan] or [Lz x Y x X x nchan] or [Lz x Y x X]
187
+
188
+ height: int, optional
189
+
190
+ width: int, optional
191
+
192
+ resize: float, optional
193
+ resize coefficient(s) for image; if Ly is None then rsz is used
194
+
195
+ interpolation: cv2 interp method (optional, default cv2.INTER_LINEAR)
196
+
197
+ Returns
198
+ --------------
199
+
200
+ imgs: ND-array
201
+ image of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan]
202
+
203
+ """
204
+ if height is None and resize is None:
205
+ error_message = 'must give size to resize to or factor to use for resizing'
206
+ transforms_logger.critical(error_message)
207
+ raise ValueError(error_message)
208
+
209
+ if height is None:
210
+ # determine Ly and Lx using rsz
211
+ if not isinstance(resize, list) and not isinstance(resize, np.ndarray):
212
+ resize = [resize, resize]
213
+ if no_channels:
214
+ height = int(image.shape[-2] * resize[-2])
215
+ width = int(image.shape[-1] * resize[-1])
216
+ else:
217
+ height = int(image.shape[-3] * resize[-2])
218
+ width = int(image.shape[-2] * resize[-1])
219
+
220
+ return cv2.resize(image, (width, height), interpolation=interpolation)
221
+
222
+
223
+ def pad_image(x: np.ndarray, div: int = 16):
224
+ """ pad image for test-time so that its dimensions are a multiple of 16 (2D or 3D)
225
+
226
+ Parameters
227
+ -------------
228
+
229
+ x: ND-array
230
+ image of size [nchan (x Lz) x height x width]
231
+
232
+ div: int (optional, default 16)
233
+
234
+ Returns
235
+ --------------
236
+
237
+ output: ND-array
238
+ padded image
239
+
240
+ y_sub: array, int
241
+ yrange of pixels in output corresponding to img0
242
+
243
+ x_sub: array, int
244
+ xrange of pixels in output corresponding to img0
245
+
246
+ """
247
+ x_pad = int(div * np.ceil(x.shape[-2] / div) - x.shape[-2])
248
+ x_pad_left = div // 2 + x_pad // 2
249
+ x_pad_right = div // 2 + x_pad - x_pad // 2
250
+
251
+ y_pad = int(div * np.ceil(x.shape[-1] / div) - x.shape[-1])
252
+ y_pad_left = div // 2 + y_pad // 2
253
+ y_pad_right = div // 2 + y_pad - y_pad // 2
254
+
255
+ if x.ndim > 3:
256
+ pads = np.array([[0, 0], [0, 0], [x_pad_left, x_pad_right], [y_pad_left, y_pad_right]])
257
+ else:
258
+ pads = np.array([[0, 0], [x_pad_left, x_pad_right], [y_pad_left, y_pad_right]])
259
+
260
+ output = np.pad(x, pads, mode='constant')
261
+
262
+ height, width = x.shape[-2:]
263
+ y_sub = np.arange(x_pad_left, x_pad_left + height)
264
+ x_sub = np.arange(y_pad_left, y_pad_left + width)
265
+ return output, y_sub, x_sub