import math import numpy as np import dynamics import transforms def _ov_batch_gradient_style(model, image): result = model(image) result = {k.get_any_name(): v for k, v in result.items()} return result["gradients"], result["styles"] def _ov_tiled_inference( model, x: np.ndarray, patch_size: int = 224, tile_overlap: float = 0.1, n_classes: int = 3, batch_size: int = 64, ): assert x.ndim == 3, "yikes" x, y_sub, x_sub = transforms.pad_image(x) slc = [slice(0, x.shape[n] + 1) for n in range(x.ndim)] slc[-3] = slice(0, n_classes + 1) slc[-2] = slice(y_sub[0], y_sub[-1] + 1) slc[-1] = slice(x_sub[0], x_sub[-1] + 1) slc = tuple(slc) patches, y_sub, x_sub = transforms.split_in_patches( x, patch_size=patch_size, tile_overlap=tile_overlap, ) _, height, width = x.shape n_y, n_x, n_channels, patch_height, patch_width = patches.shape patches = np.reshape(patches, (n_y * n_x, n_channels, patch_height, patch_width)) y = np.zeros((n_y * n_x, n_classes, patch_height, patch_width)) styles = None for k in range(math.ceil(patches.shape[0] / batch_size)): batch_indexes = np.arange( batch_size * k, min(patches.shape[0], batch_size * k + batch_size)) y0, style = _ov_batch_gradient_style( model=model, image=patches[batch_indexes], ) y[batch_indexes] = y0 if k == 0: styles = style[0] styles += style.sum(axis=0) styles /= patches.shape[0] yf = transforms.average_patches(y, y_sub, x_sub, height, width) yf = yf[:, :x.shape[1], :x.shape[2]] styles /= (styles**2).sum()**0.5 yf = np.transpose(yf[slc], (1, 2, 0)) return yf, styles def ov_inference( model, x: np.ndarray, rescale: float = 1., cell_probability_threshold: float = .0, flow_threshold: float = .4, interp: bool = False, ) -> np.ndarray: y, style = _ov_tiled_inference(model=model, x=x) cell_probability = y[:, :, 2] gradients = y[:, :, :2].transpose((2, 0, 1)) mask, _ = dynamics.compute_masks( gradients, cell_probability, n_iter=(1 / rescale) * 200, cell_probability_threshold=cell_probability_threshold, flow_threshold=flow_threshold, interp=interp, device='cpu', use_gpu=False, ) return mask.squeeze()