cell-seg / transforms.py
fbeckk's picture
feat(transforms): added transforms for post-processing
f5fff27 unverified
raw
history blame contribute delete
No virus
8.98 kB
import logging
import warnings
from functools import lru_cache
from typing import Tuple, Optional, List
import cv2
import numpy as np
transforms_logger = logging.getLogger(__name__)
@lru_cache(maxsize=None)
def _patch_intensity_mask(patch_height: int = 224, patch_width: int = 224, sig: float = 7.5):
"""
Provides an intensity mask, given a patch size, based on an exponential function.
Values close to the center of the patch are close to 1.
When we are 20 pixels from the edges, we are ~0.88. Then, we have a drastic drop
to 0 at the edges.
Args:
patch_height: Input patch height
patch_width: Input patch width
sig: Sigma that divides the exponential.
Returns:
An intensity map as a numpy array, with shape == (patch_height, patch_width)
"""
max_size = max(224, patch_height, patch_width)
xm = np.arange(max_size)
xm = np.abs(xm - xm.mean())
mask = 1 / (1 + np.exp((xm - (max_size / 2 - 20)) / sig))
mask = mask * mask[:, np.newaxis]
mask = mask[
max_size // 2 - patch_height // 2: max_size // 2 + patch_height // 2 + patch_height % 2,
max_size // 2 - patch_width // 2: max_size // 2 + patch_width // 2 + patch_width % 2]
return mask
def average_patches(patches: np.ndarray, y_sub: List[Tuple[int, int]], x_sub: List[Tuple[int, int]],
height: int, width: int):
"""
Average the patch values over an image of (height, width).
Args:
patches: numpy array of (# patches, # classes == 3, patch_height, patch_width)
y_sub: list of integer tuples. Each tuple contains the start and ending position of
the patch in the y-axis.
x_sub: list of integer tuples. Each tuple contains the start and ending position of
the patch in the x-axis.
height: output image height
width: output image width
Returns:
A numpy array of (height, width) with the average of the patches with appropriate overlap interpolation.
"""
intensity_mask = np.zeros((height, width), dtype=np.float32)
mean_output = np.zeros((patches.shape[1], height, width), dtype=np.float32)
patch_intensity_mask = _patch_intensity_mask(patch_height=patches.shape[-2], patch_width=patches.shape[-1])
for i in range(len(y_sub)):
mean_output[:, y_sub[i][0]:y_sub[i][1], x_sub[i][0]:x_sub[i][1]] += patches[i] * patch_intensity_mask
intensity_mask[y_sub[i][0]:y_sub[i][1], x_sub[i][0]:x_sub[i][1]] += patch_intensity_mask
return mean_output / intensity_mask
def split_in_patches(x: np.ndarray, patch_size: int = 224, tile_overlap: float = 0.1):
""" make tiles of image to run at test-time
Parameters
----------
x : float32
array that's n_channels x height x width
patch_size : int (optional, default 224)
size of tiles
tile_overlap: float (optional, default 0.1)
fraction of overlap of tiles
Returns
-------
patches : float32
array that's ntiles x n_channels x bsize x bsize
y_sub : list
list of arrays with start and end of tiles in Y of length ntiles
x_sub : list
list of arrays with start and end of tiles in X of length ntiles
"""
n_channels, height, width = x.shape
tile_overlap = min(0.5, max(0.05, tile_overlap))
patch_height = np.int32(min(patch_size, height))
patch_width = np.int32(min(patch_size, width))
# tiles overlap by 10% tile size
ny = 1 if height <= patch_size else int(np.ceil((1. + 2 * tile_overlap) * height / patch_size))
nx = 1 if width <= patch_size else int(np.ceil((1. + 2 * tile_overlap) * width / patch_size))
y_start = np.linspace(0, height - patch_height, ny).astype(np.int32)
x_start = np.linspace(0, width - patch_width, nx).astype(np.int32)
y_sub, x_sub = [], []
patches = np.zeros((len(y_start), len(x_start), n_channels, patch_height, patch_width), np.float32)
for j in range(len(y_start)):
for i in range(len(x_start)):
y_sub.append([y_start[j], y_start[j] + patch_height])
x_sub.append([x_start[i], x_start[i] + patch_width])
patches[j, i] = x[:, y_sub[-1][0]:y_sub[-1][1], x_sub[-1][0]:x_sub[-1][1]]
return patches, y_sub, x_sub
def convert_image_grayscale(x: np.ndarray):
assert x.ndim == 2
x = x.astype(np.float32)
x = x[:, :, np.newaxis]
x = np.concatenate((x, np.zeros_like(x)), axis=-1)
return x
def convert_image(x, channels: Tuple[int, int]):
assert len(channels) == 2
return reshape(x, channels=channels)
def reshape(x: np.ndarray, channels=(0, 0)):
""" reshape data using channels
Parameters
----------
x : Numpy array, channel last.
channels : list of int of length 2 (optional, default [0,0])
First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
For instance, to train on grayscale images, input [0,0]. To train on images with cells
in green and nuclei in blue, input [2,3].
Returns
-------
data : numpy array that's (Z x ) Ly x Lx x nchan (if chan_first==False)
"""
x = x.astype(np.float32)
if x.ndim < 3:
x = x[:, :, np.newaxis]
if x.shape[-1] == 1:
x = np.concatenate((x, np.zeros_like(x)), axis=-1)
else:
if channels[0] == 0:
x = x.mean(axis=-1, keepdims=True)
x = np.concatenate((x, np.zeros_like(x)), axis=-1)
else:
channels_index = [channels[0] - 1]
if channels[1] > 0:
channels_index.append(channels[1] - 1)
x = x[..., channels_index]
for i in range(x.shape[-1]):
if np.ptp(x[..., i]) == 0.0:
if i == 0:
warnings.warn("chan to seg' has value range of ZERO")
else:
warnings.warn("'chan2 (opt)' has value range of ZERO, can instead set chan2 to 0")
if x.shape[-1] == 1:
x = np.concatenate((x, np.zeros_like(x)), axis=-1)
return np.transpose(x, (2, 0, 1))
def resize_image(image, height: Optional[int] = None, width: Optional[int] = None, resize: Optional[float] = None,
interpolation=cv2.INTER_LINEAR, no_channels=False):
""" resize image for computing flows / unresize for computing dynamics
Parameters
-------------
image: ND-array
image of size [Y x X x nchan] or [Lz x Y x X x nchan] or [Lz x Y x X]
height: int, optional
width: int, optional
resize: float, optional
resize coefficient(s) for image; if Ly is None then rsz is used
interpolation: cv2 interp method (optional, default cv2.INTER_LINEAR)
Returns
--------------
imgs: ND-array
image of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan]
"""
if height is None and resize is None:
error_message = 'must give size to resize to or factor to use for resizing'
transforms_logger.critical(error_message)
raise ValueError(error_message)
if height is None:
# determine Ly and Lx using rsz
if not isinstance(resize, list) and not isinstance(resize, np.ndarray):
resize = [resize, resize]
if no_channels:
height = int(image.shape[-2] * resize[-2])
width = int(image.shape[-1] * resize[-1])
else:
height = int(image.shape[-3] * resize[-2])
width = int(image.shape[-2] * resize[-1])
return cv2.resize(image, (width, height), interpolation=interpolation)
def pad_image(x: np.ndarray, div: int = 16):
""" pad image for test-time so that its dimensions are a multiple of 16 (2D or 3D)
Parameters
-------------
x: ND-array
image of size [nchan (x Lz) x height x width]
div: int (optional, default 16)
Returns
--------------
output: ND-array
padded image
y_sub: array, int
yrange of pixels in output corresponding to img0
x_sub: array, int
xrange of pixels in output corresponding to img0
"""
x_pad = int(div * np.ceil(x.shape[-2] / div) - x.shape[-2])
x_pad_left = div // 2 + x_pad // 2
x_pad_right = div // 2 + x_pad - x_pad // 2
y_pad = int(div * np.ceil(x.shape[-1] / div) - x.shape[-1])
y_pad_left = div // 2 + y_pad // 2
y_pad_right = div // 2 + y_pad - y_pad // 2
if x.ndim > 3:
pads = np.array([[0, 0], [0, 0], [x_pad_left, x_pad_right], [y_pad_left, y_pad_right]])
else:
pads = np.array([[0, 0], [x_pad_left, x_pad_right], [y_pad_left, y_pad_right]])
output = np.pad(x, pads, mode='constant')
height, width = x.shape[-2:]
y_sub = np.arange(x_pad_left, x_pad_left + height)
x_sub = np.arange(y_pad_left, y_pad_left + width)
return output, y_sub, x_sub