Spaces:
Runtime error
Runtime error
File size: 7,015 Bytes
3e99b05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import torch
import torch.nn as nn
from detectron2.modeling.matcher import Matcher
from detectron2.modeling.sampling import subsample_labels
from detrex.layers.box_ops import box_iou, box_cxcywh_to_xyxy
def sample_topk_per_gt(pr_inds, gt_inds, iou, k):
if len(gt_inds) == 0:
return pr_inds, gt_inds
# find topk matches for each gt
gt_inds2, counts = gt_inds.unique(return_counts=True)
scores, pr_inds2 = iou[gt_inds2].topk(k, dim=1)
gt_inds2 = gt_inds2[:,None].repeat(1, k)
# filter to as many matches that gt has
pr_inds3 = torch.cat([pr[:c] for c, pr in zip(counts, pr_inds2)])
gt_inds3 = torch.cat([gt[:c] for c, gt in zip(counts, gt_inds2)])
return pr_inds3, gt_inds3
# modified from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/roi_heads/roi_heads.py#L123
class Stage2Assigner(nn.Module):
def __init__(self, num_queries, max_k=4):
super().__init__()
self.positive_fraction = 0.25
self.bg_label = 400 # number > 91 to filter out later
self.batch_size_per_image = num_queries
self.proposal_matcher = Matcher(thresholds=[0.6], labels=[0, 1], allow_low_quality_matches=True)
self.k = max_k
def _sample_proposals(
self, matched_idxs: torch.Tensor, matched_labels: torch.Tensor, gt_classes: torch.Tensor
):
"""
Based on the matching between N proposals and M groundtruth,
sample the proposals and set their classification labels.
Args:
matched_idxs (Tensor): a vector of length N, each is the best-matched
gt index in [0, M) for each proposal.
matched_labels (Tensor): a vector of length N, the matcher's label
(one of cfg.MODEL.ROI_HEADS.IOU_LABELS) for each proposal.
gt_classes (Tensor): a vector of length M.
Returns:
Tensor: a vector of indices of sampled proposals. Each is in [0, N).
Tensor: a vector of the same length, the classification label for
each sampled proposal. Each sample is labeled as either a category in
[0, num_classes) or the background (num_classes).
"""
has_gt = gt_classes.numel() > 0
# Get the corresponding GT for each proposal
if has_gt:
gt_classes = gt_classes[matched_idxs]
# Label unmatched proposals (0 label from matcher) as background (label=num_classes)
gt_classes[matched_labels == 0] = self.bg_label
# Label ignore proposals (-1 label)
gt_classes[matched_labels == -1] = -1
else:
gt_classes = torch.zeros_like(matched_idxs) + self.bg_label
sampled_fg_idxs, sampled_bg_idxs = subsample_labels(
gt_classes, self.batch_size_per_image, self.positive_fraction, self.bg_label
)
sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0)
return sampled_idxs, gt_classes[sampled_idxs]
def forward(self, outputs, targets, return_cost_matrix=False):
# COCO categories are from 1 to 90. They set num_classes=91 and apply sigmoid.
bs = len(targets)
indices = []
ious = []
for b in range(bs):
iou, _ = box_iou(
box_cxcywh_to_xyxy(targets[b]['boxes']),
box_cxcywh_to_xyxy(outputs['init_reference'][b].detach()),
)
matched_idxs, matched_labels = self.proposal_matcher(iou) # proposal_id -> highest_iou_gt_id, proposal_id -> [1 if iou > 0.6, 0 ow]
sampled_idxs, sampled_gt_classes = self._sample_proposals( # list of sampled proposal_ids, sampled_id -> [0, num_classes)+[bg_label]
matched_idxs, matched_labels, targets[b]['labels']
)
pos_pr_inds = sampled_idxs[sampled_gt_classes != self.bg_label]
pos_gt_inds = matched_idxs[pos_pr_inds]
pos_pr_inds, pos_gt_inds = self.postprocess_indices(pos_pr_inds, pos_gt_inds, iou)
indices.append((pos_pr_inds, pos_gt_inds))
ious.append(iou)
if return_cost_matrix:
return indices, ious
return indices
def postprocess_indices(self, pr_inds, gt_inds, iou):
return sample_topk_per_gt(pr_inds, gt_inds, iou, self.k)
# modified from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/proposal_generator/rpn.py#L181
class Stage1Assigner(nn.Module):
def __init__(self, t_low=0.3, t_high=0.7, max_k=4):
super().__init__()
self.positive_fraction = 0.5
self.batch_size_per_image = 256
self.k = max_k
self.t_low = t_low
self.t_high = t_high
self.anchor_matcher = Matcher(thresholds=[t_low, t_high], labels=[0, -1, 1], allow_low_quality_matches=True)
def _subsample_labels(self, label):
"""
Randomly sample a subset of positive and negative examples, and overwrite
the label vector to the ignore value (-1) for all elements that are not
included in the sample.
Args:
labels (Tensor): a vector of -1, 0, 1. Will be modified in-place and returned.
"""
pos_idx, neg_idx = subsample_labels(
label, self.batch_size_per_image, self.positive_fraction, 0
)
# Fill with the ignore label (-1), then set positive and negative labels
label.fill_(-1)
label.scatter_(0, pos_idx, 1)
label.scatter_(0, neg_idx, 0)
return label
def forward(self, outputs, targets):
bs = len(targets)
indices = []
for b in range(bs):
anchors = outputs['anchors'][b]
if len(targets[b]['boxes']) == 0:
indices.append((torch.tensor([], dtype=torch.long, device=anchors.device),
torch.tensor([], dtype=torch.long, device=anchors.device)))
continue
iou, _ = box_iou(
box_cxcywh_to_xyxy(targets[b]['boxes']),
box_cxcywh_to_xyxy(anchors),
)
matched_idxs, matched_labels = self.anchor_matcher(iou) # proposal_id -> highest_iou_gt_id, proposal_id -> [1 if iou > 0.7, 0 if iou < 0.3, -1 ow]
matched_labels = self._subsample_labels(matched_labels)
all_pr_inds = torch.arange(len(anchors))
pos_pr_inds = all_pr_inds[matched_labels == 1]
pos_gt_inds = matched_idxs[pos_pr_inds]
pos_ious = iou[pos_gt_inds, pos_pr_inds]
pos_pr_inds, pos_gt_inds = self.postprocess_indices(pos_pr_inds, pos_gt_inds, iou)
pos_pr_inds, pos_gt_inds = pos_pr_inds.to(anchors.device), pos_gt_inds.to(anchors.device)
indices.append((pos_pr_inds, pos_gt_inds))
return indices
def postprocess_indices(self, pr_inds, gt_inds, iou):
return sample_topk_per_gt(pr_inds, gt_inds, iou, self.k)
|