3v324v23's picture
add
c310e19
raw
history blame
7.39 kB
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from maskrcnn_benchmark.layers import nms as _box_nms
from .bounding_box import BoxList
from maskrcnn_benchmark.structures.segmentation_mask import SegmentationMask
import numpy as np
import shapely
from shapely.geometry import Polygon,MultiPoint
def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="score"):
"""
Performs non-maximum suppression on a boxlist, with scores specified
in a boxlist field via score_field.
Arguments:
boxlist(BoxList)
nms_thresh (float)
max_proposals (int): if > 0, then only the top max_proposals are kept
after non-maxium suppression
score_field (str)
"""
if nms_thresh <= 0:
return boxlist
mode = boxlist.mode
boxlist = boxlist.convert("xyxy")
boxes = boxlist.bbox
score = boxlist.get_field(score_field)
keep = _box_nms(boxes, score, nms_thresh)
if max_proposals > 0:
keep = keep[:max_proposals]
boxlist = boxlist[keep]
return boxlist.convert(mode)
def remove_small_boxes(boxlist, min_size):
"""
Only keep boxes with both sides >= min_size
Arguments:
boxlist (Boxlist)
min_size (int)
"""
# TODO maybe add an API for querying the ws / hs
xywh_boxes = boxlist.convert("xywh").bbox
_, _, ws, hs = xywh_boxes.unbind(dim=1)
keep = ((ws >= min_size) & (hs >= min_size)).nonzero().squeeze(1)
return boxlist[keep]
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def boxlist_iou(boxlist1, boxlist2):
"""Compute the intersection over union of two set of boxes.
The box order must be (xmin, ymin, xmax, ymax).
Arguments:
box1: (BoxList) bounding boxes, sized [N,4].
box2: (BoxList) bounding boxes, sized [M,4].
Returns:
(tensor) iou, sized [N,M].
Reference:
https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py
"""
if boxlist1.size != boxlist2.size:
raise RuntimeError(
"boxlists should have same image size, got {}, {}".format(
boxlist1, boxlist2
)
)
# N = len(boxlist1)
# M = len(boxlist2)
area1 = boxlist1.area()
area2 = boxlist2.area()
box1, box2 = boxlist1.bbox, boxlist2.bbox
lt = torch.max(box1[:, None, :2], box2[:, :2]) # [N,M,2]
rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # [N,M,2]
TO_REMOVE = 1
wh = (rb - lt + TO_REMOVE).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
iou = inter / (area1[:, None] + area2 - inter)
return iou
# def boxlist_polygon_iou(target, proposal):
# """Compute the intersection over union of two set of boxes.
# The box order must be (xmin, ymin, xmax, ymax).
# Arguments:
# box1: (BoxList) bounding boxes, sized [N,4].
# box2: (BoxList) bounding boxes, sized [M,4].
# Returns:
# (tensor) iou, sized [N,M].
# Reference:
# https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py
# """
# if target.size != proposal.size:
# raise RuntimeError(
# "boxlists should have same image size, got {}, {}".format(
# target, proposal
# )
# )
# target_polygon = target.get_field("masks").to_np_polygon()
# proposal_polygon = proposal.get_field("masks").to_np_polygon()
# print(target_polygon)
# print(proposal_polygon)
# polygon_points1 = target_polygon[0].reshape(-1, 2)
# poly1 = Polygon(polygon_points1).convex_hull
# polygon_points2 = proposal_polygon[0].reshape(-1, 2)
# poly2 = Polygon(polygon_points2).convex_hull
# union_poly = np.concatenate((polygon_points1, polygon_points2))
# if not poly1.intersects(poly2): # this test is fast and can accelerate calculation
# iou = 0
# else:
# try:
# inter_area = poly1.intersection(poly2).area
# #union_area = poly1.area + poly2.area - inter_area
# union_area = MultiPoint(union_poly).convex_hull.area
# if union_area == 0:
# return 0
# iou = float(inter_area) / union_area
# except shapely.geos.TopologicalError:
# print('shapely.geos.TopologicalError occured, iou set to 0')
# iou = 0
# return iou
# TODO redundant, remove
def _cat(tensors, dim=0):
"""
Efficient version of torch.cat
avoids a copy if there is only a single element in a list
"""
assert isinstance(tensors, (list, tuple))
if len(tensors) == 1:
return tensors[0]
return torch.cat(tensors, dim)
def _cat_mask(masks):
polygons_cat = []
size = masks[0].size
for mask in masks:
polygons = mask.get_polygons()
polygons_cat.extend(polygons)
masks_cat = SegmentationMask(polygons_cat, size)
return masks_cat
def cat_boxlist(bboxes):
"""
Concatenates a list of BoxList (having the same image size) into a
single BoxList
Arguments:
bboxes (list[BoxList])
"""
# if bboxes is None:
# return None
# if bboxes[0] is None:
# bboxes = [bboxes[1]
assert isinstance(bboxes, (list, tuple))
assert all(isinstance(bbox, BoxList) for bbox in bboxes)
size = bboxes[0].size
assert all(bbox.size == size for bbox in bboxes)
mode = bboxes[0].mode
assert all(bbox.mode == mode for bbox in bboxes)
fields = set(bboxes[0].fields())
assert all(set(bbox.fields()) == fields for bbox in bboxes)
cat_boxes = BoxList(_cat([bbox.bbox for bbox in bboxes], dim=0), size, mode)
for field in fields:
if field == 'masks':
data = _cat_mask([bbox.get_field(field) for bbox in bboxes])
else:
data = _cat([bbox.get_field(field) for bbox in bboxes], dim=0)
cat_boxes.add_field(field, data)
return cat_boxes
def cat_boxlist_gt(bboxes):
"""
Concatenates a list of BoxList (having the same image size) into a
single BoxList
Arguments:
bboxes (list[BoxList])
"""
assert isinstance(bboxes, (list, tuple))
assert all(isinstance(bbox, BoxList) for bbox in bboxes)
size = bboxes[0].size
# bboxes[1].set_size(size)
assert all(bbox.size == size for bbox in bboxes)
mode = bboxes[0].mode
assert all(bbox.mode == mode for bbox in bboxes)
fields = set(bboxes[0].fields())
assert all(set(bbox.fields()) == fields for bbox in bboxes)
if bboxes[0].bbox.sum().item() == 0:
cat_boxes = BoxList(bboxes[1].bbox, size, mode)
else:
cat_boxes = BoxList(_cat([bbox.bbox for bbox in bboxes], dim=0), size, mode)
for field in fields:
if bboxes[0].bbox.sum().item() == 0:
if field == 'masks':
data = _cat_mask([bbox.get_field(field) for bbox in bboxes[1:]])
else:
data = _cat([bbox.get_field(field) for bbox in bboxes[1:]], dim=0)
else:
if field == 'masks':
data = _cat_mask([bbox.get_field(field) for bbox in bboxes])
else:
data = _cat([bbox.get_field(field) for bbox in bboxes], dim=0)
cat_boxes.add_field(field, data)
return cat_boxes