| """Regionprops features and its augmentations. |
| WindowedRegionFeatures (WRFeatures) is a class that holds regionprops features for a windowed track region. |
| Modified from Trackastra (https://github.com/weigertlab/trackastra) |
| """ |
|
|
| import itertools |
| import logging |
| from collections import OrderedDict |
| from collections.abc import Iterable |
| from functools import reduce |
| from typing import Literal |
|
|
| import joblib |
| import numpy as np |
| import pandas as pd |
| from edt import edt |
| from skimage.measure import regionprops, regionprops_table |
| from tqdm import tqdm |
| from typing import Tuple, Optional, Sequence, Union, List |
| import typing |
| import torch |
| logger = logging.getLogger(__name__) |
|
|
| _PROPERTIES = { |
| "regionprops": ( |
| "area", |
| "intensity_mean", |
| "intensity_max", |
| "intensity_min", |
| "inertia_tensor", |
| ), |
| "regionprops2": ( |
| "equivalent_diameter_area", |
| "intensity_mean", |
| "inertia_tensor", |
| "border_dist", |
| ), |
| } |
|
|
|
|
| def _border_dist_fast(mask: np.ndarray, cutoff: float = 5): |
| cutoff = int(cutoff) |
| border = np.ones(mask.shape, dtype=np.float32) |
| ndim = len(mask.shape) |
|
|
| for axis, size in enumerate(mask.shape): |
| |
| band_vals = np.arange(cutoff, dtype=np.float32) / cutoff |
|
|
| |
| low_slices = [slice(None)] * ndim |
| low_slices[axis] = slice(0, cutoff) |
| border_low = border[tuple(low_slices)] |
| border_low_vals = np.minimum( |
| border_low, band_vals[(...,) + (None,) * (ndim - axis - 1)] |
| ) |
| border[tuple(low_slices)] = border_low_vals |
|
|
| |
| high_slices = [slice(None)] * ndim |
| high_slices[axis] = slice(size - cutoff, size) |
| band_vals_rev = band_vals[::-1] |
| border_high = border[tuple(high_slices)] |
| border_high_vals = np.minimum( |
| border_high, band_vals_rev[(...,) + (None,) * (ndim - axis - 1)] |
| ) |
| border[tuple(high_slices)] = border_high_vals |
|
|
| dist = 1 - border |
| return tuple(r.intensity_max for r in regionprops(mask, intensity_image=dist)) |
|
|
|
|
| class WRFeatures: |
| """regionprops features for a windowed track region.""" |
|
|
| def __init__( |
| self, |
| coords: np.ndarray, |
| labels: np.ndarray, |
| timepoints: np.ndarray, |
| features: typing.OrderedDict[str, np.ndarray], |
| ): |
| self.ndim = coords.shape[-1] |
| if self.ndim not in (2, 3): |
| raise ValueError("Only 2D or 3D data is supported") |
|
|
| self.coords = coords |
| self.labels = labels |
| self.features = features.copy() |
| self.timepoints = timepoints |
|
|
| def __repr__(self): |
| s = ( |
| f"WindowRegionFeatures(ndim={self.ndim}, nregions={len(self.labels)}," |
| f" ntimepoints={len(np.unique(self.timepoints))})\n\n" |
| ) |
| for k, v in self.features.items(): |
| s += f"{k:>20} -> {v.shape}\n" |
| return s |
|
|
| @property |
| def features_stacked(self): |
| return np.concatenate([v for k, v in self.features.items()], axis=-1) |
|
|
| def __len__(self): |
| return len(self.labels) |
|
|
| def __getitem__(self, key): |
| if key in self.features: |
| return self.features[key] |
| else: |
| raise KeyError(f"Key {key} not found in features") |
|
|
| @classmethod |
| def concat(cls, feats: Sequence["WRFeatures"]) -> "WRFeatures": |
| """Concatenate multiple WRFeatures into a single one.""" |
| if len(feats) == 0: |
| raise ValueError("Cannot concatenate empty list of features") |
| return reduce(lambda x, y: x + y, feats) |
|
|
| def __add__(self, other: "WRFeatures") -> "WRFeatures": |
| """Concatenate two WRFeatures.""" |
| if self.ndim != other.ndim: |
| raise ValueError("Cannot concatenate features of different dimensions") |
| if self.features.keys() != other.features.keys(): |
| raise ValueError("Cannot concatenate features with different properties") |
|
|
| coords = np.concatenate([self.coords, other.coords], axis=0) |
| labels = np.concatenate([self.labels, other.labels], axis=0) |
| timepoints = np.concatenate([self.timepoints, other.timepoints], axis=0) |
|
|
| features = OrderedDict( |
| (k, np.concatenate([v, other.features[k]], axis=0)) |
| for k, v in self.features.items() |
| ) |
|
|
| return WRFeatures( |
| coords=coords, labels=labels, timepoints=timepoints, features=features |
| ) |
|
|
| @classmethod |
| def from_mask_img( |
| cls, |
| mask: np.ndarray, |
| img: np.ndarray, |
| properties="regionprops2", |
| t_start: int = 0, |
| ): |
| img = np.asarray(img) |
| mask = np.asarray(mask) |
|
|
| _ntime, ndim = mask.shape[0], mask.ndim - 1 |
| if ndim not in (2, 3): |
| raise ValueError("Only 2D or 3D data is supported") |
|
|
| properties = tuple(_PROPERTIES[properties]) |
| if "label" in properties or "centroid" in properties: |
| raise ValueError( |
| f"label and centroid should not be in properties {properties}" |
| ) |
|
|
| if "border_dist" in properties: |
| use_border_dist = True |
| |
| properties = tuple(p for p in properties if p != "border_dist") |
| else: |
| use_border_dist = False |
|
|
| df_properties = ("label", "centroid", *properties) |
| dfs = [] |
| for i, (y, x) in enumerate(zip(mask, img)): |
| _df = pd.DataFrame( |
| regionprops_table(y, intensity_image=x, properties=df_properties) |
| ) |
| _df["timepoint"] = i + t_start |
|
|
| if use_border_dist: |
| _df["border_dist"] = _border_dist_fast(y) |
|
|
| dfs.append(_df) |
| df = pd.concat(dfs) |
|
|
| if use_border_dist: |
| properties = (*properties, "border_dist") |
|
|
| timepoints = df["timepoint"].values.astype(np.int32) |
| labels = df["label"].values.astype(np.int32) |
| coords = df[[f"centroid-{i}" for i in range(ndim)]].values.astype(np.float32) |
|
|
| features = OrderedDict( |
| ( |
| p, |
| np.stack( |
| [ |
| df[c].values.astype(np.float32) |
| for c in df.columns |
| if c.startswith(p) |
| ], |
| axis=-1, |
| ), |
| ) |
| for p in properties |
| ) |
|
|
| return cls( |
| coords=coords, labels=labels, timepoints=timepoints, features=features |
| ) |
|
|
|
|
| def get_features( |
| detections: np.ndarray, |
| imgs: Optional[np.ndarray] = None, |
| features: Literal["none", "wrfeat"] = "wrfeat", |
| ndim: int = 2, |
| n_workers=0, |
| progbar_class=tqdm, |
| ) -> List[WRFeatures]: |
| detections = _check_dimensions(detections, ndim) |
| imgs = _check_dimensions(imgs, ndim) |
| logger.info(f"Extracting features from {len(detections)} detections") |
| if n_workers > 0: |
| logger.info(f"Using {n_workers} processes for feature extraction") |
| features = joblib.Parallel(n_jobs=n_workers, backend="loky")( |
| joblib.delayed(WRFeatures.from_mask_img)( |
| |
| mask=mask[np.newaxis, ...].copy(), |
| img=img[np.newaxis, ...].copy(), |
| t_start=t, |
| ) |
| for t, (mask, img) in progbar_class( |
| enumerate(zip(detections, imgs)), |
| total=len(imgs), |
| desc="Extracting features", |
| ) |
| ) |
| else: |
| logger.info("Using single process for feature extraction") |
| features = tuple( |
| WRFeatures.from_mask_img( |
| mask=mask[np.newaxis, ...], |
| img=img[np.newaxis, ...], |
| t_start=t, |
| ) |
| for t, (mask, img) in progbar_class( |
| enumerate(zip(detections, imgs)), |
| total=len(imgs), |
| desc="Extracting features", |
| ) |
| ) |
|
|
| return features |
|
|
|
|
| def _check_dimensions(x: np.ndarray, ndim: int): |
| if ndim == 2 and not x.ndim == 3: |
| raise ValueError(f"Expected 2D data, got {x.ndim - 1}D data") |
| elif ndim == 3: |
| |
| if x.ndim == 3: |
| x = np.expand_dims(x, axis=1) |
| elif x.ndim == 4: |
| pass |
| else: |
| raise ValueError(f"Expected 3D data, got {x.ndim - 1}D data") |
| return x |
|
|
|
|
| def build_windows_sd( |
| features: List[WRFeatures], imgs_enc, imgs_stable, boxes, imgs, masks, window_size: int, progbar_class=tqdm |
| ) -> List[dict]: |
| windows = [] |
| for t1, t2 in progbar_class( |
| zip(range(0, len(features)), range(window_size, len(features) + 1)), |
| total=len(features) - window_size + 1, |
| desc="Building windows", |
| ): |
| feat = WRFeatures.concat(features[t1:t2]) |
|
|
| labels = feat.labels |
| timepoints = feat.timepoints |
| coords = feat.coords |
|
|
| if len(feat) == 0: |
| coords = np.zeros((0, feat.ndim), dtype=int) |
|
|
| w = dict( |
| coords=coords, |
| t1=t1, |
| labels=labels, |
| timepoints=timepoints, |
| features=feat.features_stacked, |
| img_enc=imgs_enc[t1:t2], |
| image_stable=imgs_stable[t1:t2], |
| boxes=boxes, |
| img=imgs[t1:t2], |
| mask=masks[t1:t2], |
| coords_t=torch.tensor(coords, dtype=torch.float32), |
| labels_t=torch.tensor(labels, dtype=torch.int32), |
| timepoints_t=torch.tensor(timepoints, dtype=torch.int64), |
| features_t=torch.tensor(feat.features_stacked, dtype=torch.float32), |
| img_t=torch.tensor(imgs[t1:t2], dtype=torch.float32), |
| mask_t=torch.tensor(masks[t1:t2], dtype=torch.int32), |
| ) |
| windows.append(w) |
|
|
| logger.debug(f"Built {len(windows)} track windows.\n") |
| return windows |
|
|
|
|