| """ |
| Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer, Michael Rariden and Marius Pachitariu. |
| """ |
|
|
| import os, time |
| |
| import numpy as np |
| from tqdm import trange |
| import torch |
| from scipy.ndimage import gaussian_filter |
| |
| import cv2 |
|
|
| import logging |
|
|
| models_logger = logging.getLogger(__name__) |
|
|
| from . import transforms, dynamics, utils |
| from .vit import Transformer |
| from .core import assign_device, run_net |
|
|
| |
| |
| |
|
|
| |
|
|
| |
|
|
| normalize_default = { |
| "lowhigh": None, |
| "percentile": None, |
| "normalize": True, |
| "norm3D": True, |
| "sharpen_radius": 0, |
| "smooth_radius": 0, |
| "tile_norm_blocksize": 0, |
| "tile_norm_smooth3D": 1, |
| "invert": False |
| } |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| class SegModel(): |
| """ |
| Class representing a Cellpose model. |
| |
| Attributes: |
| diam_mean (float): Mean "diameter" value for the model. |
| builtin (bool): Whether the model is a built-in model or not. |
| device (torch device): Device used for model running / training. |
| nclasses (int): Number of classes in the model. |
| nbase (list): List of base values for the model. |
| net (CPnet): Cellpose network. |
| pretrained_model (str): Path to pretrained cellpose model. |
| pretrained_model_ortho (str): Path or model_name for pretrained cellpose model for ortho views in 3D. |
| backbone (str): Type of network ("default" is the standard res-unet, "transformer" for the segformer). |
| |
| Methods: |
| __init__(self, gpu=False, pretrained_model=False, model_type=None, diam_mean=30., device=None): |
| Initialize the CellposeModel. |
| |
| eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, z_axis=None, normalize=True, invert=False, rescale=None, diameter=None, flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0, min_size=15, niter=None, augment=False, tile_overlap=0.1, bsize=224, interp=True, compute_masks=True, progress=None): |
| Segment list of images x, or 4D array - Z x C x Y x X. |
| |
| """ |
|
|
| def __init__(self, gpu=False, pretrained_model="", model_type=None, |
| diam_mean=None, device=None, nchan=None, use_bfloat16=True, vit_checkpoint=None): |
| """ |
| Initialize the CellposeModel. |
| |
| Parameters: |
| gpu (bool, optional): Whether or not to save model to GPU, will check if GPU available. |
| pretrained_model (str or list of strings, optional): Full path to pretrained cellpose model(s), if None or False, no model loaded. |
| model_type (str, optional): Any model that is available in the GUI, use name in GUI e.g. "livecell" (can be user-trained or model zoo). |
| diam_mean (float, optional): Mean "diameter", 30. is built-in value for "cyto" model; 17. is built-in value for "nuclei" model; if saved in custom model file (cellpose>=2.0) then it will be loaded automatically and overwrite this value. |
| device (torch device, optional): Device used for model running / training (torch.device("cuda") or torch.device("cpu")), overrides gpu input, recommended if you want to use a specific GPU (e.g. torch.device("cuda:1")). |
| use_bfloat16 (bool, optional): Use 16bit float precision instead of 32bit for model weights. Default to 16bit (True). |
| """ |
| |
| self.device = assign_device(gpu=gpu)[0] if device is None else device |
| if torch.cuda.is_available(): |
| device_gpu = self.device.type == "cuda" |
| elif torch.backends.mps.is_available(): |
| device_gpu = self.device.type == "mps" |
| else: |
| device_gpu = False |
| self.gpu = device_gpu |
|
|
| if pretrained_model is None: |
| |
| pretrained_model = "" |
|
|
| self.pretrained_model = pretrained_model |
| dtype = torch.bfloat16 if use_bfloat16 else torch.float32 |
| self.net = Transformer(dtype=dtype, checkpoint=vit_checkpoint).to(self.device) |
|
|
| |
| def eval(self, x, feat=None, batch_size=8, resample=True, channels=None, channel_axis=None, |
| z_axis=None, normalize=True, invert=False, rescale=None, diameter=None, |
| flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None, |
| flow3D_smooth=0, stitch_threshold=0.0, |
| min_size=15, max_size_fraction=0.4, niter=None, |
| augment=False, tile_overlap=0.1, bsize=256, |
| compute_masks=True, progress=None): |
| |
| if isinstance(x, list) or x.squeeze().ndim == 5: |
| self.timing = [] |
| masks, styles, flows = [], [], [] |
| tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO) |
| nimg = len(x) |
| iterator = trange(nimg, file=tqdm_out, |
| mininterval=30) if nimg > 1 else range(nimg) |
| for i in iterator: |
| tic = time.time() |
| maski, flowi, stylei = self.eval( |
| x[i], |
| feat=None if feat is None else feat[i], |
| batch_size=batch_size, |
| channel_axis=channel_axis, |
| z_axis=z_axis, |
| normalize=normalize, |
| invert=invert, |
| diameter=diameter[i] if isinstance(diameter, list) or |
| isinstance(diameter, np.ndarray) else diameter, |
| do_3D=do_3D, |
| anisotropy=anisotropy, |
| augment=augment, |
| tile_overlap=tile_overlap, |
| bsize=bsize, |
| resample=resample, |
| flow_threshold=flow_threshold, |
| cellprob_threshold=cellprob_threshold, |
| compute_masks=compute_masks, |
| min_size=min_size, |
| max_size_fraction=max_size_fraction, |
| stitch_threshold=stitch_threshold, |
| flow3D_smooth=flow3D_smooth, |
| progress=progress, |
| niter=niter) |
| masks.append(maski) |
| flows.append(flowi) |
| styles.append(stylei) |
| self.timing.append(time.time() - tic) |
| return masks, flows, styles |
|
|
| |
| |
| x = transforms.convert_image(x, channel_axis=channel_axis, |
| z_axis=z_axis, |
| do_3D=(do_3D or stitch_threshold > 0)) |
| |
| |
| if x.ndim < 4: |
| x = x[np.newaxis, ...] |
| if feat is not None: |
| if feat.ndim < 4: |
| feat = feat[np.newaxis, ...] |
| nimg = x.shape[0] |
| |
| image_scaling = None |
| Ly_0 = x.shape[1] |
| Lx_0 = x.shape[2] |
| Lz_0 = None |
| if stitch_threshold > 0: |
| Lz_0 = x.shape[0] |
| if diameter is not None: |
| image_scaling = 30. / diameter |
| x = transforms.resize_image(x, |
| Ly=int(x.shape[1] * image_scaling), |
| Lx=int(x.shape[2] * image_scaling)) |
| if feat is not None: |
| feat = transforms.resize_image(feat, |
| Ly=int(feat.shape[1] * image_scaling), |
| Lx=int(feat.shape[2] * image_scaling)) |
|
|
|
|
| |
| normalize_params = normalize_default |
| if isinstance(normalize, dict): |
| normalize_params = {**normalize_params, **normalize} |
| elif not isinstance(normalize, bool): |
| raise ValueError("normalize parameter must be a bool or a dict") |
| else: |
| normalize_params["normalize"] = normalize |
| normalize_params["invert"] = invert |
|
|
| |
| do_normalization = True if normalize_params["normalize"] else False |
| if nimg > 1 and do_normalization and (stitch_threshold or do_3D): |
| normalize_params["norm3D"] = True if do_3D else normalize_params["norm3D"] |
| x = transforms.normalize_img(x, **normalize_params) |
| do_normalization = False |
| else: |
| if normalize_params["norm3D"] and nimg > 1 and do_normalization: |
| models_logger.warning( |
| "normalize_params['norm3D'] is True but do_3D is False and stitch_threshold=0, so setting to False" |
| ) |
| normalize_params["norm3D"] = False |
| if do_normalization: |
| x = transforms.normalize_img(x, **normalize_params) |
|
|
| if feat is not None: |
| if feat.shape[-1] > feat.shape[1]: |
| |
| feat = np.moveaxis(feat, 1, -1) |
|
|
| |
| if isinstance(anisotropy, (float, int)) and image_scaling: |
| anisotropy = image_scaling * anisotropy |
|
|
| dP, cellprob, styles = self._run_net( |
| x, |
| feat=feat, |
| augment=augment, |
| batch_size=batch_size, |
| tile_overlap=tile_overlap, |
| bsize=bsize, |
| do_3D=do_3D, |
| anisotropy=anisotropy) |
|
|
|
|
| if resample: |
| |
| dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0) |
| cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0) |
|
|
|
|
| if compute_masks: |
| niter0 = 200 |
| niter = niter0 if niter is None or niter == 0 else niter |
| masks = self._compute_masks(x.shape, dP, cellprob, flow_threshold=flow_threshold, |
| cellprob_threshold=cellprob_threshold, min_size=min_size, |
| max_size_fraction=max_size_fraction, niter=niter, |
| stitch_threshold=stitch_threshold, do_3D=do_3D) |
| else: |
| masks = np.zeros(0) |
| |
| masks = masks.squeeze() |
|
|
| |
| if image_scaling is not None or anisotropy is not None: |
|
|
| if compute_masks: |
| masks = transforms.resize_image(masks, Ly=Ly_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST) |
|
|
| return masks |
|
|
| |
|
|
| def _resize_cellprob(self, prob: np.ndarray, to_y_size: int, to_x_size: int, to_z_size: int = None) -> np.ndarray: |
| """ |
| Resize cellprob array to specified dimensions for either 2D or 3D. |
| |
| Parameters: |
| prob (numpy.ndarray): The cellprobs to resize, either in 2D or 3D. Returns the same ndim as provided. |
| to_y_size (int): The target size along the Y-axis. |
| to_x_size (int): The target size along the X-axis. |
| to_z_size (int, optional): The target size along the Z-axis. Required |
| for 3D cellprobs. |
| |
| Returns: |
| numpy.ndarray: The resized cellprobs array with the same number of dimensions |
| as the input. |
| |
| Raises: |
| ValueError: If the input cellprobs array does not have 3 or 4 dimensions. |
| """ |
| prob_shape = prob.shape |
| prob = prob.squeeze() |
| squeeze_happened = prob.shape != prob_shape |
| prob_shape = np.array(prob_shape) |
|
|
| if prob.ndim == 2: |
| |
| prob = transforms.resize_image(prob, Ly=to_y_size, Lx=to_x_size, no_channels=True) |
| if squeeze_happened: |
| prob = np.expand_dims(prob, int(np.argwhere(prob_shape == 1))) |
| elif prob.ndim == 3: |
| |
| prob = transforms.resize_image(prob, Ly=to_y_size, Lx=to_x_size, no_channels=True) |
| prob = prob.transpose(1, 0, 2) |
| prob = transforms.resize_image(prob, Ly=to_z_size, Lx=to_x_size, no_channels=True) |
| prob = prob.transpose(1, 0, 2) |
| else: |
| raise ValueError(f'gradients have incorrect dimension after squeezing. Should be 2 or 3, prob shape: {prob.shape}') |
| |
| return prob |
|
|
|
|
| def _resize_gradients(self, grads: np.ndarray, to_y_size: int, to_x_size: int, to_z_size: int = None) -> np.ndarray: |
| """ |
| Resize gradient arrays to specified dimensions for either 2D or 3D gradients. |
| |
| Parameters: |
| grads (np.ndarray): The gradients to resize, either in 2D or 3D. Returns the same ndim as provided. |
| to_y_size (int): The target size along the Y-axis. |
| to_x_size (int): The target size along the X-axis. |
| to_z_size (int, optional): The target size along the Z-axis. Required |
| for 3D gradients. |
| |
| Returns: |
| numpy.ndarray: The resized gradient array with the same number of dimensions |
| as the input. |
| |
| Raises: |
| ValueError: If the input gradient array does not have 3 or 4 dimensions. |
| """ |
| grads_shape = grads.shape |
| grads = grads.squeeze() |
| squeeze_happened = grads.shape != grads_shape |
| grads_shape = np.array(grads_shape) |
|
|
| if grads.ndim == 3: |
| |
| grads = np.moveaxis(grads, 0, -1) |
| grads = transforms.resize_image(grads, Ly=to_y_size, Lx=to_x_size, no_channels=False) |
| grads = np.moveaxis(grads, -1, 0) |
|
|
| if squeeze_happened: |
| grads = np.expand_dims(grads, int(np.argwhere(grads_shape == 1))) |
| elif grads.ndim == 4: |
| |
| grads = grads.transpose(1, 2, 3, 0) |
| grads = transforms.resize_image(grads, Ly=to_y_size, Lx=to_x_size, no_channels=False) |
| grads = grads.transpose(1, 0, 2, 3) |
| grads = transforms.resize_image(grads, Ly=to_z_size, Lx=to_x_size, no_channels=False) |
| grads = grads.transpose(3, 1, 0, 2) |
| else: |
| raise ValueError(f'gradients have incorrect dimension after squeezing. Should be 3 or 4, grads shape: {grads.shape}') |
| |
| return grads |
|
|
|
|
| def _run_net(self, x, feat=None, |
| augment=False, |
| batch_size=8, tile_overlap=0.1, |
| bsize=224, anisotropy=1.0, do_3D=False): |
| """ run network on image x """ |
| tic = time.time() |
| shape = x.shape |
| nimg = shape[0] |
|
|
|
|
| |
| yf, styles = run_net(self.net, x, feat=feat, bsize=bsize, augment=augment, |
| batch_size=batch_size, |
| tile_overlap=tile_overlap, |
| ) |
| cellprob = yf[..., -1] |
| dP = yf[..., -3:-1].transpose((3, 0, 1, 2)) |
| if yf.shape[-1] > 3: |
| styles = yf[..., :-3] |
| |
| styles = styles.squeeze() |
|
|
| net_time = time.time() - tic |
| if nimg > 1: |
| models_logger.info("network run in %2.2fs" % (net_time)) |
|
|
| return dP, cellprob, styles |
| |
| def _compute_masks(self, shape, dP, cellprob, flow_threshold=0.4, cellprob_threshold=0.0, |
| min_size=15, max_size_fraction=0.4, niter=None, |
| do_3D=False, stitch_threshold=0.0): |
| """ compute masks from flows and cell probability """ |
| changed_device_from = None |
| if self.device.type == "mps" and do_3D: |
| models_logger.warning("MPS does not support 3D post-processing, switching to CPU") |
| self.device = torch.device("cpu") |
| changed_device_from = "mps" |
| Lz, Ly, Lx = shape[:3] |
| tic = time.time() |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| nimg = shape[0] |
| Ly0, Lx0 = cellprob[0].shape |
| resize = None if Ly0==Ly and Lx0==Lx else [Ly, Lx] |
| tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO) |
| iterator = trange(nimg, file=tqdm_out, |
| mininterval=30) if nimg > 1 else range(nimg) |
| for i in iterator: |
| |
| min_size0 = min_size if stitch_threshold == 0 or nimg == 1 else -1 |
| outputs = dynamics.resize_and_compute_masks( |
| dP[:, i], cellprob[i], |
| niter=niter, cellprob_threshold=cellprob_threshold, |
| flow_threshold=flow_threshold, resize=resize, |
| min_size=min_size0, max_size_fraction=max_size_fraction, |
| device=self.device) |
| if i==0 and nimg > 1: |
| masks = np.zeros((nimg, shape[1], shape[2]), outputs.dtype) |
| if nimg > 1: |
| masks[i] = outputs |
| else: |
| masks = outputs |
|
|
| if stitch_threshold > 0 and nimg > 1: |
| models_logger.info( |
| f"stitching {nimg} planes using stitch_threshold={stitch_threshold:0.3f} to make 3D masks" |
| ) |
| masks = utils.stitch3D(masks, stitch_threshold=stitch_threshold) |
| masks = utils.fill_holes_and_remove_small_masks( |
| masks, min_size=min_size) |
| elif nimg > 1: |
| models_logger.warning( |
| "3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only" |
| ) |
|
|
| flow_time = time.time() - tic |
| if shape[0] > 1: |
| models_logger.info("masks created in %2.2fs" % (flow_time)) |
| |
| if changed_device_from is not None: |
| models_logger.info("switching back to device %s" % self.device) |
| self.device = torch.device(changed_device_from) |
| return masks |
|
|