Spaces:
Build error
Build error
| import cv2 | |
| import copy | |
| import torch | |
| import numpy as np | |
| from maskrcnn_benchmark.layers.misc import interpolate | |
| import pycocotools.mask as mask_utils | |
| # transpose | |
| FLIP_LEFT_RIGHT = 0 | |
| FLIP_TOP_BOTTOM = 1 | |
| """ ABSTRACT | |
| Segmentations come in either: | |
| 1) Binary masks | |
| 2) Polygons | |
| Binary masks can be represented in a contiguous array | |
| and operations can be carried out more efficiently, | |
| therefore BinaryMaskList handles them together. | |
| Polygons are handled separately for each instance, | |
| by PolygonInstance and instances are handled by | |
| PolygonList. | |
| SegmentationList is supposed to represent both, | |
| therefore it wraps the functions of BinaryMaskList | |
| and PolygonList to make it transparent. | |
| """ | |
| class BinaryMaskList(object): | |
| """ | |
| This class handles binary masks for all objects in the image | |
| """ | |
| def __init__(self, masks, size): | |
| """ | |
| Arguments: | |
| masks: Either torch.tensor of [num_instances, H, W] | |
| or list of torch.tensors of [H, W] with num_instances elems, | |
| or RLE (Run Length Encoding) - interpreted as list of dicts, | |
| or BinaryMaskList. | |
| size: absolute image size, width first | |
| After initialization, a hard copy will be made, to leave the | |
| initializing source data intact. | |
| """ | |
| if isinstance(masks, torch.Tensor): | |
| # The raw data representation is passed as argument | |
| masks = masks.clone() | |
| elif isinstance(masks, (list, tuple)): | |
| if isinstance(masks[0], torch.Tensor): | |
| masks = torch.stack(masks, dim=2).clone() | |
| elif isinstance(masks[0], dict) and "count" in masks[0]: | |
| # RLE interpretation | |
| masks = mask_utils | |
| else: | |
| RuntimeError( | |
| "Type of `masks[0]` could not be interpreted: %s" % type(masks) | |
| ) | |
| elif isinstance(masks, BinaryMaskList): | |
| # just hard copy the BinaryMaskList instance's underlying data | |
| masks = masks.masks.clone() | |
| else: | |
| RuntimeError( | |
| "Type of `masks` argument could not be interpreted:%s" % type(masks) | |
| ) | |
| if len(masks.shape) == 2: | |
| # if only a single instance mask is passed | |
| masks = masks[None] | |
| assert len(masks.shape) == 3 | |
| assert masks.shape[1] == size[1], "%s != %s" % (masks.shape[1], size[1]) | |
| assert masks.shape[2] == size[0], "%s != %s" % (masks.shape[2], size[0]) | |
| self.masks = masks | |
| self.size = tuple(size) | |
| def transpose(self, method): | |
| dim = 1 if method == FLIP_TOP_BOTTOM else 2 | |
| flipped_masks = self.masks.flip(dim) | |
| return BinaryMaskList(flipped_masks, self.size) | |
| def crop(self, box): | |
| assert isinstance(box, (list, tuple, torch.Tensor)), str(type(box)) | |
| # box is assumed to be xyxy | |
| current_width, current_height = self.size | |
| xmin, ymin, xmax, ymax = [round(float(b)) for b in box] | |
| assert xmin <= xmax and ymin <= ymax, str(box) | |
| xmin = min(max(xmin, 0), current_width - 1) | |
| ymin = min(max(ymin, 0), current_height - 1) | |
| xmax = min(max(xmax, 0), current_width) | |
| ymax = min(max(ymax, 0), current_height) | |
| xmax = max(xmax, xmin + 1) | |
| ymax = max(ymax, ymin + 1) | |
| width, height = xmax - xmin, ymax - ymin | |
| cropped_masks = self.masks[:, ymin:ymax, xmin:xmax] | |
| cropped_size = width, height | |
| return BinaryMaskList(cropped_masks, cropped_size) | |
| def resize(self, size): | |
| try: | |
| iter(size) | |
| except TypeError: | |
| assert isinstance(size, (int, float)) | |
| size = size, size | |
| width, height = map(int, size) | |
| assert width > 0 | |
| assert height > 0 | |
| # Height comes first here! | |
| resized_masks = torch.nn.functional.interpolate( | |
| input=self.masks[None].float(), | |
| size=(height, width), | |
| mode="bilinear", | |
| align_corners=False, | |
| )[0].type_as(self.masks) | |
| resized_size = width, height | |
| return BinaryMaskList(resized_masks, resized_size) | |
| def convert_to_polygon(self): | |
| contours = self._findContours() | |
| return PolygonList(contours, self.size) | |
| def to(self, *args, **kwargs): | |
| return self | |
| def _findContours(self): | |
| contours = [] | |
| masks = self.masks.detach().numpy() | |
| for mask in masks: | |
| mask = cv2.UMat(mask) | |
| contour, hierarchy = cv2.findContours( | |
| mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_TC89_L1 | |
| ) | |
| reshaped_contour = [] | |
| for entity in contour: | |
| assert len(entity.shape) == 3 | |
| assert entity.shape[1] == 1, "Hierarchical contours are not allowed" | |
| reshaped_contour.append(entity.reshape(-1).tolist()) | |
| contours.append(reshaped_contour) | |
| return contours | |
| def __len__(self): | |
| return len(self.masks) | |
| def __getitem__(self, index): | |
| # Probably it can cause some overhead | |
| # but preserves consistency | |
| masks = self.masks[index].clone() | |
| return BinaryMaskList(masks, self.size) | |
| def __iter__(self): | |
| return iter(self.masks) | |
| def __repr__(self): | |
| s = self.__class__.__name__ + "(" | |
| s += "num_instances={}, ".format(len(self.masks)) | |
| s += "image_width={}, ".format(self.size[0]) | |
| s += "image_height={})".format(self.size[1]) | |
| return s | |
| class PolygonInstance(object): | |
| """ | |
| This class holds a set of polygons that represents a single instance | |
| of an object mask. The object can be represented as a set of | |
| polygons | |
| """ | |
| def __init__(self, polygons, size): | |
| """ | |
| Arguments: | |
| a list of lists of numbers. | |
| The first level refers to all the polygons that compose the | |
| object, and the second level to the polygon coordinates. | |
| """ | |
| if isinstance(polygons, (list, tuple)): | |
| valid_polygons = [] | |
| for p in polygons: | |
| p = torch.as_tensor(p, dtype=torch.float32) | |
| if len(p) >= 6: # 3 * 2 coordinates | |
| valid_polygons.append(p) | |
| polygons = valid_polygons | |
| elif isinstance(polygons, PolygonInstance): | |
| polygons = copy.copy(polygons.polygons) | |
| else: | |
| RuntimeError( | |
| "Type of argument `polygons` is not allowed:%s" % (type(polygons)) | |
| ) | |
| """ This crashes the training way too many times... | |
| for p in polygons: | |
| assert p[::2].min() >= 0 | |
| assert p[::2].max() < size[0] | |
| assert p[1::2].min() >= 0 | |
| assert p[1::2].max() , size[1] | |
| """ | |
| self.polygons = polygons | |
| self.size = tuple(size) | |
| def transpose(self, method): | |
| if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): | |
| raise NotImplementedError( | |
| "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" | |
| ) | |
| flipped_polygons = [] | |
| width, height = self.size | |
| if method == FLIP_LEFT_RIGHT: | |
| dim = width | |
| idx = 0 | |
| elif method == FLIP_TOP_BOTTOM: | |
| dim = height | |
| idx = 1 | |
| for poly in self.polygons: | |
| p = poly.clone() | |
| TO_REMOVE = 1 | |
| p[idx::2] = dim - poly[idx::2] - TO_REMOVE | |
| flipped_polygons.append(p) | |
| return PolygonInstance(flipped_polygons, size=self.size) | |
| def crop(self, box): | |
| assert isinstance(box, (list, tuple, torch.Tensor)), str(type(box)) | |
| # box is assumed to be xyxy | |
| current_width, current_height = self.size | |
| xmin, ymin, xmax, ymax = map(float, box) | |
| assert xmin <= xmax and ymin <= ymax, str(box) | |
| xmin = min(max(xmin, 0), current_width - 1) | |
| ymin = min(max(ymin, 0), current_height - 1) | |
| xmax = min(max(xmax, 0), current_width) | |
| ymax = min(max(ymax, 0), current_height) | |
| xmax = max(xmax, xmin + 1) | |
| ymax = max(ymax, ymin + 1) | |
| w, h = xmax - xmin, ymax - ymin | |
| cropped_polygons = [] | |
| for poly in self.polygons: | |
| p = poly.clone() | |
| p[0::2] = p[0::2] - xmin # .clamp(min=0, max=w) | |
| p[1::2] = p[1::2] - ymin # .clamp(min=0, max=h) | |
| cropped_polygons.append(p) | |
| return PolygonInstance(cropped_polygons, size=(w, h)) | |
| def resize(self, size): | |
| try: | |
| iter(size) | |
| except TypeError: | |
| assert isinstance(size, (int, float)) | |
| size = size, size | |
| ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size)) | |
| if ratios[0] == ratios[1]: | |
| ratio = ratios[0] | |
| scaled_polys = [p * ratio for p in self.polygons] | |
| return PolygonInstance(scaled_polys, size) | |
| ratio_w, ratio_h = ratios | |
| scaled_polygons = [] | |
| for poly in self.polygons: | |
| p = poly.clone() | |
| p[0::2] *= ratio_w | |
| p[1::2] *= ratio_h | |
| scaled_polygons.append(p) | |
| return PolygonInstance(scaled_polygons, size=size) | |
| def convert_to_binarymask(self): | |
| width, height = self.size | |
| # formatting for COCO PythonAPI | |
| polygons = [p.numpy() for p in self.polygons] | |
| rles = mask_utils.frPyObjects(polygons, height, width) | |
| rle = mask_utils.merge(rles) | |
| mask = mask_utils.decode(rle) | |
| mask = torch.from_numpy(mask) | |
| return mask | |
| def __len__(self): | |
| return len(self.polygons) | |
| def __repr__(self): | |
| s = self.__class__.__name__ + "(" | |
| s += "num_groups={}, ".format(len(self.polygons)) | |
| s += "image_width={}, ".format(self.size[0]) | |
| s += "image_height={}, ".format(self.size[1]) | |
| return s | |
| class PolygonList(object): | |
| """ | |
| This class handles PolygonInstances for all objects in the image | |
| """ | |
| def __init__(self, polygons, size): | |
| """ | |
| Arguments: | |
| polygons: | |
| a list of list of lists of numbers. The first | |
| level of the list correspond to individual instances, | |
| the second level to all the polygons that compose the | |
| object, and the third level to the polygon coordinates. | |
| OR | |
| a list of PolygonInstances. | |
| OR | |
| a PolygonList | |
| size: absolute image size | |
| """ | |
| if isinstance(polygons, (list, tuple)): | |
| if len(polygons) == 0: | |
| polygons = [[[]]] | |
| if isinstance(polygons[0], (list, tuple)): | |
| assert isinstance(polygons[0][0], (list, tuple)), str( | |
| type(polygons[0][0]) | |
| ) | |
| else: | |
| assert isinstance(polygons[0], PolygonInstance), str(type(polygons[0])) | |
| elif isinstance(polygons, PolygonList): | |
| size = polygons.size | |
| polygons = polygons.polygons | |
| else: | |
| RuntimeError( | |
| "Type of argument `polygons` is not allowed:%s" % (type(polygons)) | |
| ) | |
| assert isinstance(size, (list, tuple)), str(type(size)) | |
| self.polygons = [] | |
| for p in polygons: | |
| p = PolygonInstance(p, size) | |
| if len(p) > 0: | |
| self.polygons.append(p) | |
| self.size = tuple(size) | |
| def transpose(self, method): | |
| if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM): | |
| raise NotImplementedError( | |
| "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented" | |
| ) | |
| flipped_polygons = [] | |
| for polygon in self.polygons: | |
| flipped_polygons.append(polygon.transpose(method)) | |
| return PolygonList(flipped_polygons, size=self.size) | |
| def crop(self, box): | |
| w, h = box[2] - box[0], box[3] - box[1] | |
| cropped_polygons = [] | |
| for polygon in self.polygons: | |
| cropped_polygons.append(polygon.crop(box)) | |
| cropped_size = w, h | |
| return PolygonList(cropped_polygons, cropped_size) | |
| def resize(self, size): | |
| resized_polygons = [] | |
| for polygon in self.polygons: | |
| resized_polygons.append(polygon.resize(size)) | |
| resized_size = size | |
| return PolygonList(resized_polygons, resized_size) | |
| def to(self, *args, **kwargs): | |
| return self | |
| def convert_to_binarymask(self): | |
| if len(self) > 0: | |
| masks = torch.stack([p.convert_to_binarymask() for p in self.polygons]) | |
| else: | |
| size = self.size | |
| masks = torch.empty([0, size[1], size[0]], dtype=torch.uint8) | |
| return BinaryMaskList(masks, size=self.size) | |
| def __len__(self): | |
| return len(self.polygons) | |
| def __getitem__(self, item): | |
| if isinstance(item, int): | |
| selected_polygons = [self.polygons[item]] | |
| elif isinstance(item, slice): | |
| selected_polygons = self.polygons[item] | |
| else: | |
| # advanced indexing on a single dimension | |
| selected_polygons = [] | |
| if isinstance(item, torch.Tensor) and item.dtype == torch.uint8: | |
| item = item.nonzero() | |
| item = item.squeeze(1) if item.numel() > 0 else item | |
| item = item.tolist() | |
| for i in item: | |
| selected_polygons.append(self.polygons[i]) | |
| return PolygonList(selected_polygons, size=self.size) | |
| def __iter__(self): | |
| return iter(self.polygons) | |
| def __repr__(self): | |
| s = self.__class__.__name__ + "(" | |
| s += "num_instances={}, ".format(len(self.polygons)) | |
| s += "image_width={}, ".format(self.size[0]) | |
| s += "image_height={})".format(self.size[1]) | |
| return s | |
| class SegmentationMask(object): | |
| """ | |
| This class stores the segmentations for all objects in the image. | |
| It wraps BinaryMaskList and PolygonList conveniently. | |
| """ | |
| def __init__(self, instances, size, mode="poly"): | |
| """ | |
| Arguments: | |
| instances: two types | |
| (1) polygon | |
| (2) binary mask | |
| size: (width, height) | |
| mode: 'poly', 'mask'. if mode is 'mask', convert mask of any format to binary mask | |
| """ | |
| assert isinstance(size, (list, tuple)) | |
| assert len(size) == 2 | |
| if isinstance(size[0], torch.Tensor): | |
| assert isinstance(size[1], torch.Tensor) | |
| size = size[0].item(), size[1].item() | |
| assert isinstance(size[0], (int, float)) | |
| assert isinstance(size[1], (int, float)) | |
| if mode == "poly": | |
| self.instances = PolygonList(instances, size) | |
| elif mode == "mask": | |
| self.instances = BinaryMaskList(instances, size) | |
| else: | |
| raise NotImplementedError("Unknown mode: %s" % str(mode)) | |
| self.mode = mode | |
| self.size = tuple(size) | |
| def transpose(self, method): | |
| flipped_instances = self.instances.transpose(method) | |
| return SegmentationMask(flipped_instances, self.size, self.mode) | |
| def crop(self, box): | |
| cropped_instances = self.instances.crop(box) | |
| cropped_size = cropped_instances.size | |
| return SegmentationMask(cropped_instances, cropped_size, self.mode) | |
| def resize(self, size, *args, **kwargs): | |
| resized_instances = self.instances.resize(size) | |
| resized_size = size | |
| return SegmentationMask(resized_instances, resized_size, self.mode) | |
| def to(self, *args, **kwargs): | |
| return self | |
| def convert(self, mode): | |
| if mode == self.mode: | |
| return self | |
| if mode == "poly": | |
| converted_instances = self.instances.convert_to_polygon() | |
| elif mode == "mask": | |
| converted_instances = self.instances.convert_to_binarymask() | |
| else: | |
| raise NotImplementedError("Unknown mode: %s" % str(mode)) | |
| return SegmentationMask(converted_instances, self.size, mode) | |
| def get_mask_tensor(self): | |
| instances = self.instances | |
| if self.mode == "poly": | |
| instances = instances.convert_to_binarymask() | |
| # If there is only 1 instance | |
| return instances.masks.squeeze(0) | |
| def __len__(self): | |
| return len(self.instances) | |
| def __getitem__(self, item): | |
| selected_instances = self.instances.__getitem__(item) | |
| return SegmentationMask(selected_instances, self.size, self.mode) | |
| def __iter__(self): | |
| self.iter_idx = 0 | |
| return self | |
| def __next__(self): | |
| if self.iter_idx < self.__len__(): | |
| next_segmentation = self.__getitem__(self.iter_idx) | |
| self.iter_idx += 1 | |
| return next_segmentation | |
| raise StopIteration() | |
| next = __next__ # Python 2 compatibility | |
| def __repr__(self): | |
| s = self.__class__.__name__ + "(" | |
| s += "num_instances={}, ".format(len(self.instances)) | |
| s += "image_width={}, ".format(self.size[0]) | |
| s += "image_height={}, ".format(self.size[1]) | |
| s += "mode={})".format(self.mode) | |
| return s | |