# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import math import torch class BoxCoder(object): """ This class encodes and decodes a set of bounding boxes into the representation used for training the regressors. """ def __init__(self, weights, bbox_xform_clip=math.log(1000.0 / 16)): """ Arguments: weights (4-element tuple) bbox_xform_clip (float) """ self.weights = weights self.bbox_xform_clip = bbox_xform_clip def encode(self, reference_boxes, proposals): """ Encode a set of proposals with respect to some reference boxes Arguments: reference_boxes (Tensor): reference boxes proposals (Tensor): boxes to be encoded """ TO_REMOVE = 1 # TODO remove ex_widths = proposals[:, 2] - proposals[:, 0] + TO_REMOVE ex_heights = proposals[:, 3] - proposals[:, 1] + TO_REMOVE ex_ctr_x = proposals[:, 0] + 0.5 * ex_widths ex_ctr_y = proposals[:, 1] + 0.5 * ex_heights gt_widths = reference_boxes[:, 2] - reference_boxes[:, 0] + TO_REMOVE gt_heights = reference_boxes[:, 3] - reference_boxes[:, 1] + TO_REMOVE gt_ctr_x = reference_boxes[:, 0] + 0.5 * gt_widths gt_ctr_y = reference_boxes[:, 1] + 0.5 * gt_heights wx, wy, ww, wh = self.weights targets_dx = wx * (gt_ctr_x - ex_ctr_x) / ex_widths targets_dy = wy * (gt_ctr_y - ex_ctr_y) / ex_heights targets_dw = ww * torch.log(gt_widths / ex_widths) targets_dh = wh * torch.log(gt_heights / ex_heights) targets = torch.stack((targets_dx, targets_dy, targets_dw, targets_dh), dim=1) return targets def decode(self, rel_codes, boxes): """ From a set of original boxes and encoded relative box offsets, get the decoded boxes. Arguments: rel_codes (Tensor): encoded boxes boxes (Tensor): reference boxes. """ boxes = boxes.to(rel_codes.dtype) TO_REMOVE = 1 # TODO remove widths = boxes[:, 2] - boxes[:, 0] + TO_REMOVE heights = boxes[:, 3] - boxes[:, 1] + TO_REMOVE ctr_x = boxes[:, 0] + 0.5 * widths ctr_y = boxes[:, 1] + 0.5 * heights wx, wy, ww, wh = self.weights dx = rel_codes[:, 0::4] / wx dy = rel_codes[:, 1::4] / wy dw = rel_codes[:, 2::4] / ww dh = rel_codes[:, 3::4] / wh # Prevent sending too large values into torch.exp() dw = torch.clamp(dw, max=self.bbox_xform_clip) dh = torch.clamp(dh, max=self.bbox_xform_clip) pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] pred_w = torch.exp(dw) * widths[:, None] pred_h = torch.exp(dh) * heights[:, None] pred_boxes = torch.zeros_like(rel_codes) # x1 pred_boxes[:, 0::4] = pred_ctr_x - 0.5 * pred_w # y1 pred_boxes[:, 1::4] = pred_ctr_y - 0.5 * pred_h # x2 (note: "- 1" is correct; don't be fooled by the asymmetry) pred_boxes[:, 2::4] = pred_ctr_x + 0.5 * pred_w - 1 # y2 (note: "- 1" is correct; don't be fooled by the asymmetry) pred_boxes[:, 3::4] = pred_ctr_y + 0.5 * pred_h - 1 return pred_boxes