import enum from copy import deepcopy import numpy as np from skimage import img_as_ubyte from skimage.transform import rescale, resize try: from detectron2 import model_zoo from detectron2.config import get_cfg from detectron2.engine import DefaultPredictor DETECTRON_INSTALLED = True except: print("Detectron v2 is not installed") DETECTRON_INSTALLED = False from .countless.countless2d import zero_corrected_countless class ObjectMask(): def __init__(self, mask): self.height, self.width = mask.shape (self.up, self.down), (self.left, self.right) = self._get_limits(mask) self.mask = mask[self.up:self.down, self.left:self.right].copy() @staticmethod def _get_limits(mask): def indicator_limits(indicator): lower = indicator.argmax() upper = len(indicator) - indicator[::-1].argmax() return lower, upper vertical_indicator = mask.any(axis=1) vertical_limits = indicator_limits(vertical_indicator) horizontal_indicator = mask.any(axis=0) horizontal_limits = indicator_limits(horizontal_indicator) return vertical_limits, horizontal_limits def _clean(self): self.up, self.down, self.left, self.right = 0, 0, 0, 0 self.mask = np.empty((0, 0)) def horizontal_flip(self, inplace=False): if not inplace: flipped = deepcopy(self) return flipped.horizontal_flip(inplace=True) self.mask = self.mask[:, ::-1] return self def vertical_flip(self, inplace=False): if not inplace: flipped = deepcopy(self) return flipped.vertical_flip(inplace=True) self.mask = self.mask[::-1, :] return self def image_center(self): y_center = self.up + (self.down - self.up) / 2 x_center = self.left + (self.right - self.left) / 2 return y_center, x_center def rescale(self, scaling_factor, inplace=False): if not inplace: scaled = deepcopy(self) return scaled.rescale(scaling_factor, inplace=True) scaled_mask = rescale(self.mask.astype(float), scaling_factor, order=0) > 0.5 (up, down), (left, right) = self._get_limits(scaled_mask) self.mask = scaled_mask[up:down, left:right] y_center, x_center = self.image_center() mask_height, mask_width = self.mask.shape self.up = int(round(y_center - mask_height / 2)) self.down = self.up + mask_height self.left = int(round(x_center - mask_width / 2)) self.right = self.left + mask_width return self def crop_to_canvas(self, vertical=True, horizontal=True, inplace=False): if not inplace: cropped = deepcopy(self) cropped.crop_to_canvas(vertical=vertical, horizontal=horizontal, inplace=True) return cropped if vertical: if self.up >= self.height or self.down <= 0: self._clean() else: cut_up, cut_down = max(-self.up, 0), max(self.down - self.height, 0) if cut_up != 0: self.mask = self.mask[cut_up:] self.up = 0 if cut_down != 0: self.mask = self.mask[:-cut_down] self.down = self.height if horizontal: if self.left >= self.width or self.right <= 0: self._clean() else: cut_left, cut_right = max(-self.left, 0), max(self.right - self.width, 0) if cut_left != 0: self.mask = self.mask[:, cut_left:] self.left = 0 if cut_right != 0: self.mask = self.mask[:, :-cut_right] self.right = self.width return self def restore_full_mask(self, allow_crop=False): cropped = self.crop_to_canvas(inplace=allow_crop) mask = np.zeros((cropped.height, cropped.width), dtype=bool) mask[cropped.up:cropped.down, cropped.left:cropped.right] = cropped.mask return mask def shift(self, vertical=0, horizontal=0, inplace=False): if not inplace: shifted = deepcopy(self) return shifted.shift(vertical=vertical, horizontal=horizontal, inplace=True) self.up += vertical self.down += vertical self.left += horizontal self.right += horizontal return self def area(self): return self.mask.sum() class RigidnessMode(enum.Enum): soft = 0 rigid = 1 class SegmentationMask: def __init__(self, confidence_threshold=0.5, rigidness_mode=RigidnessMode.rigid, max_object_area=0.3, min_mask_area=0.02, downsample_levels=6, num_variants_per_mask=4, max_mask_intersection=0.5, max_foreground_coverage=0.5, max_foreground_intersection=0.5, max_hidden_area=0.2, max_scale_change=0.25, horizontal_flip=True, max_vertical_shift=0.1, position_shuffle=True): """ :param confidence_threshold: float; threshold for confidence of the panoptic segmentator to allow for the instance. :param rigidness_mode: RigidnessMode object when soft, checks intersection only with the object from which the mask_object was produced when rigid, checks intersection with any foreground class object :param max_object_area: float; allowed upper bound for to be considered as mask_object. :param min_mask_area: float; lower bound for mask to be considered valid :param downsample_levels: int; defines width of the resized segmentation to obtain shifted masks; :param num_variants_per_mask: int; maximal number of the masks for the same object; :param max_mask_intersection: float; maximum allowed area fraction of intersection for 2 masks produced by horizontal shift of the same mask_object; higher value -> more diversity :param max_foreground_coverage: float; maximum allowed area fraction of intersection for foreground object to be covered by mask; lower value -> less the objects are covered :param max_foreground_intersection: float; maximum allowed area of intersection for the mask with foreground object; lower value -> mask is more on the background than on the objects :param max_hidden_area: upper bound on part of the object hidden by shifting object outside the screen area; :param max_scale_change: allowed scale change for the mask_object; :param horizontal_flip: if horizontal flips are allowed; :param max_vertical_shift: amount of vertical movement allowed; :param position_shuffle: shuffle """ assert DETECTRON_INSTALLED, 'Cannot use SegmentationMask without detectron2' self.cfg = get_cfg() self.cfg.merge_from_file(model_zoo.get_config_file("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml")) self.cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml") self.cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = confidence_threshold self.predictor = DefaultPredictor(self.cfg) self.rigidness_mode = RigidnessMode(rigidness_mode) self.max_object_area = max_object_area self.min_mask_area = min_mask_area self.downsample_levels = downsample_levels self.num_variants_per_mask = num_variants_per_mask self.max_mask_intersection = max_mask_intersection self.max_foreground_coverage = max_foreground_coverage self.max_foreground_intersection = max_foreground_intersection self.max_hidden_area = max_hidden_area self.position_shuffle = position_shuffle self.max_scale_change = max_scale_change self.horizontal_flip = horizontal_flip self.max_vertical_shift = max_vertical_shift def get_segmentation(self, img): im = img_as_ubyte(img) panoptic_seg, segment_info = self.predictor(im)["panoptic_seg"] return panoptic_seg, segment_info @staticmethod def _is_power_of_two(n): return (n != 0) and (n & (n-1) == 0) def identify_candidates(self, panoptic_seg, segments_info): potential_mask_ids = [] for segment in segments_info: if not segment["isthing"]: continue mask = (panoptic_seg == segment["id"]).int().detach().cpu().numpy() area = mask.sum().item() / np.prod(panoptic_seg.shape) if area >= self.max_object_area: continue potential_mask_ids.append(segment["id"]) return potential_mask_ids def downsample_mask(self, mask): height, width = mask.shape if not (self._is_power_of_two(height) and self._is_power_of_two(width)): raise ValueError("Image sides are not power of 2.") num_iterations = width.bit_length() - 1 - self.downsample_levels if num_iterations < 0: raise ValueError(f"Width is lower than 2^{self.downsample_levels}.") if height.bit_length() - 1 < num_iterations: raise ValueError("Height is too low to perform downsampling") downsampled = mask for _ in range(num_iterations): downsampled = zero_corrected_countless(downsampled) return downsampled def _augmentation_params(self): scaling_factor = np.random.uniform(1 - self.max_scale_change, 1 + self.max_scale_change) if self.horizontal_flip: horizontal_flip = bool(np.random.choice(2)) else: horizontal_flip = False vertical_shift = np.random.uniform(-self.max_vertical_shift, self.max_vertical_shift) return { "scaling_factor": scaling_factor, "horizontal_flip": horizontal_flip, "vertical_shift": vertical_shift } def _get_intersection(self, mask_array, mask_object): intersection = mask_array[ mask_object.up:mask_object.down, mask_object.left:mask_object.right ] & mask_object.mask return intersection def _check_masks_intersection(self, aug_mask, total_mask_area, prev_masks): for existing_mask in prev_masks: intersection_area = self._get_intersection(existing_mask, aug_mask).sum() intersection_existing = intersection_area / existing_mask.sum() intersection_current = 1 - (aug_mask.area() - intersection_area) / total_mask_area if (intersection_existing > self.max_mask_intersection) or \ (intersection_current > self.max_mask_intersection): return False return True def _check_foreground_intersection(self, aug_mask, foreground): for existing_mask in foreground: intersection_area = self._get_intersection(existing_mask, aug_mask).sum() intersection_existing = intersection_area / existing_mask.sum() if intersection_existing > self.max_foreground_coverage: return False intersection_mask = intersection_area / aug_mask.area() if intersection_mask > self.max_foreground_intersection: return False return True def _move_mask(self, mask, foreground): # Obtaining properties of the original mask_object: orig_mask = ObjectMask(mask) chosen_masks = [] chosen_parameters = [] # to fix the case when resizing gives mask_object consisting only of False scaling_factor_lower_bound = 0. for var_idx in range(self.num_variants_per_mask): # Obtaining augmentation parameters and applying them to the downscaled mask_object augmentation_params = self._augmentation_params() augmentation_params["scaling_factor"] = min([ augmentation_params["scaling_factor"], 2 * min(orig_mask.up, orig_mask.height - orig_mask.down) / orig_mask.height + 1., 2 * min(orig_mask.left, orig_mask.width - orig_mask.right) / orig_mask.width + 1. ]) augmentation_params["scaling_factor"] = max([ augmentation_params["scaling_factor"], scaling_factor_lower_bound ]) aug_mask = deepcopy(orig_mask) aug_mask.rescale(augmentation_params["scaling_factor"], inplace=True) if augmentation_params["horizontal_flip"]: aug_mask.horizontal_flip(inplace=True) total_aug_area = aug_mask.area() if total_aug_area == 0: scaling_factor_lower_bound = 1. continue # Fix if the element vertical shift is too strong and shown area is too small: vertical_area = aug_mask.mask.sum(axis=1) / total_aug_area # share of area taken by rows # number of rows which are allowed to be hidden from upper and lower parts of image respectively max_hidden_up = np.searchsorted(vertical_area.cumsum(), self.max_hidden_area) max_hidden_down = np.searchsorted(vertical_area[::-1].cumsum(), self.max_hidden_area) # correcting vertical shift, so not too much area will be hidden augmentation_params["vertical_shift"] = np.clip( augmentation_params["vertical_shift"], -(aug_mask.up + max_hidden_up) / aug_mask.height, (aug_mask.height - aug_mask.down + max_hidden_down) / aug_mask.height ) # Applying vertical shift: vertical_shift = int(round(aug_mask.height * augmentation_params["vertical_shift"])) aug_mask.shift(vertical=vertical_shift, inplace=True) aug_mask.crop_to_canvas(vertical=True, horizontal=False, inplace=True) # Choosing horizontal shift: max_hidden_area = self.max_hidden_area - (1 - aug_mask.area() / total_aug_area) horizontal_area = aug_mask.mask.sum(axis=0) / total_aug_area max_hidden_left = np.searchsorted(horizontal_area.cumsum(), max_hidden_area) max_hidden_right = np.searchsorted(horizontal_area[::-1].cumsum(), max_hidden_area) allowed_shifts = np.arange(-max_hidden_left, aug_mask.width - (aug_mask.right - aug_mask.left) + max_hidden_right + 1) allowed_shifts = - (aug_mask.left - allowed_shifts) if self.position_shuffle: np.random.shuffle(allowed_shifts) mask_is_found = False for horizontal_shift in allowed_shifts: aug_mask_left = deepcopy(aug_mask) aug_mask_left.shift(horizontal=horizontal_shift, inplace=True) aug_mask_left.crop_to_canvas(inplace=True) prev_masks = [mask] + chosen_masks is_mask_suitable = self._check_masks_intersection(aug_mask_left, total_aug_area, prev_masks) & \ self._check_foreground_intersection(aug_mask_left, foreground) if is_mask_suitable: aug_draw = aug_mask_left.restore_full_mask() chosen_masks.append(aug_draw) augmentation_params["horizontal_shift"] = horizontal_shift / aug_mask_left.width chosen_parameters.append(augmentation_params) mask_is_found = True break if not mask_is_found: break return chosen_parameters def _prepare_mask(self, mask): height, width = mask.shape target_width = width if self._is_power_of_two(width) else (1 << width.bit_length()) target_height = height if self._is_power_of_two(height) else (1 << height.bit_length()) return resize(mask.astype('float32'), (target_height, target_width), order=0, mode='edge').round().astype('int32') def get_masks(self, im, return_panoptic=False): panoptic_seg, segments_info = self.get_segmentation(im) potential_mask_ids = self.identify_candidates(panoptic_seg, segments_info) panoptic_seg_scaled = self._prepare_mask(panoptic_seg.detach().cpu().numpy()) downsampled = self.downsample_mask(panoptic_seg_scaled) scene_objects = [] for segment in segments_info: if not segment["isthing"]: continue mask = downsampled == segment["id"] if not np.any(mask): continue scene_objects.append(mask) mask_set = [] for mask_id in potential_mask_ids: mask = downsampled == mask_id if not np.any(mask): continue if self.rigidness_mode is RigidnessMode.soft: foreground = [mask] elif self.rigidness_mode is RigidnessMode.rigid: foreground = scene_objects else: raise ValueError(f'Unexpected rigidness_mode: {rigidness_mode}') masks_params = self._move_mask(mask, foreground) full_mask = ObjectMask((panoptic_seg == mask_id).detach().cpu().numpy()) for params in masks_params: aug_mask = deepcopy(full_mask) aug_mask.rescale(params["scaling_factor"], inplace=True) if params["horizontal_flip"]: aug_mask.horizontal_flip(inplace=True) vertical_shift = int(round(aug_mask.height * params["vertical_shift"])) horizontal_shift = int(round(aug_mask.width * params["horizontal_shift"])) aug_mask.shift(vertical=vertical_shift, horizontal=horizontal_shift, inplace=True) aug_mask = aug_mask.restore_full_mask().astype('uint8') if aug_mask.mean() <= self.min_mask_area: continue mask_set.append(aug_mask) if return_panoptic: return mask_set, panoptic_seg.detach().cpu().numpy() else: return mask_set def propose_random_square_crop(mask, min_overlap=0.5): height, width = mask.shape mask_ys, mask_xs = np.where(mask > 0.5) # mask==0 is known fragment and mask==1 is missing if height < width: crop_size = height obj_left, obj_right = mask_xs.min(), mask_xs.max() obj_width = obj_right - obj_left left_border = max(0, min(width - crop_size - 1, obj_left + obj_width * min_overlap - crop_size)) right_border = max(left_border + 1, min(width - crop_size, obj_left + obj_width * min_overlap)) start_x = np.random.randint(left_border, right_border) return start_x, 0, start_x + crop_size, height else: crop_size = width obj_top, obj_bottom = mask_ys.min(), mask_ys.max() obj_height = obj_bottom - obj_top top_border = max(0, min(height - crop_size - 1, obj_top + obj_height * min_overlap - crop_size)) bottom_border = max(top_border + 1, min(height - crop_size, obj_top + obj_height * min_overlap)) start_y = np.random.randint(top_border, bottom_border) return 0, start_y, width, start_y + crop_size