import logging import warnings from functools import lru_cache from typing import Tuple, Optional, List import cv2 import numpy as np transforms_logger = logging.getLogger(__name__) @lru_cache(maxsize=None) def _patch_intensity_mask(patch_height: int = 224, patch_width: int = 224, sig: float = 7.5): """ Provides an intensity mask, given a patch size, based on an exponential function. Values close to the center of the patch are close to 1. When we are 20 pixels from the edges, we are ~0.88. Then, we have a drastic drop to 0 at the edges. Args: patch_height: Input patch height patch_width: Input patch width sig: Sigma that divides the exponential. Returns: An intensity map as a numpy array, with shape == (patch_height, patch_width) """ max_size = max(224, patch_height, patch_width) xm = np.arange(max_size) xm = np.abs(xm - xm.mean()) mask = 1 / (1 + np.exp((xm - (max_size / 2 - 20)) / sig)) mask = mask * mask[:, np.newaxis] mask = mask[ max_size // 2 - patch_height // 2: max_size // 2 + patch_height // 2 + patch_height % 2, max_size // 2 - patch_width // 2: max_size // 2 + patch_width // 2 + patch_width % 2] return mask def average_patches(patches: np.ndarray, y_sub: List[Tuple[int, int]], x_sub: List[Tuple[int, int]], height: int, width: int): """ Average the patch values over an image of (height, width). Args: patches: numpy array of (# patches, # classes == 3, patch_height, patch_width) y_sub: list of integer tuples. Each tuple contains the start and ending position of the patch in the y-axis. x_sub: list of integer tuples. Each tuple contains the start and ending position of the patch in the x-axis. height: output image height width: output image width Returns: A numpy array of (height, width) with the average of the patches with appropriate overlap interpolation. """ intensity_mask = np.zeros((height, width), dtype=np.float32) mean_output = np.zeros((patches.shape[1], height, width), dtype=np.float32) patch_intensity_mask = _patch_intensity_mask(patch_height=patches.shape[-2], patch_width=patches.shape[-1]) for i in range(len(y_sub)): mean_output[:, y_sub[i][0]:y_sub[i][1], x_sub[i][0]:x_sub[i][1]] += patches[i] * patch_intensity_mask intensity_mask[y_sub[i][0]:y_sub[i][1], x_sub[i][0]:x_sub[i][1]] += patch_intensity_mask return mean_output / intensity_mask def split_in_patches(x: np.ndarray, patch_size: int = 224, tile_overlap: float = 0.1): """ make tiles of image to run at test-time Parameters ---------- x : float32 array that's n_channels x height x width patch_size : int (optional, default 224) size of tiles tile_overlap: float (optional, default 0.1) fraction of overlap of tiles Returns ------- patches : float32 array that's ntiles x n_channels x bsize x bsize y_sub : list list of arrays with start and end of tiles in Y of length ntiles x_sub : list list of arrays with start and end of tiles in X of length ntiles """ n_channels, height, width = x.shape tile_overlap = min(0.5, max(0.05, tile_overlap)) patch_height = np.int32(min(patch_size, height)) patch_width = np.int32(min(patch_size, width)) # tiles overlap by 10% tile size ny = 1 if height <= patch_size else int(np.ceil((1. + 2 * tile_overlap) * height / patch_size)) nx = 1 if width <= patch_size else int(np.ceil((1. + 2 * tile_overlap) * width / patch_size)) y_start = np.linspace(0, height - patch_height, ny).astype(np.int32) x_start = np.linspace(0, width - patch_width, nx).astype(np.int32) y_sub, x_sub = [], [] patches = np.zeros((len(y_start), len(x_start), n_channels, patch_height, patch_width), np.float32) for j in range(len(y_start)): for i in range(len(x_start)): y_sub.append([y_start[j], y_start[j] + patch_height]) x_sub.append([x_start[i], x_start[i] + patch_width]) patches[j, i] = x[:, y_sub[-1][0]:y_sub[-1][1], x_sub[-1][0]:x_sub[-1][1]] return patches, y_sub, x_sub def convert_image_grayscale(x: np.ndarray): assert x.ndim == 2 x = x.astype(np.float32) x = x[:, :, np.newaxis] x = np.concatenate((x, np.zeros_like(x)), axis=-1) return x def convert_image(x, channels: Tuple[int, int]): assert len(channels) == 2 return reshape(x, channels=channels) def reshape(x: np.ndarray, channels=(0, 0)): """ reshape data using channels Parameters ---------- x : Numpy array, channel last. channels : list of int of length 2 (optional, default [0,0]) First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue). Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue). For instance, to train on grayscale images, input [0,0]. To train on images with cells in green and nuclei in blue, input [2,3]. Returns ------- data : numpy array that's (Z x ) Ly x Lx x nchan (if chan_first==False) """ x = x.astype(np.float32) if x.ndim < 3: x = x[:, :, np.newaxis] if x.shape[-1] == 1: x = np.concatenate((x, np.zeros_like(x)), axis=-1) else: if channels[0] == 0: x = x.mean(axis=-1, keepdims=True) x = np.concatenate((x, np.zeros_like(x)), axis=-1) else: channels_index = [channels[0] - 1] if channels[1] > 0: channels_index.append(channels[1] - 1) x = x[..., channels_index] for i in range(x.shape[-1]): if np.ptp(x[..., i]) == 0.0: if i == 0: warnings.warn("chan to seg' has value range of ZERO") else: warnings.warn("'chan2 (opt)' has value range of ZERO, can instead set chan2 to 0") if x.shape[-1] == 1: x = np.concatenate((x, np.zeros_like(x)), axis=-1) return np.transpose(x, (2, 0, 1)) def resize_image(image, height: Optional[int] = None, width: Optional[int] = None, resize: Optional[float] = None, interpolation=cv2.INTER_LINEAR, no_channels=False): """ resize image for computing flows / unresize for computing dynamics Parameters ------------- image: ND-array image of size [Y x X x nchan] or [Lz x Y x X x nchan] or [Lz x Y x X] height: int, optional width: int, optional resize: float, optional resize coefficient(s) for image; if Ly is None then rsz is used interpolation: cv2 interp method (optional, default cv2.INTER_LINEAR) Returns -------------- imgs: ND-array image of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan] """ if height is None and resize is None: error_message = 'must give size to resize to or factor to use for resizing' transforms_logger.critical(error_message) raise ValueError(error_message) if height is None: # determine Ly and Lx using rsz if not isinstance(resize, list) and not isinstance(resize, np.ndarray): resize = [resize, resize] if no_channels: height = int(image.shape[-2] * resize[-2]) width = int(image.shape[-1] * resize[-1]) else: height = int(image.shape[-3] * resize[-2]) width = int(image.shape[-2] * resize[-1]) return cv2.resize(image, (width, height), interpolation=interpolation) def pad_image(x: np.ndarray, div: int = 16): """ pad image for test-time so that its dimensions are a multiple of 16 (2D or 3D) Parameters ------------- x: ND-array image of size [nchan (x Lz) x height x width] div: int (optional, default 16) Returns -------------- output: ND-array padded image y_sub: array, int yrange of pixels in output corresponding to img0 x_sub: array, int xrange of pixels in output corresponding to img0 """ x_pad = int(div * np.ceil(x.shape[-2] / div) - x.shape[-2]) x_pad_left = div // 2 + x_pad // 2 x_pad_right = div // 2 + x_pad - x_pad // 2 y_pad = int(div * np.ceil(x.shape[-1] / div) - x.shape[-1]) y_pad_left = div // 2 + y_pad // 2 y_pad_right = div // 2 + y_pad - y_pad // 2 if x.ndim > 3: pads = np.array([[0, 0], [0, 0], [x_pad_left, x_pad_right], [y_pad_left, y_pad_right]]) else: pads = np.array([[0, 0], [x_pad_left, x_pad_right], [y_pad_left, y_pad_right]]) output = np.pad(x, pads, mode='constant') height, width = x.shape[-2:] y_sub = np.arange(x_pad_left, x_pad_left + height) x_sub = np.arange(y_pad_left, y_pad_left + width) return output, y_sub, x_sub