Spaces:
Sleeping
Sleeping
| """ | |
| Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. | |
| """ | |
| import os, time, datetime | |
| import numpy as np | |
| from scipy.stats import mode | |
| import cv2 | |
| import torch | |
| from torch import nn | |
| from torch.nn.functional import conv2d, interpolate | |
| from tqdm import trange | |
| from pathlib import Path | |
| import logging | |
| denoise_logger = logging.getLogger(__name__) | |
| from cellpose import transforms, resnet_torch, utils, io | |
| from cellpose.core import run_net | |
| from cellpose.resnet_torch import CPnet | |
| from cellpose.models import CellposeModel, model_path, normalize_default, assign_device, check_mkl | |
| MODEL_NAMES = [] | |
| for ctype in ["cyto3", "cyto2", "nuclei"]: | |
| for ntype in ["denoise", "deblur", "upsample", "oneclick"]: | |
| MODEL_NAMES.append(f"{ntype}_{ctype}") | |
| if ctype != "cyto3": | |
| for ltype in ["per", "seg", "rec"]: | |
| MODEL_NAMES.append(f"{ntype}_{ltype}_{ctype}") | |
| if ctype != "cyto3": | |
| MODEL_NAMES.append(f"aniso_{ctype}") | |
| criterion = nn.MSELoss(reduction="mean") | |
| criterion2 = nn.BCEWithLogitsLoss(reduction="mean") | |
| def deterministic(seed=0): | |
| """ set random seeds to create test data """ | |
| import random | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. | |
| np.random.seed(seed) # Numpy module. | |
| random.seed(seed) # Python random module. | |
| torch.manual_seed(seed) | |
| torch.backends.cudnn.benchmark = False | |
| torch.backends.cudnn.deterministic = True | |
| def loss_fn_rec(lbl, y): | |
| """ loss function between true labels lbl and prediction y """ | |
| loss = 80. * criterion(y, lbl) | |
| return loss | |
| def loss_fn_seg(lbl, y): | |
| """ loss function between true labels lbl and prediction y """ | |
| veci = 5. * lbl[:, 1:] | |
| lbl = (lbl[:, 0] > .5).float() | |
| loss = criterion(y[:, :2], veci) | |
| loss /= 2. | |
| loss2 = criterion2(y[:, 2], lbl) | |
| loss = loss + loss2 | |
| return loss | |
| def get_sigma(Tdown): | |
| """ Calculates the correlation matrices across channels for the perceptual loss. | |
| Args: | |
| Tdown (list): List of tensors output by each downsampling block of network. | |
| Returns: | |
| list: List of correlations for each input tensor. | |
| """ | |
| Tnorm = [x - x.mean((-2, -1), keepdim=True) for x in Tdown] | |
| Tnorm = [x / x.std((-2, -1), keepdim=True) for x in Tnorm] | |
| Sigma = [ | |
| torch.einsum("bnxy, bmxy -> bnm", x, x) / (x.shape[-2] * x.shape[-1]) | |
| for x in Tnorm | |
| ] | |
| return Sigma | |
| def imstats(X, net1): | |
| """ | |
| Calculates the image correlation matrices for the perceptual loss. | |
| Args: | |
| X (torch.Tensor): Input image tensor. | |
| net1: Cellpose net. | |
| Returns: | |
| list: A list of tensors of correlation matrices. | |
| """ | |
| _, _, Tdown = net1(X) | |
| Sigma = get_sigma(Tdown) | |
| Sigma = [x.detach() for x in Sigma] | |
| return Sigma | |
| def loss_fn_per(img, net1, yl): | |
| """ | |
| Calculates the perceptual loss function for image restoration. | |
| Args: | |
| img (torch.Tensor): Input image tensor (noisy/blurry/downsampled). | |
| net1 (torch.nn.Module): Perceptual loss net (Cellpose segmentation net). | |
| yl (torch.Tensor): Clean image tensor. | |
| Returns: | |
| torch.Tensor: Mean perceptual loss. | |
| """ | |
| Sigma = imstats(img, net1) | |
| sd = [x.std((1, 2)) + 1e-6 for x in Sigma] | |
| Sigma_test = get_sigma(yl) | |
| losses = torch.zeros(len(Sigma[0]), device=img.device) | |
| for k in range(len(Sigma)): | |
| losses = losses + (((Sigma_test[k] - Sigma[k])**2).mean((1, 2)) / sd[k]**2) | |
| return losses.mean() | |
| def test_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]): | |
| """ | |
| Calculates the test loss for image restoration tasks. | |
| Args: | |
| net0 (torch.nn.Module): The image restoration network. | |
| X (torch.Tensor): The input image tensor. | |
| net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None. | |
| img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None. | |
| lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None. | |
| lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.]. | |
| Returns: | |
| tuple: A tuple containing the total loss and the perceptual loss. | |
| """ | |
| net0.eval() | |
| if net1 is not None: | |
| net1.eval() | |
| loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device) | |
| with torch.no_grad(): | |
| img_dn = net0(X)[0] | |
| if lam[2] > 0.: | |
| loss += lam[2] * loss_fn_rec(img, img_dn) | |
| if lam[1] > 0. or lam[0] > 0.: | |
| y, _, ydown = net1(img_dn) | |
| if lam[1] > 0.: | |
| loss += lam[1] * loss_fn_seg(lbl, y) | |
| if lam[0] > 0.: | |
| loss_per = loss_fn_per(img, net1, ydown) | |
| loss += lam[0] * loss_per | |
| return loss, loss_per | |
| def train_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]): | |
| """ | |
| Calculates the train loss for image restoration tasks. | |
| Args: | |
| net0 (torch.nn.Module): The image restoration network. | |
| X (torch.Tensor): The input image tensor. | |
| net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None. | |
| img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None. | |
| lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None. | |
| lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.]. | |
| Returns: | |
| tuple: A tuple containing the total loss and the perceptual loss. | |
| """ | |
| net0.train() | |
| if net1 is not None: | |
| net1.eval() | |
| loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device) | |
| img_dn = net0(X)[0] | |
| if lam[2] > 0.: | |
| loss += lam[2] * loss_fn_rec(img, img_dn) | |
| if lam[1] > 0. or lam[0] > 0.: | |
| y, _, ydown = net1(img_dn) | |
| if lam[1] > 0.: | |
| loss += lam[1] * loss_fn_seg(lbl, y) | |
| if lam[0] > 0.: | |
| loss_per = loss_fn_per(img, net1, ydown) | |
| loss += lam[0] * loss_per | |
| return loss, loss_per | |
| def img_norm(imgi): | |
| """ | |
| Normalizes the input image by subtracting the 1st percentile and dividing by the difference between the 99th and 1st percentiles. | |
| Args: | |
| imgi (torch.Tensor): Input image tensor. | |
| Returns: | |
| torch.Tensor: Normalized image tensor. | |
| """ | |
| shape = imgi.shape | |
| imgi = imgi.reshape(imgi.shape[0], imgi.shape[1], -1) | |
| perc = torch.quantile(imgi, torch.tensor([0.01, 0.99], device=imgi.device), dim=-1, | |
| keepdim=True) | |
| for k in range(imgi.shape[1]): | |
| hask = (perc[1, :, k, 0] - perc[0, :, k, 0]) > 1e-3 | |
| imgi[hask, k] -= perc[0, hask, k] | |
| imgi[hask, k] /= (perc[1, hask, k] - perc[0, hask, k]) | |
| imgi = imgi.reshape(shape) | |
| return imgi | |
| def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsample=0.7, | |
| ds_max=7, diams=None, pscale=None, iso=True, sigma0=None, sigma1=None, | |
| ds=None, uniform_blur=False, partial_blur=False): | |
| """Adds noise to the input image. | |
| Args: | |
| lbl (torch.Tensor): The input image tensor of shape (nimg, nchan, Ly, Lx). | |
| alpha (float, optional): The shape parameter of the gamma distribution used for generating poisson noise. Defaults to 4. | |
| beta (float, optional): The rate parameter of the gamma distribution used for generating poisson noise. Defaults to 0.7. | |
| poisson (float, optional): The probability of adding poisson noise to the image. Defaults to 0.7. | |
| blur (float, optional): The probability of adding gaussian blur to the image. Defaults to 0.7. | |
| gblur (float, optional): The scale factor for the gaussian blur. Defaults to 1.0. | |
| downsample (float, optional): The probability of downsampling the image. Defaults to 0.7. | |
| ds_max (int, optional): The maximum downsampling factor. Defaults to 7. | |
| diams (torch.Tensor, optional): The diameter of the objects in the image. Defaults to None. | |
| pscale (torch.Tensor, optional): The scale factor for the poisson noise, instead of sampling. Defaults to None. | |
| iso (bool, optional): Whether to use isotropic gaussian blur. Defaults to True. | |
| sigma0 (torch.Tensor, optional): The standard deviation of the gaussian filter for the Y axis, instead of sampling. Defaults to None. | |
| sigma1 (torch.Tensor, optional): The standard deviation of the gaussian filter for the X axis, instead of sampling. Defaults to None. | |
| ds (torch.Tensor, optional): The downsampling factor for each image, instead of sampling. Defaults to None. | |
| Returns: | |
| torch.Tensor: The noisy image tensor of the same shape as the input image. | |
| """ | |
| device = lbl.device | |
| imgi = torch.zeros_like(lbl) | |
| Ly, Lx = lbl.shape[-2:] | |
| diams = diams if diams is not None else 30. * torch.ones(len(lbl), device=device) | |
| #ds0 = 1 if ds is None else ds.item() | |
| ds = ds * torch.ones( | |
| (len(lbl),), device=device, dtype=torch.long) if ds is not None else ds | |
| # downsample | |
| ii = [] | |
| idownsample = np.random.rand(len(lbl)) < downsample | |
| if (ds is None and idownsample.sum() > 0.) or not iso: | |
| ds = torch.ones(len(lbl), dtype=torch.long, device=device) | |
| ds[idownsample] = torch.randint(2, ds_max + 1, size=(idownsample.sum(),), | |
| device=device) | |
| ii = torch.nonzero(ds > 1).flatten() | |
| elif ds is not None and (ds > 1).sum(): | |
| ii = torch.nonzero(ds > 1).flatten() | |
| # add gaussian blur | |
| iblur = torch.rand(len(lbl), device=device) < blur | |
| iblur[ii] = True | |
| if iblur.sum() > 0: | |
| if sigma0 is None: | |
| if uniform_blur and iso: | |
| xr = torch.rand(len(lbl), device=device) | |
| if len(ii) > 0: | |
| xr[ii] = ds[ii].float() / 2. / gblur | |
| sigma0 = diams[iblur] / 30. * gblur * (1 / gblur + (1 - 1 / gblur) * xr[iblur]) | |
| sigma1 = sigma0.clone() | |
| elif not iso: | |
| xr = torch.rand(len(lbl), device=device) | |
| if len(ii) > 0: | |
| xr[ii] = (ds[ii].float()) / gblur | |
| xr[ii] = xr[ii] + torch.rand(len(ii), device=device) * 0.7 - 0.35 | |
| xr[ii] = torch.clip(xr[ii], 0.05, 1.5) | |
| sigma0 = diams[iblur] / 30. * gblur * xr[iblur] | |
| sigma1 = sigma0.clone() / 10. | |
| else: | |
| xrand = np.random.exponential(1, size=iblur.sum()) | |
| xrand = np.clip(xrand * 0.5, 0.1, 1.0) | |
| xrand *= gblur | |
| sigma0 = diams[iblur] / 30. * 5. * torch.from_numpy(xrand).float().to( | |
| device) | |
| sigma1 = sigma0.clone() | |
| else: | |
| sigma0 = sigma0 * torch.ones((iblur.sum(),), device=device) | |
| sigma1 = sigma1 * torch.ones((iblur.sum(),), device=device) | |
| # create gaussian filter | |
| xr = max(8, sigma0.max().long() * 2) | |
| gfilt0 = torch.exp(-torch.arange(-xr + 1, xr, device=device)**2 / | |
| (2 * sigma0.unsqueeze(-1)**2)) | |
| gfilt0 /= gfilt0.sum(axis=-1, keepdims=True) | |
| gfilt1 = torch.zeros_like(gfilt0) | |
| gfilt1[sigma1 == sigma0] = gfilt0[sigma1 == sigma0] | |
| gfilt1[sigma1 != sigma0] = torch.exp( | |
| -torch.arange(-xr + 1, xr, device=device)**2 / | |
| (2 * sigma1[sigma1 != sigma0].unsqueeze(-1)**2)) | |
| gfilt1[sigma1 == 0] = 0. | |
| gfilt1[sigma1 == 0, xr] = 1. | |
| gfilt1 /= gfilt1.sum(axis=-1, keepdims=True) | |
| gfilt = torch.einsum("ck,cl->ckl", gfilt0, gfilt1) | |
| gfilt /= gfilt.sum(axis=(1, 2), keepdims=True) | |
| lbl_blur = conv2d(lbl[iblur].transpose(1, 0), gfilt.unsqueeze(1), | |
| padding=gfilt.shape[-1] // 2, | |
| groups=gfilt.shape[0]).transpose(1, 0) | |
| if partial_blur: | |
| #yc, xc = np.random.randint(100, Ly-100), np.random.randint(100, Lx-100) | |
| imgi[iblur] = lbl[iblur].clone() | |
| Lxc = int(Lx * 0.85) | |
| ym, xm = torch.meshgrid(torch.zeros(Ly, dtype=torch.float32), | |
| torch.arange(0, Lxc, dtype=torch.float32), | |
| indexing="ij") | |
| mask = torch.exp(-(ym**2 + xm**2) / 2*(0.001**2)) | |
| mask -= mask.min() | |
| mask /= mask.max() | |
| lbl_blur_crop = lbl_blur[:, :, :, :Lxc] | |
| imgi[iblur, :, :, :Lxc] = (lbl_blur_crop * mask + | |
| (1-mask) * imgi[iblur, :, :, :Lxc]) | |
| else: | |
| imgi[iblur] = lbl_blur | |
| imgi[~iblur] = lbl[~iblur] | |
| # apply downsample | |
| for k in ii: | |
| i0 = imgi[k:k + 1, :, ::ds[k], ::ds[k]] if iso else imgi[k:k + 1, :, ::ds[k]] | |
| imgi[k] = interpolate(i0, size=lbl[k].shape[-2:], mode="bilinear") | |
| # add poisson noise | |
| ipoisson = np.random.rand(len(lbl)) < poisson | |
| if ipoisson.sum() > 0: | |
| if pscale is None: | |
| pscale = torch.zeros(len(lbl)) | |
| m = torch.distributions.gamma.Gamma(alpha, beta) | |
| pscale = torch.clamp(m.rsample(sample_shape=(ipoisson.sum(),)), 1.) | |
| #pscale = torch.clamp(20 * (torch.rand(size=(len(lbl),), device=lbl.device)), 1.5) | |
| pscale = pscale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(device) | |
| else: | |
| pscale = pscale * torch.ones((ipoisson.sum(), 1, 1, 1), device=device) | |
| imgi[ipoisson] = torch.poisson(pscale * imgi[ipoisson]) | |
| imgi[~ipoisson] = imgi[~ipoisson] | |
| # renormalize | |
| imgi = img_norm(imgi) | |
| return imgi | |
| def random_rotate_and_resize_noise(data, labels=None, diams=None, poisson=0.7, blur=0.7, | |
| downsample=0.0, beta=0.7, gblur=1.0, diam_mean=30, | |
| ds_max=7, uniform_blur=False, iso=True, rotate=True, | |
| device=torch.device("cuda"), xy=(224, 224), | |
| nchan_noise=1, keep_raw=True): | |
| """ | |
| Applies random rotation, resizing, and noise to the input data. | |
| Args: | |
| data (numpy.ndarray): The input data. | |
| labels (numpy.ndarray, optional): The flow and cellprob labels associated with the data. Defaults to None. | |
| diams (float, optional): The diameter of the objects. Defaults to None. | |
| poisson (float, optional): The Poisson noise probability. Defaults to 0.7. | |
| blur (float, optional): The blur probability. Defaults to 0.7. | |
| downsample (float, optional): The downsample probability. Defaults to 0.0. | |
| beta (float, optional): The beta value for the poisson noise distribution. Defaults to 0.7. | |
| gblur (float, optional): The Gaussian blur level. Defaults to 1.0. | |
| diam_mean (float, optional): The mean diameter. Defaults to 30. | |
| ds_max (int, optional): The maximum downsample value. Defaults to 7. | |
| iso (bool, optional): Whether to apply isotropic augmentation. Defaults to True. | |
| rotate (bool, optional): Whether to apply rotation augmentation. Defaults to True. | |
| device (torch.device, optional): The device to use. Defaults to torch.device("cuda"). | |
| xy (tuple, optional): The size of the output image. Defaults to (224, 224). | |
| nchan_noise (int, optional): The number of channels to add noise to. Defaults to 1. | |
| keep_raw (bool, optional): Whether to keep the raw image. Defaults to True. | |
| Returns: | |
| torch.Tensor: The augmented image and augmented noisy/blurry/downsampled version of image. | |
| torch.Tensor: The augmented labels. | |
| float: The scale factor applied to the image. | |
| """ | |
| if device == None: | |
| device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None | |
| diams = 30 if diams is None else diams | |
| random_diam = diam_mean * (2**(2 * np.random.rand(len(data)) - 1)) | |
| random_rsc = diams / random_diam #/ random_diam | |
| #rsc /= random_scale | |
| xy0 = (340, 340) | |
| nchan = data[0].shape[0] | |
| data_new = np.zeros((len(data), (1 + keep_raw) * nchan, xy0[0], xy0[1]), "float32") | |
| labels_new = np.zeros((len(data), 3, xy0[0], xy0[1]), "float32") | |
| for i in range( | |
| len(data)): #, (sc, img, lbl) in enumerate(zip(random_rsc, data, labels)): | |
| sc = random_rsc[i] | |
| img = data[i] | |
| lbl = labels[i] if labels is not None else None | |
| # create affine transform to resize | |
| Ly, Lx = img.shape[-2:] | |
| dxy = np.maximum(0, np.array([Lx / sc - xy0[1], Ly / sc - xy0[0]])) | |
| dxy = (np.random.rand(2,) - .5) * dxy | |
| cc = np.array([Lx / 2, Ly / 2]) | |
| cc1 = cc - np.array([Lx - xy0[1], Ly - xy0[0]]) / 2 + dxy | |
| pts1 = np.float32([cc, cc + np.array([1, 0]), cc + np.array([0, 1])]) | |
| pts2 = np.float32( | |
| [cc1, cc1 + np.array([1, 0]) / sc, cc1 + np.array([0, 1]) / sc]) | |
| M = cv2.getAffineTransform(pts1, pts2) | |
| # apply to image | |
| for c in range(nchan): | |
| img_rsz = cv2.warpAffine(img[c], M, xy0, flags=cv2.INTER_LINEAR) | |
| #img_noise = add_noise(torch.from_numpy(img_rsz).to(device).unsqueeze(0)).cpu().numpy().squeeze(0) | |
| data_new[i, c] = img_rsz | |
| if keep_raw: | |
| data_new[i, c + nchan] = img_rsz | |
| if lbl is not None: | |
| # apply to labels | |
| labels_new[i, 0] = cv2.warpAffine(lbl[0], M, xy0, flags=cv2.INTER_NEAREST) | |
| labels_new[i, 1] = cv2.warpAffine(lbl[1], M, xy0, flags=cv2.INTER_LINEAR) | |
| labels_new[i, 2] = cv2.warpAffine(lbl[2], M, xy0, flags=cv2.INTER_LINEAR) | |
| rsc = random_diam / diam_mean | |
| # add noise before augmentations | |
| img = torch.from_numpy(data_new).to(device) | |
| img = torch.clamp(img, 0.) | |
| # just add noise to cyto if nchan_noise=1 | |
| img[:, :nchan_noise] = add_noise( | |
| img[:, :nchan_noise], poisson=poisson, blur=blur, ds_max=ds_max, iso=iso, | |
| downsample=downsample, beta=beta, gblur=gblur, | |
| diams=torch.from_numpy(random_diam).to(device).float()) | |
| # img -= img.mean(dim=(-2,-1), keepdim=True) | |
| # img /= img.std(dim=(-2,-1), keepdim=True) + 1e-3 | |
| img = img.cpu().numpy() | |
| # augmentations | |
| img, lbl, scale = transforms.random_rotate_and_resize( | |
| img, | |
| Y=labels_new, | |
| xy=xy, | |
| rotate=False if not iso else rotate, | |
| #(iso and downsample==0), | |
| rescale=rsc, | |
| scale_range=0.5) | |
| img = torch.from_numpy(img).to(device) | |
| lbl = torch.from_numpy(lbl).to(device) | |
| return img, lbl, scale | |
| def one_chan_cellpose(device, model_type="cyto2", pretrained_model=None): | |
| """ | |
| Creates a Cellpose network with a single input channel. | |
| Args: | |
| device (str): The device to run the network on. | |
| model_type (str, optional): The type of Cellpose model to use. Defaults to "cyto2". | |
| pretrained_model (str, optional): The path to a pretrained model file. Defaults to None. | |
| Returns: | |
| torch.nn.Module: The Cellpose network with a single input channel. | |
| """ | |
| if pretrained_model is not None and not os.path.exists(pretrained_model): | |
| model_type = pretrained_model | |
| pretrained_model = None | |
| nbase = [32, 64, 128, 256] | |
| nchan = 1 | |
| net1 = resnet_torch.CPnet([nchan, *nbase], nout=3, sz=3).to(device) | |
| filename = model_path(model_type, | |
| 0) if pretrained_model is None else pretrained_model | |
| weights = torch.load(filename, weights_only=True) | |
| zp = 0 | |
| print(filename) | |
| for name in net1.state_dict(): | |
| if ("res_down_0.conv.conv_0" not in name and | |
| #"output" not in name and | |
| "res_down_0.proj" not in name and name != "diam_mean" and | |
| name != "diam_labels"): | |
| net1.state_dict()[name].copy_(weights[name]) | |
| elif "res_down_0" in name: | |
| if len(weights[name].shape) > 0: | |
| new_weight = torch.zeros_like(net1.state_dict()[name]) | |
| if weights[name].shape[0] == 2: | |
| new_weight[:] = weights[name][0] | |
| elif len(weights[name].shape) > 1 and weights[name].shape[1] == 2: | |
| new_weight[:, zp] = weights[name][:, 0] | |
| else: | |
| new_weight = weights[name] | |
| else: | |
| new_weight = weights[name] | |
| net1.state_dict()[name].copy_(new_weight) | |
| return net1 | |
| class CellposeDenoiseModel(): | |
| """ model to run Cellpose and Image restoration """ | |
| def __init__(self, gpu=False, pretrained_model=False, model_type=None, | |
| restore_type="denoise_cyto3", nchan=2, | |
| chan2_restore=False, device=None): | |
| self.dn = DenoiseModel(gpu=gpu, model_type=restore_type, chan2=chan2_restore, | |
| device=device) | |
| self.cp = CellposeModel(gpu=gpu, model_type=model_type, nchan=nchan, | |
| pretrained_model=pretrained_model, device=device) | |
| def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, | |
| normalize=True, rescale=None, diameter=None, tile_overlap=0.1, | |
| augment=False, resample=True, invert=False, flow_threshold=0.4, | |
| cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0, | |
| min_size=15, niter=None, interp=True, bsize=224, flow3D_smooth=0): | |
| """ | |
| Restore array or list of images using the image restoration model, and then segment. | |
| Args: | |
| x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images | |
| batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU | |
| (can make smaller or bigger depending on GPU memory usage). Defaults to 8. | |
| channels (list, optional): list of channels, either of length 2 or of length number of images by 2. | |
| First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue). | |
| Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue). | |
| For instance, to segment grayscale images, input [0,0]. To segment images with cells | |
| in green and nuclei in blue, input [2,3]. To segment one grayscale image and one | |
| image with cells in green and nuclei in blue, input [[0,0], [2,3]]. | |
| Defaults to None. | |
| channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x. | |
| if None, channels dimension is attempted to be automatically determined. Defaults to None. | |
| z_axis (int, optional): z axis in element of list x, or of np.ndarray x. | |
| if None, z dimension is attempted to be automatically determined. Defaults to None. | |
| normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel; | |
| can also pass dictionary of parameters (all keys are optional, default values shown): | |
| - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored) | |
| - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels | |
| - "normalize"=True ; run normalization (if False, all following parameters ignored) | |
| - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high] | |
| - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100) | |
| - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode. | |
| Defaults to True. | |
| rescale (float, optional): resize factor for each image, if None, set to 1.0; | |
| (only used if diameter is None). Defaults to None. | |
| diameter (float, optional): diameter for each image, | |
| if diameter is None, set to diam_mean or diam_train if available. Defaults to None. | |
| tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1. | |
| augment (bool, optional): augment tiles by flipping and averaging for segmentation. Defaults to False. | |
| resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries). Defaults to True. | |
| invert (bool, optional): invert image pixel intensity before running network. Defaults to False. | |
| flow_threshold (float, optional): flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4. | |
| cellprob_threshold (float, optional): all pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0. | |
| do_3D (bool, optional): set to True to run 3D segmentation on 3D/4D image input. Defaults to False. | |
| anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None. | |
| stitch_threshold (float, optional): if stitch_threshold>0.0 and not do_3D, masks are stitched in 3D to return volume segmentation. Defaults to 0.0. | |
| min_size (int, optional): all ROIs below this size, in pixels, will be discarded. Defaults to 15. | |
| flow3D_smooth (int, optional): if do_3D and flow3D_smooth>0, smooth flows with gaussian filter of this stddev. Defaults to 0. | |
| niter (int, optional): number of iterations for dynamics computation. if None, it is set proportional to the diameter. Defaults to None. | |
| interp (bool, optional): interpolate during 2D dynamics (not available in 3D) . Defaults to True. | |
| Returns: | |
| A tuple containing (masks, flows, styles, imgs); masks: labelled image(s), where 0=no masks; 1,2,...=mask labels; | |
| flows: list of lists: flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY(Z) flows at each pixel; flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics); flows[k][3] = final pixel locations after Euler integration; | |
| styles: style vector summarizing each image of size 256; | |
| imgs: Restored images. | |
| """ | |
| if isinstance(normalize, dict): | |
| normalize_params = {**normalize_default, **normalize} | |
| elif not isinstance(normalize, bool): | |
| raise ValueError("normalize parameter must be a bool or a dict") | |
| else: | |
| normalize_params = normalize_default | |
| normalize_params["normalize"] = normalize | |
| normalize_params["invert"] = invert | |
| img_restore = self.dn.eval(x, batch_size=batch_size, channels=channels, | |
| channel_axis=channel_axis, z_axis=z_axis, | |
| do_3D=do_3D, | |
| normalize=normalize_params, rescale=rescale, | |
| diameter=diameter, | |
| tile_overlap=tile_overlap, bsize=bsize) | |
| # turn off special normalization for segmentation | |
| normalize_params = normalize_default | |
| # change channels for segmentation | |
| if channels is not None: | |
| channels_new = [0, 0] if channels[0] == 0 else [1, 2] | |
| else: | |
| channels_new = None | |
| # change diameter if self.ratio > 1 (upsampled to self.dn.diam_mean) | |
| diameter = self.dn.diam_mean if self.dn.ratio > 1 else diameter | |
| masks, flows, styles = self.cp.eval( | |
| img_restore, batch_size=batch_size, channels=channels_new, channel_axis=-1, | |
| z_axis=0 if not isinstance(img_restore, list) and img_restore.ndim > 3 and img_restore.shape[0] > 0 else None, | |
| normalize=normalize_params, rescale=rescale, diameter=diameter, | |
| tile_overlap=tile_overlap, augment=augment, resample=resample, | |
| invert=invert, flow_threshold=flow_threshold, | |
| cellprob_threshold=cellprob_threshold, do_3D=do_3D, anisotropy=anisotropy, | |
| stitch_threshold=stitch_threshold, min_size=min_size, niter=niter, | |
| interp=interp, bsize=bsize) | |
| return masks, flows, styles, img_restore | |
| class DenoiseModel(): | |
| """ | |
| DenoiseModel class for denoising images using Cellpose denoising model. | |
| Args: | |
| gpu (bool, optional): Whether to use GPU for computation. Defaults to False. | |
| pretrained_model (bool or str or Path, optional): Pretrained model to use for denoising. | |
| Can be a string or path. Defaults to False. | |
| nchan (int, optional): Number of channels in the input images, all Cellpose 3 models were trained with nchan=1. Defaults to 1. | |
| model_type (str, optional): Type of pretrained model to use ("denoise_cyto3", "deblur_cyto3", "upsample_cyto3", ...). Defaults to None. | |
| chan2 (bool, optional): Whether to use a separate model for the second channel. Defaults to False. | |
| diam_mean (float, optional): Mean diameter of the objects in the images. Defaults to 30.0. | |
| device (torch.device, optional): Device to use for computation. Defaults to None. | |
| Attributes: | |
| nchan (int): Number of channels in the input images. | |
| diam_mean (float): Mean diameter of the objects in the images. | |
| net (CPnet): Cellpose network for denoising. | |
| pretrained_model (bool or str or Path): Pretrained model path to use for denoising. | |
| net_chan2 (CPnet or None): Cellpose network for the second channel, if applicable. | |
| net_type (str): Type of the denoising network. | |
| Methods: | |
| eval(x, batch_size=8, channels=None, channel_axis=None, z_axis=None, | |
| normalize=True, rescale=None, diameter=None, tile=True, tile_overlap=0.1) | |
| Denoise array or list of images using the denoising model. | |
| _eval(net, x, normalize=True, rescale=None, diameter=None, tile=True, | |
| tile_overlap=0.1) | |
| Run denoising model on a single channel. | |
| """ | |
| def __init__(self, gpu=False, pretrained_model=False, nchan=1, model_type=None, | |
| chan2=False, diam_mean=30., device=None): | |
| self.nchan = nchan | |
| if pretrained_model and (not isinstance(pretrained_model, str) and | |
| not isinstance(pretrained_model, Path)): | |
| raise ValueError("pretrained_model must be a string or path") | |
| self.diam_mean = diam_mean | |
| builtin = True | |
| if model_type is not None or (pretrained_model and | |
| not os.path.exists(pretrained_model)): | |
| pretrained_model_string = model_type if model_type is not None else "denoise_cyto3" | |
| if ~np.any([pretrained_model_string == s for s in MODEL_NAMES]): | |
| pretrained_model_string = "denoise_cyto3" | |
| pretrained_model = model_path(pretrained_model_string) | |
| if (pretrained_model and not os.path.exists(pretrained_model)): | |
| denoise_logger.warning("pretrained model has incorrect path") | |
| denoise_logger.info(f">> {pretrained_model_string} << model set to be used") | |
| self.diam_mean = 17. if "nuclei" in pretrained_model_string else 30. | |
| else: | |
| if pretrained_model: | |
| builtin = False | |
| pretrained_model_string = pretrained_model | |
| denoise_logger.info(f">>>> loading model {pretrained_model_string}") | |
| # assign network device | |
| self.mkldnn = None | |
| if device is None: | |
| sdevice, gpu = assign_device(use_torch=True, gpu=gpu) | |
| self.device = device if device is not None else sdevice | |
| if device is not None: | |
| device_gpu = self.device.type == "cuda" | |
| self.gpu = gpu if device is None else device_gpu | |
| if not self.gpu: | |
| self.mkldnn = check_mkl(True) | |
| # create network | |
| self.nchan = nchan | |
| self.nclasses = 1 | |
| nbase = [32, 64, 128, 256] | |
| self.nchan = nchan | |
| self.nbase = [nchan, *nbase] | |
| self.net = CPnet(self.nbase, self.nclasses, sz=3, mkldnn=self.mkldnn, | |
| max_pool=True, diam_mean=diam_mean).to(self.device) | |
| self.pretrained_model = pretrained_model | |
| self.net_chan2 = None | |
| if self.pretrained_model: | |
| self.net.load_model(self.pretrained_model, device=self.device) | |
| denoise_logger.info( | |
| f">>>> model diam_mean = {self.diam_mean: .3f} (ROIs rescaled to this size during training)" | |
| ) | |
| if chan2 and builtin: | |
| chan2_path = model_path( | |
| os.path.split(self.pretrained_model)[-1].split("_")[0] + "_nuclei") | |
| print(f"loading model for chan2: {os.path.split(str(chan2_path))[-1]}") | |
| self.net_chan2 = CPnet(self.nbase, self.nclasses, sz=3, | |
| mkldnn=self.mkldnn, max_pool=True, | |
| diam_mean=17.).to(self.device) | |
| self.net_chan2.load_model(chan2_path, device=self.device) | |
| self.net_type = "cellpose_denoise" | |
| def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, | |
| normalize=True, rescale=None, diameter=None, tile=True, do_3D=False, | |
| tile_overlap=0.1, bsize=224): | |
| """ | |
| Restore array or list of images using the image restoration model. | |
| Args: | |
| x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images | |
| batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU | |
| (can make smaller or bigger depending on GPU memory usage). Defaults to 8. | |
| channels (list, optional): list of channels, either of length 2 or of length number of images by 2. | |
| First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue). | |
| Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue). | |
| For instance, to segment grayscale images, input [0,0]. To segment images with cells | |
| in green and nuclei in blue, input [2,3]. To segment one grayscale image and one | |
| image with cells in green and nuclei in blue, input [[0,0], [2,3]]. | |
| Defaults to None. | |
| channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x. | |
| if None, channels dimension is attempted to be automatically determined. Defaults to None. | |
| z_axis (int, optional): z axis in element of list x, or of np.ndarray x. | |
| if None, z dimension is attempted to be automatically determined. Defaults to None. | |
| normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel; | |
| can also pass dictionary of parameters (all keys are optional, default values shown): | |
| - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored) | |
| - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels | |
| - "normalize"=True ; run normalization (if False, all following parameters ignored) | |
| - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high] | |
| - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100) | |
| - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode. | |
| Defaults to True. | |
| rescale (float, optional): resize factor for each image, if None, set to 1.0; | |
| (only used if diameter is None). Defaults to None. | |
| diameter (float, optional): diameter for each image, | |
| if diameter is None, set to diam_mean or diam_train if available. Defaults to None. | |
| tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1. | |
| Returns: | |
| list: A list of 2D/3D arrays of restored images | |
| """ | |
| if isinstance(x, list) or x.squeeze().ndim == 5: | |
| tqdm_out = utils.TqdmToLogger(denoise_logger, level=logging.INFO) | |
| nimg = len(x) | |
| iterator = trange(nimg, file=tqdm_out, | |
| mininterval=30) if nimg > 1 else range(nimg) | |
| imgs = [] | |
| for i in iterator: | |
| imgi = self.eval( | |
| x[i], batch_size=batch_size, | |
| channels=channels[i] if channels is not None and | |
| ((len(channels) == len(x) and | |
| (isinstance(channels[i], list) or | |
| isinstance(channels[i], np.ndarray)) and len(channels[i]) == 2)) | |
| else channels, channel_axis=channel_axis, z_axis=z_axis, | |
| normalize=normalize, | |
| do_3D=do_3D, | |
| rescale=rescale[i] if isinstance(rescale, list) or | |
| isinstance(rescale, np.ndarray) else rescale, | |
| diameter=diameter[i] if isinstance(diameter, list) or | |
| isinstance(diameter, np.ndarray) else diameter, | |
| tile_overlap=tile_overlap, bsize=bsize) | |
| imgs.append(imgi) | |
| if isinstance(x, np.ndarray): | |
| imgs = np.array(imgs) | |
| return imgs | |
| else: | |
| # reshape image | |
| x = transforms.convert_image(x, channels, channel_axis=channel_axis, | |
| z_axis=z_axis, do_3D=do_3D, nchan=None) | |
| if x.ndim < 4: | |
| squeeze = True | |
| x = x[np.newaxis, ...] | |
| else: | |
| squeeze = False | |
| # may need to interpolate image before running upsampling | |
| self.ratio = 1. | |
| if "upsample" in self.pretrained_model: | |
| Ly, Lx = x.shape[-3:-1] | |
| if diameter is not None and 3 <= diameter < self.diam_mean: | |
| self.ratio = self.diam_mean / diameter | |
| denoise_logger.info( | |
| f"upsampling image to {self.diam_mean} pixel diameter ({self.ratio:0.2f} times)" | |
| ) | |
| Lyr, Lxr = int(Ly * self.ratio), int(Lx * self.ratio) | |
| x = transforms.resize_image(x, Ly=Lyr, Lx=Lxr) | |
| else: | |
| denoise_logger.warning( | |
| f"not interpolating image before upsampling because diameter is set >= {self.diam_mean}" | |
| ) | |
| #raise ValueError(f"diameter is set to {diameter}, needs to be >=3 and < {self.dn.diam_mean}") | |
| self.batch_size = batch_size | |
| if diameter is not None and diameter > 0: | |
| rescale = self.diam_mean / diameter | |
| elif rescale is None: | |
| rescale = 1.0 | |
| if np.ptp(x[..., -1]) < 1e-3 or (channels is not None and channels[-1] == 0): | |
| x = x[..., :1] | |
| for c in range(x.shape[-1]): | |
| rescale0 = rescale * 30. / 17. if c == 1 else rescale | |
| if c == 0 or self.net_chan2 is None: | |
| x[..., | |
| c] = self._eval(self.net, x[..., c:c + 1], batch_size=batch_size, | |
| normalize=normalize, rescale=rescale0, | |
| tile_overlap=tile_overlap, bsize=bsize)[...,0] | |
| else: | |
| x[..., | |
| c] = self._eval(self.net_chan2, x[..., | |
| c:c + 1], batch_size=batch_size, | |
| normalize=normalize, rescale=rescale0, | |
| tile_overlap=tile_overlap, bsize=bsize)[...,0] | |
| x = x[0] if squeeze else x | |
| return x | |
| def _eval(self, net, x, batch_size=8, normalize=True, rescale=None, | |
| tile_overlap=0.1, bsize=224): | |
| """ | |
| Run image restoration model on a single channel. | |
| Args: | |
| x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images | |
| batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU | |
| (can make smaller or bigger depending on GPU memory usage). Defaults to 8. | |
| normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel; | |
| can also pass dictionary of parameters (all keys are optional, default values shown): | |
| - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored) | |
| - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels | |
| - "normalize"=True ; run normalization (if False, all following parameters ignored) | |
| - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high] | |
| - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100) | |
| - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode. | |
| Defaults to True. | |
| rescale (float, optional): resize factor for each image, if None, set to 1.0; | |
| (only used if diameter is None). Defaults to None. | |
| tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1. | |
| Returns: | |
| list: A list of 2D/3D arrays of restored images | |
| """ | |
| if isinstance(normalize, dict): | |
| normalize_params = {**normalize_default, **normalize} | |
| elif not isinstance(normalize, bool): | |
| raise ValueError("normalize parameter must be a bool or a dict") | |
| else: | |
| normalize_params = normalize_default | |
| normalize_params["normalize"] = normalize | |
| tic = time.time() | |
| shape = x.shape | |
| nimg = shape[0] | |
| do_normalization = True if normalize_params["normalize"] else False | |
| img = np.asarray(x) | |
| if do_normalization: | |
| img = transforms.normalize_img(img, **normalize_params) | |
| if rescale != 1.0: | |
| img = transforms.resize_image(img, rsz=rescale) | |
| yf, style = run_net(self.net, img, bsize=bsize, | |
| tile_overlap=tile_overlap) | |
| yf = transforms.resize_image(yf, shape[1], shape[2]) | |
| imgs = yf | |
| del yf, style | |
| # imgs = np.zeros((*x.shape[:-1], 1), np.float32) | |
| # for i in iterator: | |
| # img = np.asarray(x[i]) | |
| # if do_normalization: | |
| # img = transforms.normalize_img(img, **normalize_params) | |
| # if rescale != 1.0: | |
| # img = transforms.resize_image(img, rsz=[rescale, rescale]) | |
| # if img.ndim == 2: | |
| # img = img[:, :, np.newaxis] | |
| # yf, style = run_net(net, img, batch_size=batch_size, augment=False, | |
| # tile=tile, tile_overlap=tile_overlap, bsize=bsize) | |
| # img = transforms.resize_image(yf, Ly=x.shape[-3], Lx=x.shape[-2]) | |
| # if img.ndim == 2: | |
| # img = img[:, :, np.newaxis] | |
| # imgs[i] = img | |
| # del yf, style | |
| net_time = time.time() - tic | |
| if nimg > 1: | |
| denoise_logger.info("imgs denoised in %2.2fs" % (net_time)) | |
| return imgs | |
| def train(net, train_data=None, train_labels=None, train_files=None, test_data=None, | |
| test_labels=None, test_files=None, train_probs=None, test_probs=None, | |
| lam=[1., 1.5, 0.], scale_range=0.5, seg_model_type="cyto2", save_path=None, | |
| save_every=100, save_each=False, poisson=0.7, beta=0.7, blur=0.7, gblur=1.0, | |
| iso=True, uniform_blur=False, downsample=0., ds_max=7, | |
| learning_rate=0.005, n_epochs=500, | |
| weight_decay=0.00001, batch_size=8, nimg_per_epoch=None, | |
| nimg_test_per_epoch=None, model_name=None): | |
| # net properties | |
| device = net.device | |
| nchan = net.nchan | |
| diam_mean = net.diam_mean.item() | |
| args = np.array([poisson, beta, blur, gblur, downsample]) | |
| if args.ndim == 1: | |
| args = args[:, np.newaxis] | |
| poisson, beta, blur, gblur, downsample = args | |
| nnoise = len(poisson) | |
| d = datetime.datetime.now() | |
| if save_path is not None: | |
| if model_name is None: | |
| filename = "" | |
| lstrs = ["per", "seg", "rec"] | |
| for k, (l, s) in enumerate(zip(lam, lstrs)): | |
| filename += f"{s}_{l:.2f}_" | |
| if not iso: | |
| filename += "aniso_" | |
| if poisson.sum() > 0: | |
| filename += "poisson_" | |
| if blur.sum() > 0: | |
| filename += "blur_" | |
| if downsample.sum() > 0: | |
| filename += "downsample_" | |
| filename += d.strftime("%Y_%m_%d_%H_%M_%S.%f") | |
| filename = os.path.join(save_path, filename) | |
| else: | |
| filename = os.path.join(save_path, model_name) | |
| print(filename) | |
| for i in range(len(poisson)): | |
| denoise_logger.info( | |
| f"poisson: {poisson[i]: 0.2f}, beta: {beta[i]: 0.2f}, blur: {blur[i]: 0.2f}, gblur: {gblur[i]: 0.2f}, downsample: {downsample[i]: 0.2f}" | |
| ) | |
| net1 = one_chan_cellpose(device=device, pretrained_model=seg_model_type) | |
| learning_rate_const = learning_rate | |
| LR = np.linspace(0, learning_rate_const, 10) | |
| LR = np.append(LR, learning_rate_const * np.ones(n_epochs - 100)) | |
| for i in range(10): | |
| LR = np.append(LR, LR[-1] / 2 * np.ones(10)) | |
| learning_rate = LR | |
| batch_size = 8 | |
| optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate[0], | |
| weight_decay=weight_decay) | |
| if train_data is not None: | |
| nimg = len(train_data) | |
| diam_train = np.array( | |
| [utils.diameters(train_labels[k])[0] for k in trange(len(train_labels))]) | |
| diam_train[diam_train < 5] = 5. | |
| if test_data is not None: | |
| diam_test = np.array( | |
| [utils.diameters(test_labels[k])[0] for k in trange(len(test_labels))]) | |
| diam_test[diam_test < 5] = 5. | |
| nimg_test = len(test_data) | |
| else: | |
| nimg = len(train_files) | |
| denoise_logger.info(">>> using files instead of loading dataset") | |
| train_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in train_files] | |
| denoise_logger.info(">>> computing diameters") | |
| diam_train = np.array([ | |
| utils.diameters(io.imread(train_labels_files[k])[0])[0] | |
| for k in trange(len(train_labels_files)) | |
| ]) | |
| diam_train[diam_train < 5] = 5. | |
| if test_files is not None: | |
| nimg_test = len(test_files) | |
| test_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in test_files] | |
| diam_test = np.array([ | |
| utils.diameters(io.imread(test_labels_files[k])[0])[0] | |
| for k in trange(len(test_labels_files)) | |
| ]) | |
| diam_test[diam_test < 5] = 5. | |
| train_probs = 1. / nimg * np.ones(nimg, | |
| "float64") if train_probs is None else train_probs | |
| if test_files is not None or test_data is not None: | |
| test_probs = 1. / nimg_test * np.ones( | |
| nimg_test, "float64") if test_probs is None else test_probs | |
| tic = time.time() | |
| nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch | |
| if test_files is not None or test_data is not None: | |
| nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch | |
| nbatch = 0 | |
| train_losses, test_losses = [], [] | |
| for iepoch in range(n_epochs): | |
| np.random.seed(iepoch) | |
| rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,), | |
| p=train_probs) | |
| torch.manual_seed(iepoch) | |
| np.random.seed(iepoch) | |
| for param_group in optimizer.param_groups: | |
| param_group["lr"] = learning_rate[iepoch] | |
| lavg, lavg_per, nsum = 0, 0, 0 | |
| for ibatch in range(0, nimg_per_epoch, batch_size * nnoise): | |
| inds = rperm[ibatch : ibatch + batch_size * nnoise] | |
| if train_data is None: | |
| imgs = [np.maximum(0, io.imread(train_files[i])[:nchan]) for i in inds] | |
| lbls = [io.imread(train_labels_files[i])[1:] for i in inds] | |
| else: | |
| imgs = [train_data[i][:nchan] for i in inds] | |
| lbls = [train_labels[i][1:] for i in inds] | |
| #inoise = nbatch % nnoise | |
| rnoise = np.random.permutation(nnoise) | |
| for i, inoise in enumerate(rnoise): | |
| if i * batch_size < len(imgs): | |
| imgi, lbli, scale = random_rotate_and_resize_noise( | |
| imgs[i * batch_size : (i + 1) * batch_size], | |
| lbls[i * batch_size : (i + 1) * batch_size], | |
| diam_train[inds][i * batch_size : (i + 1) * batch_size].copy(), | |
| poisson=poisson[inoise], | |
| beta=beta[inoise], gblur=gblur[inoise], blur=blur[inoise], iso=iso, | |
| downsample=downsample[inoise], uniform_blur=uniform_blur, | |
| diam_mean=diam_mean, ds_max=ds_max, | |
| device=device) | |
| if i == 0: | |
| img = imgi | |
| lbl = lbli | |
| else: | |
| img = torch.cat((img, imgi), axis=0) | |
| lbl = torch.cat((lbl, lbli), axis=0) | |
| if nnoise > 0: | |
| iperm = np.random.permutation(img.shape[0]) | |
| img, lbl = img[iperm], lbl[iperm] | |
| for i in range(nnoise): | |
| optimizer.zero_grad() | |
| imgi = img[i * batch_size: (i + 1) * batch_size] | |
| lbli = lbl[i * batch_size: (i + 1) * batch_size] | |
| if imgi.shape[0] > 0: | |
| loss, loss_per = train_loss(net, imgi[:, :nchan], net1=net1, | |
| img=imgi[:, nchan:], lbl=lbli, lam=lam) | |
| loss.backward() | |
| optimizer.step() | |
| lavg += loss.item() * imgi.shape[0] | |
| lavg_per += loss_per.item() * imgi.shape[0] | |
| nsum += len(img) | |
| nbatch += 1 | |
| if iepoch % 5 == 0 or iepoch < 10: | |
| lavg = lavg / nsum | |
| lavg_per = lavg_per / nsum | |
| if test_data is not None or test_files is not None: | |
| lavgt, nsum = 0., 0 | |
| np.random.seed(42) | |
| rperm = np.random.choice(np.arange(0, nimg_test), | |
| size=(nimg_test_per_epoch,), p=test_probs) | |
| inoise = iepoch % nnoise | |
| torch.manual_seed(inoise) | |
| for ibatch in range(0, nimg_test_per_epoch, batch_size): | |
| inds = rperm[ibatch:ibatch + batch_size] | |
| if test_data is None: | |
| imgs = [ | |
| np.maximum(0, | |
| io.imread(test_files[i])[:nchan]) for i in inds | |
| ] | |
| lbls = [io.imread(test_labels_files[i])[1:] for i in inds] | |
| else: | |
| imgs = [test_data[i][:nchan] for i in inds] | |
| lbls = [test_labels[i][1:] for i in inds] | |
| img, lbl, scale = random_rotate_and_resize_noise( | |
| imgs, lbls, diam_test[inds].copy(), poisson=poisson[inoise], | |
| beta=beta[inoise], blur=blur[inoise], gblur=gblur[inoise], | |
| iso=iso, downsample=downsample[inoise], uniform_blur=uniform_blur, | |
| diam_mean=diam_mean, ds_max=ds_max, device=device) | |
| loss, loss_per = test_loss(net, img[:, :nchan], net1=net1, | |
| img=img[:, nchan:], lbl=lbl, lam=lam) | |
| lavgt += loss.item() * img.shape[0] | |
| nsum += len(img) | |
| lavgt = lavgt / nsum | |
| denoise_logger.info( | |
| "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, Loss Test %0.3f, LR %2.4f" | |
| % (iepoch, time.time() - tic, lavg, lavg_per, lavgt, | |
| learning_rate[iepoch])) | |
| test_losses.append(lavgt) | |
| else: | |
| denoise_logger.info( | |
| "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, LR %2.4f" % | |
| (iepoch, time.time() - tic, lavg, lavg_per, learning_rate[iepoch])) | |
| train_losses.append(lavg) | |
| if save_path is not None: | |
| if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0): | |
| if save_each: #separate files as model progresses | |
| filename0 = str(filename) + f"_epoch_{iepoch:%04d}" | |
| else: | |
| filename0 = filename | |
| denoise_logger.info(f"saving network parameters to {filename0}") | |
| net.save_model(filename0) | |
| else: | |
| filename = save_path | |
| return filename, train_losses, test_losses | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="cellpose parameters") | |
| input_img_args = parser.add_argument_group("input image arguments") | |
| input_img_args.add_argument("--dir", default=[], type=str, | |
| help="folder containing data to run or train on.") | |
| input_img_args.add_argument("--img_filter", default=[], type=str, | |
| help="end string for images to run on") | |
| model_args = parser.add_argument_group("model arguments") | |
| model_args.add_argument("--pretrained_model", default=[], type=str, | |
| help="pretrained denoising model") | |
| training_args = parser.add_argument_group("training arguments") | |
| training_args.add_argument("--test_dir", default=[], type=str, | |
| help="folder containing test data (optional)") | |
| training_args.add_argument("--file_list", default=[], type=str, | |
| help="npy file containing list of train and test files") | |
| training_args.add_argument("--seg_model_type", default="cyto2", type=str, | |
| help="model to use for seg training loss") | |
| training_args.add_argument( | |
| "--noise_type", default=[], type=str, | |
| help="noise type to use (if input, then other noise params are ignored)") | |
| training_args.add_argument("--poisson", default=0.8, type=float, | |
| help="fraction of images to add poisson noise to") | |
| training_args.add_argument("--beta", default=0.7, type=float, | |
| help="scale of poisson noise") | |
| training_args.add_argument("--blur", default=0., type=float, | |
| help="fraction of images to blur") | |
| training_args.add_argument("--gblur", default=1.0, type=float, | |
| help="scale of gaussian blurring stddev") | |
| training_args.add_argument("--downsample", default=0., type=float, | |
| help="fraction of images to downsample") | |
| training_args.add_argument("--ds_max", default=7, type=int, | |
| help="max downsampling factor") | |
| training_args.add_argument("--lam_per", default=1.0, type=float, | |
| help="weighting of perceptual loss") | |
| training_args.add_argument("--lam_seg", default=1.5, type=float, | |
| help="weighting of segmentation loss") | |
| training_args.add_argument("--lam_rec", default=0., type=float, | |
| help="weighting of reconstruction loss") | |
| training_args.add_argument( | |
| "--diam_mean", default=30., type=float, help= | |
| "mean diameter to resize cells to during training -- if starting from pretrained models it cannot be changed from 30.0" | |
| ) | |
| training_args.add_argument("--learning_rate", default=0.001, type=float, | |
| help="learning rate. Default: %(default)s") | |
| training_args.add_argument("--n_epochs", default=2000, type=int, | |
| help="number of epochs. Default: %(default)s") | |
| training_args.add_argument( | |
| "--save_each", default=False, action="store_true", | |
| help="save each epoch as separate model") | |
| training_args.add_argument( | |
| "--nimg_per_epoch", default=0, type=int, | |
| help="number of images per epoch. Default is length of training images") | |
| training_args.add_argument( | |
| "--nimg_test_per_epoch", default=0, type=int, | |
| help="number of test images per epoch. Default is length of testing images") | |
| io.logger_setup() | |
| args = parser.parse_args() | |
| lams = [args.lam_per, args.lam_seg, args.lam_rec] | |
| print("lam", lams) | |
| if len(args.noise_type) > 0: | |
| noise_type = args.noise_type | |
| uniform_blur = False | |
| iso = True | |
| if noise_type == "poisson": | |
| poisson = 0.8 | |
| blur = 0. | |
| downsample = 0. | |
| beta = 0.7 | |
| gblur = 1.0 | |
| elif noise_type == "blur_expr": | |
| poisson = 0.8 | |
| blur = 0.8 | |
| downsample = 0. | |
| beta = 0.1 | |
| gblur = 0.5 | |
| elif noise_type == "blur": | |
| poisson = 0.8 | |
| blur = 0.8 | |
| downsample = 0. | |
| beta = 0.1 | |
| gblur = 10.0 | |
| uniform_blur = True | |
| elif noise_type == "downsample_expr": | |
| poisson = 0.8 | |
| blur = 0.8 | |
| downsample = 0.8 | |
| beta = 0.03 | |
| gblur = 1.0 | |
| elif noise_type == "downsample": | |
| poisson = 0.8 | |
| blur = 0.8 | |
| downsample = 0.8 | |
| beta = 0.03 | |
| gblur = 5.0 | |
| uniform_blur = True | |
| elif noise_type == "all": | |
| poisson = [0.8, 0.8, 0.8] | |
| blur = [0., 0.8, 0.8] | |
| downsample = [0., 0., 0.8] | |
| beta = [0.7, 0.1, 0.03] | |
| gblur = [0., 10.0, 5.0] | |
| uniform_blur = True | |
| elif noise_type == "aniso": | |
| poisson = 0.8 | |
| blur = 0.8 | |
| downsample = 0.8 | |
| beta = 0.1 | |
| gblur = args.ds_max * 1.5 | |
| iso = False | |
| else: | |
| raise ValueError(f"{noise_type} noise_type is not supported") | |
| else: | |
| poisson, beta = args.poisson, args.beta | |
| blur, gblur = args.blur, args.gblur | |
| downsample = args.downsample | |
| pretrained_model = None if len( | |
| args.pretrained_model) == 0 else args.pretrained_model | |
| model = DenoiseModel(gpu=True, nchan=1, diam_mean=args.diam_mean, | |
| pretrained_model=pretrained_model) | |
| train_data, labels, train_files, train_probs = None, None, None, None | |
| test_data, test_labels, test_files, test_probs = None, None, None, None | |
| if len(args.file_list) == 0: | |
| output = io.load_train_test_data(args.dir, args.test_dir, "_img", "_masks", 0) | |
| images, labels, image_names, test_images, test_labels, image_names_test = output | |
| train_data = [] | |
| for i in range(len(images)): | |
| img = images[i].astype("float32") | |
| if img.ndim > 2: | |
| img = img[0] | |
| train_data.append( | |
| np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :]) | |
| if len(args.test_dir) > 0: | |
| test_data = [] | |
| for i in range(len(test_images)): | |
| img = test_images[i].astype("float32") | |
| if img.ndim > 2: | |
| img = img[0] | |
| test_data.append( | |
| np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :]) | |
| save_path = os.path.join(args.dir, "../models/") | |
| else: | |
| root = args.dir | |
| denoise_logger.info( | |
| ">>> using file_list (assumes images are normalized and have flows!)") | |
| dat = np.load(args.file_list, allow_pickle=True).item() | |
| train_files = dat["train_files"] | |
| test_files = dat["test_files"] | |
| train_probs = dat["train_probs"] if "train_probs" in dat else None | |
| test_probs = dat["test_probs"] if "test_probs" in dat else None | |
| if str(train_files[0])[:len(str(root))] != str(root): | |
| for i in range(len(train_files)): | |
| new_path = root / Path(*train_files[i].parts[-3:]) | |
| if i == 0: | |
| print(f"changing path from {train_files[i]} to {new_path}") | |
| train_files[i] = new_path | |
| for i in range(len(test_files)): | |
| new_path = root / Path(*test_files[i].parts[-3:]) | |
| test_files[i] = new_path | |
| save_path = os.path.join(args.dir, "models/") | |
| os.makedirs(save_path, exist_ok=True) | |
| nimg_per_epoch = None if args.nimg_per_epoch == 0 else args.nimg_per_epoch | |
| nimg_test_per_epoch = None if args.nimg_test_per_epoch == 0 else args.nimg_test_per_epoch | |
| model_path = train( | |
| model.net, train_data=train_data, train_labels=labels, train_files=train_files, | |
| test_data=test_data, test_labels=test_labels, test_files=test_files, | |
| train_probs=train_probs, test_probs=test_probs, poisson=poisson, beta=beta, | |
| blur=blur, gblur=gblur, downsample=downsample, ds_max=args.ds_max, | |
| iso=iso, uniform_blur=uniform_blur, n_epochs=args.n_epochs, | |
| learning_rate=args.learning_rate, | |
| lam=lams, | |
| seg_model_type=args.seg_model_type, nimg_per_epoch=nimg_per_epoch, | |
| nimg_test_per_epoch=nimg_test_per_epoch, save_path=save_path) | |
| def seg_train_noisy(model, train_data, train_labels, test_data=None, test_labels=None, | |
| poisson=0.8, blur=0.0, downsample=0.0, save_path=None, | |
| save_every=100, save_each=False, learning_rate=0.2, n_epochs=500, | |
| momentum=0.9, weight_decay=0.00001, SGD=True, batch_size=8, | |
| nimg_per_epoch=None, diameter=None, rescale=True, z_masking=False, | |
| model_name=None): | |
| """ train function uses loss function model.loss_fn in models.py | |
| (data should already be normalized) | |
| """ | |
| d = datetime.datetime.now() | |
| model.n_epochs = n_epochs | |
| if isinstance(learning_rate, (list, np.ndarray)): | |
| if isinstance(learning_rate, np.ndarray) and learning_rate.ndim > 1: | |
| raise ValueError("learning_rate.ndim must equal 1") | |
| elif len(learning_rate) != n_epochs: | |
| raise ValueError( | |
| "if learning_rate given as list or np.ndarray it must have length n_epochs" | |
| ) | |
| model.learning_rate = learning_rate | |
| model.learning_rate_const = mode(learning_rate)[0][0] | |
| else: | |
| model.learning_rate_const = learning_rate | |
| # set learning rate schedule | |
| if SGD: | |
| LR = np.linspace(0, model.learning_rate_const, 10) | |
| if model.n_epochs > 250: | |
| LR = np.append( | |
| LR, model.learning_rate_const * np.ones(model.n_epochs - 100)) | |
| for i in range(10): | |
| LR = np.append(LR, LR[-1] / 2 * np.ones(10)) | |
| else: | |
| LR = np.append( | |
| LR, | |
| model.learning_rate_const * np.ones(max(0, model.n_epochs - 10))) | |
| else: | |
| LR = model.learning_rate_const * np.ones(model.n_epochs) | |
| model.learning_rate = LR | |
| model.batch_size = batch_size | |
| model._set_optimizer(model.learning_rate[0], momentum, weight_decay, SGD) | |
| model._set_criterion() | |
| nimg = len(train_data) | |
| # compute average cell diameter | |
| if diameter is None: | |
| diam_train = np.array( | |
| [utils.diameters(train_labels[k][0])[0] for k in range(len(train_labels))]) | |
| diam_train_mean = diam_train[diam_train > 0].mean() | |
| model.diam_labels = diam_train_mean | |
| if rescale: | |
| diam_train[diam_train < 5] = 5. | |
| if test_data is not None: | |
| diam_test = np.array([ | |
| utils.diameters(test_labels[k][0])[0] | |
| for k in range(len(test_labels)) | |
| ]) | |
| diam_test[diam_test < 5] = 5. | |
| denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean) | |
| elif rescale: | |
| diam_train_mean = diameter | |
| model.diam_labels = diameter | |
| denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean) | |
| diam_train = diameter * np.ones(len(train_labels), "float32") | |
| if test_data is not None: | |
| diam_test = diameter * np.ones(len(test_labels), "float32") | |
| denoise_logger.info( | |
| f">>>> mean of training label mask diameters (saved to model) {diam_train_mean:.3f}" | |
| ) | |
| model.net.diam_labels.data = torch.ones(1, device=model.device) * diam_train_mean | |
| nchan = train_data[0].shape[0] | |
| denoise_logger.info(">>>> training network with %d channel input <<<<" % nchan) | |
| denoise_logger.info(">>>> LR: %0.5f, batch_size: %d, weight_decay: %0.5f" % | |
| (model.learning_rate_const, model.batch_size, weight_decay)) | |
| if test_data is not None: | |
| denoise_logger.info(f">>>> ntrain = {nimg}, ntest = {len(test_data)}") | |
| else: | |
| denoise_logger.info(f">>>> ntrain = {nimg}") | |
| tic = time.time() | |
| lavg, nsum = 0, 0 | |
| if save_path is not None: | |
| _, file_label = os.path.split(save_path) | |
| file_path = os.path.join(save_path, "models/") | |
| if not os.path.exists(file_path): | |
| os.makedirs(file_path) | |
| else: | |
| denoise_logger.warning("WARNING: no save_path given, model not saving") | |
| ksave = 0 | |
| # cannot train with mkldnn | |
| model.net.mkldnn = False | |
| # get indices for each epoch for training | |
| np.random.seed(0) | |
| inds_all = np.zeros((0,), "int32") | |
| if nimg_per_epoch is None or nimg > nimg_per_epoch: | |
| nimg_per_epoch = nimg | |
| denoise_logger.info(f">>>> nimg_per_epoch = {nimg_per_epoch}") | |
| while len(inds_all) < n_epochs * nimg_per_epoch: | |
| rperm = np.random.permutation(nimg) | |
| inds_all = np.hstack((inds_all, rperm)) | |
| for iepoch in range(model.n_epochs): | |
| if SGD: | |
| model._set_learning_rate(model.learning_rate[iepoch]) | |
| np.random.seed(iepoch) | |
| rperm = inds_all[iepoch * nimg_per_epoch:(iepoch + 1) * nimg_per_epoch] | |
| for ibatch in range(0, nimg_per_epoch, batch_size): | |
| inds = rperm[ibatch:ibatch + batch_size] | |
| imgi, lbl, scale = random_rotate_and_resize_noise( | |
| [train_data[i] for i in inds], [train_labels[i][1:] for i in inds], | |
| poisson=poisson, blur=blur, downsample=downsample, | |
| diams=diam_train[inds], diam_mean=model.diam_mean) | |
| imgi = imgi[:, :1] # keep noisy only | |
| if z_masking: | |
| nc = imgi.shape[1] | |
| nb = imgi.shape[0] | |
| ncmin = (np.random.rand(nb) > 0.25) * (np.random.randint( | |
| nc // 2 - 1, size=nb)) | |
| ncmax = nc - (np.random.rand(nb) > 0.25) * (np.random.randint( | |
| nc // 2 - 1, size=nb)) | |
| for b in range(nb): | |
| imgi[b, :ncmin[b]] = 0 | |
| imgi[b, ncmax[b]:] = 0 | |
| train_loss = model._train_step(imgi, lbl) | |
| lavg += train_loss | |
| nsum += len(imgi) | |
| if iepoch % 10 == 0 or iepoch == 5: | |
| lavg = lavg / nsum | |
| if test_data is not None: | |
| lavgt, nsum = 0., 0 | |
| np.random.seed(42) | |
| rperm = np.arange(0, len(test_data), 1, int) | |
| for ibatch in range(0, len(test_data), batch_size): | |
| inds = rperm[ibatch:ibatch + batch_size] | |
| imgi, lbl, scale = random_rotate_and_resize_noise( | |
| [test_data[i] for i in inds], | |
| [test_labels[i][1:] for i in inds], poisson=poisson, blur=blur, | |
| downsample=downsample, diams=diam_test[inds], | |
| diam_mean=model.diam_mean) | |
| imgi = imgi[:, :1] # keep noisy only | |
| test_loss = model._test_eval(imgi, lbl) | |
| lavgt += test_loss | |
| nsum += len(imgi) | |
| denoise_logger.info( | |
| "Epoch %d, Time %4.1fs, Loss %2.4f, Loss Test %2.4f, LR %2.4f" % | |
| (iepoch, time.time() - tic, lavg, lavgt / nsum, | |
| model.learning_rate[iepoch])) | |
| else: | |
| denoise_logger.info( | |
| "Epoch %d, Time %4.1fs, Loss %2.4f, LR %2.4f" % | |
| (iepoch, time.time() - tic, lavg, model.learning_rate[iepoch])) | |
| lavg, nsum = 0, 0 | |
| if save_path is not None: | |
| if iepoch == model.n_epochs - 1 or iepoch % save_every == 1: | |
| # save model at the end | |
| if save_each: #separate files as model progresses | |
| if model_name is None: | |
| filename = "{}_{}_{}_{}".format( | |
| model.net_type, file_label, | |
| d.strftime("%Y_%m_%d_%H_%M_%S.%f"), "epoch_" + str(iepoch)) | |
| else: | |
| filename = "{}_{}".format(model_name, "epoch_" + str(iepoch)) | |
| else: | |
| if model_name is None: | |
| filename = "{}_{}_{}".format(model.net_type, file_label, | |
| d.strftime("%Y_%m_%d_%H_%M_%S.%f")) | |
| else: | |
| filename = model_name | |
| filename = os.path.join(file_path, filename) | |
| ksave += 1 | |
| denoise_logger.info(f"saving network parameters to {filename}") | |
| model.net.save_model(filename) | |
| else: | |
| filename = save_path | |
| # reset to mkldnn if available | |
| model.net.mkldnn = model.mkldnn | |
| return filename | |