|
import logging |
|
import warnings |
|
from functools import lru_cache |
|
from typing import Tuple, Optional, List |
|
|
|
import cv2 |
|
import numpy as np |
|
|
|
transforms_logger = logging.getLogger(__name__) |
|
|
|
|
|
@lru_cache(maxsize=None) |
|
def _patch_intensity_mask(patch_height: int = 224, patch_width: int = 224, sig: float = 7.5): |
|
""" |
|
Provides an intensity mask, given a patch size, based on an exponential function. |
|
Values close to the center of the patch are close to 1. |
|
When we are 20 pixels from the edges, we are ~0.88. Then, we have a drastic drop |
|
to 0 at the edges. |
|
Args: |
|
patch_height: Input patch height |
|
patch_width: Input patch width |
|
sig: Sigma that divides the exponential. |
|
|
|
Returns: |
|
An intensity map as a numpy array, with shape == (patch_height, patch_width) |
|
""" |
|
max_size = max(224, patch_height, patch_width) |
|
xm = np.arange(max_size) |
|
xm = np.abs(xm - xm.mean()) |
|
mask = 1 / (1 + np.exp((xm - (max_size / 2 - 20)) / sig)) |
|
mask = mask * mask[:, np.newaxis] |
|
mask = mask[ |
|
max_size // 2 - patch_height // 2: max_size // 2 + patch_height // 2 + patch_height % 2, |
|
max_size // 2 - patch_width // 2: max_size // 2 + patch_width // 2 + patch_width % 2] |
|
return mask |
|
|
|
|
|
def average_patches(patches: np.ndarray, y_sub: List[Tuple[int, int]], x_sub: List[Tuple[int, int]], |
|
height: int, width: int): |
|
""" |
|
Average the patch values over an image of (height, width). |
|
Args: |
|
patches: numpy array of (# patches, # classes == 3, patch_height, patch_width) |
|
y_sub: list of integer tuples. Each tuple contains the start and ending position of |
|
the patch in the y-axis. |
|
x_sub: list of integer tuples. Each tuple contains the start and ending position of |
|
the patch in the x-axis. |
|
height: output image height |
|
width: output image width |
|
|
|
Returns: |
|
A numpy array of (height, width) with the average of the patches with appropriate overlap interpolation. |
|
""" |
|
intensity_mask = np.zeros((height, width), dtype=np.float32) |
|
mean_output = np.zeros((patches.shape[1], height, width), dtype=np.float32) |
|
patch_intensity_mask = _patch_intensity_mask(patch_height=patches.shape[-2], patch_width=patches.shape[-1]) |
|
|
|
for i in range(len(y_sub)): |
|
mean_output[:, y_sub[i][0]:y_sub[i][1], x_sub[i][0]:x_sub[i][1]] += patches[i] * patch_intensity_mask |
|
intensity_mask[y_sub[i][0]:y_sub[i][1], x_sub[i][0]:x_sub[i][1]] += patch_intensity_mask |
|
|
|
return mean_output / intensity_mask |
|
|
|
|
|
def split_in_patches(x: np.ndarray, patch_size: int = 224, tile_overlap: float = 0.1): |
|
""" make tiles of image to run at test-time |
|
|
|
Parameters |
|
---------- |
|
x : float32 |
|
array that's n_channels x height x width |
|
|
|
patch_size : int (optional, default 224) |
|
size of tiles |
|
|
|
|
|
tile_overlap: float (optional, default 0.1) |
|
fraction of overlap of tiles |
|
|
|
Returns |
|
------- |
|
patches : float32 |
|
array that's ntiles x n_channels x bsize x bsize |
|
|
|
y_sub : list |
|
list of arrays with start and end of tiles in Y of length ntiles |
|
|
|
x_sub : list |
|
list of arrays with start and end of tiles in X of length ntiles |
|
|
|
|
|
""" |
|
|
|
n_channels, height, width = x.shape |
|
|
|
tile_overlap = min(0.5, max(0.05, tile_overlap)) |
|
patch_height = np.int32(min(patch_size, height)) |
|
patch_width = np.int32(min(patch_size, width)) |
|
|
|
|
|
ny = 1 if height <= patch_size else int(np.ceil((1. + 2 * tile_overlap) * height / patch_size)) |
|
nx = 1 if width <= patch_size else int(np.ceil((1. + 2 * tile_overlap) * width / patch_size)) |
|
|
|
y_start = np.linspace(0, height - patch_height, ny).astype(np.int32) |
|
x_start = np.linspace(0, width - patch_width, nx).astype(np.int32) |
|
|
|
y_sub, x_sub = [], [] |
|
patches = np.zeros((len(y_start), len(x_start), n_channels, patch_height, patch_width), np.float32) |
|
for j in range(len(y_start)): |
|
for i in range(len(x_start)): |
|
y_sub.append([y_start[j], y_start[j] + patch_height]) |
|
x_sub.append([x_start[i], x_start[i] + patch_width]) |
|
patches[j, i] = x[:, y_sub[-1][0]:y_sub[-1][1], x_sub[-1][0]:x_sub[-1][1]] |
|
|
|
return patches, y_sub, x_sub |
|
|
|
|
|
def convert_image_grayscale(x: np.ndarray): |
|
assert x.ndim == 2 |
|
x = x.astype(np.float32) |
|
x = x[:, :, np.newaxis] |
|
x = np.concatenate((x, np.zeros_like(x)), axis=-1) |
|
return x |
|
|
|
|
|
def convert_image(x, channels: Tuple[int, int]): |
|
assert len(channels) == 2 |
|
|
|
return reshape(x, channels=channels) |
|
|
|
|
|
def reshape(x: np.ndarray, channels=(0, 0)): |
|
""" reshape data using channels |
|
|
|
Parameters |
|
---------- |
|
x : Numpy array, channel last. |
|
|
|
channels : list of int of length 2 (optional, default [0,0]) |
|
First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue). |
|
Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue). |
|
For instance, to train on grayscale images, input [0,0]. To train on images with cells |
|
in green and nuclei in blue, input [2,3]. |
|
|
|
|
|
Returns |
|
------- |
|
data : numpy array that's (Z x ) Ly x Lx x nchan (if chan_first==False) |
|
|
|
""" |
|
x = x.astype(np.float32) |
|
if x.ndim < 3: |
|
x = x[:, :, np.newaxis] |
|
|
|
if x.shape[-1] == 1: |
|
x = np.concatenate((x, np.zeros_like(x)), axis=-1) |
|
else: |
|
if channels[0] == 0: |
|
x = x.mean(axis=-1, keepdims=True) |
|
x = np.concatenate((x, np.zeros_like(x)), axis=-1) |
|
else: |
|
channels_index = [channels[0] - 1] |
|
if channels[1] > 0: |
|
channels_index.append(channels[1] - 1) |
|
x = x[..., channels_index] |
|
for i in range(x.shape[-1]): |
|
if np.ptp(x[..., i]) == 0.0: |
|
if i == 0: |
|
warnings.warn("chan to seg' has value range of ZERO") |
|
else: |
|
warnings.warn("'chan2 (opt)' has value range of ZERO, can instead set chan2 to 0") |
|
if x.shape[-1] == 1: |
|
x = np.concatenate((x, np.zeros_like(x)), axis=-1) |
|
|
|
return np.transpose(x, (2, 0, 1)) |
|
|
|
|
|
def resize_image(image, height: Optional[int] = None, width: Optional[int] = None, resize: Optional[float] = None, |
|
interpolation=cv2.INTER_LINEAR, no_channels=False): |
|
""" resize image for computing flows / unresize for computing dynamics |
|
|
|
Parameters |
|
------------- |
|
|
|
image: ND-array |
|
image of size [Y x X x nchan] or [Lz x Y x X x nchan] or [Lz x Y x X] |
|
|
|
height: int, optional |
|
|
|
width: int, optional |
|
|
|
resize: float, optional |
|
resize coefficient(s) for image; if Ly is None then rsz is used |
|
|
|
interpolation: cv2 interp method (optional, default cv2.INTER_LINEAR) |
|
|
|
Returns |
|
-------------- |
|
|
|
imgs: ND-array |
|
image of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan] |
|
|
|
""" |
|
if height is None and resize is None: |
|
error_message = 'must give size to resize to or factor to use for resizing' |
|
transforms_logger.critical(error_message) |
|
raise ValueError(error_message) |
|
|
|
if height is None: |
|
|
|
if not isinstance(resize, list) and not isinstance(resize, np.ndarray): |
|
resize = [resize, resize] |
|
if no_channels: |
|
height = int(image.shape[-2] * resize[-2]) |
|
width = int(image.shape[-1] * resize[-1]) |
|
else: |
|
height = int(image.shape[-3] * resize[-2]) |
|
width = int(image.shape[-2] * resize[-1]) |
|
|
|
return cv2.resize(image, (width, height), interpolation=interpolation) |
|
|
|
|
|
def pad_image(x: np.ndarray, div: int = 16): |
|
""" pad image for test-time so that its dimensions are a multiple of 16 (2D or 3D) |
|
|
|
Parameters |
|
------------- |
|
|
|
x: ND-array |
|
image of size [nchan (x Lz) x height x width] |
|
|
|
div: int (optional, default 16) |
|
|
|
Returns |
|
-------------- |
|
|
|
output: ND-array |
|
padded image |
|
|
|
y_sub: array, int |
|
yrange of pixels in output corresponding to img0 |
|
|
|
x_sub: array, int |
|
xrange of pixels in output corresponding to img0 |
|
|
|
""" |
|
x_pad = int(div * np.ceil(x.shape[-2] / div) - x.shape[-2]) |
|
x_pad_left = div // 2 + x_pad // 2 |
|
x_pad_right = div // 2 + x_pad - x_pad // 2 |
|
|
|
y_pad = int(div * np.ceil(x.shape[-1] / div) - x.shape[-1]) |
|
y_pad_left = div // 2 + y_pad // 2 |
|
y_pad_right = div // 2 + y_pad - y_pad // 2 |
|
|
|
if x.ndim > 3: |
|
pads = np.array([[0, 0], [0, 0], [x_pad_left, x_pad_right], [y_pad_left, y_pad_right]]) |
|
else: |
|
pads = np.array([[0, 0], [x_pad_left, x_pad_right], [y_pad_left, y_pad_right]]) |
|
|
|
output = np.pad(x, pads, mode='constant') |
|
|
|
height, width = x.shape[-2:] |
|
y_sub = np.arange(x_pad_left, x_pad_left + height) |
|
x_sub = np.arange(y_pad_left, y_pad_left + width) |
|
return output, y_sub, x_sub |