# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import torch from detectron2.layers import cat def get_point_coords_from_point_annotation(instances): """ Load point coords and their corresponding labels from point annotation. Args: instances (list[Instances]): A list of N Instances, where N is the number of images in the batch. These instances are in 1:1 correspondence with the pred_mask_logits. The ground-truth labels (class, box, mask, ...) associated with each instance are stored in fields. Returns: point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P sampled points. point_labels (Tensor): A tensor of shape (N, P) that contains the labels of P sampled points. `point_labels` takes 3 possible values: - 0: the point belongs to background - 1: the point belongs to the object - -1: the point is ignored during training """ point_coords_list = [] point_labels_list = [] for instances_per_image in instances: if len(instances_per_image) == 0: continue point_coords = instances_per_image.gt_point_coords.to(torch.float32) point_labels = instances_per_image.gt_point_labels.to(torch.float32).clone() proposal_boxes_per_image = instances_per_image.proposal_boxes.tensor # Convert point coordinate system, ground truth points are in image coord. point_coords_wrt_box = get_point_coords_wrt_box(proposal_boxes_per_image, point_coords) # Ignore points that are outside predicted boxes. point_ignores = ( (point_coords_wrt_box[:, :, 0] < 0) | (point_coords_wrt_box[:, :, 0] > 1) | (point_coords_wrt_box[:, :, 1] < 0) | (point_coords_wrt_box[:, :, 1] > 1) ) point_labels[point_ignores] = -1 point_coords_list.append(point_coords_wrt_box) point_labels_list.append(point_labels) return ( cat(point_coords_list, dim=0), cat(point_labels_list, dim=0), ) def get_point_coords_wrt_box(boxes_coords, point_coords): """ Convert image-level absolute coordinates to box-normalized [0, 1] x [0, 1] point cooordinates. Args: boxes_coords (Tensor): A tensor of shape (R, 4) that contains bounding boxes. coordinates. point_coords (Tensor): A tensor of shape (R, P, 2) that contains image-normalized coordinates of P sampled points. Returns: point_coords_wrt_box (Tensor): A tensor of shape (R, P, 2) that contains [0, 1] x [0, 1] box-normalized coordinates of the P sampled points. """ with torch.no_grad(): point_coords_wrt_box = point_coords.clone() point_coords_wrt_box[:, :, 0] -= boxes_coords[:, None, 0] point_coords_wrt_box[:, :, 1] -= boxes_coords[:, None, 1] point_coords_wrt_box[:, :, 0] = point_coords_wrt_box[:, :, 0] / ( boxes_coords[:, None, 2] - boxes_coords[:, None, 0] ) point_coords_wrt_box[:, :, 1] = point_coords_wrt_box[:, :, 1] / ( boxes_coords[:, None, 3] - boxes_coords[:, None, 1] ) return point_coords_wrt_box