Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from detectron2.modeling.matcher import Matcher | |
from detectron2.modeling.sampling import subsample_labels | |
from detrex.layers.box_ops import box_iou, box_cxcywh_to_xyxy | |
def sample_topk_per_gt(pr_inds, gt_inds, iou, k): | |
if len(gt_inds) == 0: | |
return pr_inds, gt_inds | |
# find topk matches for each gt | |
gt_inds2, counts = gt_inds.unique(return_counts=True) | |
scores, pr_inds2 = iou[gt_inds2].topk(k, dim=1) | |
gt_inds2 = gt_inds2[:,None].repeat(1, k) | |
# filter to as many matches that gt has | |
pr_inds3 = torch.cat([pr[:c] for c, pr in zip(counts, pr_inds2)]) | |
gt_inds3 = torch.cat([gt[:c] for c, gt in zip(counts, gt_inds2)]) | |
return pr_inds3, gt_inds3 | |
# modified from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/roi_heads/roi_heads.py#L123 | |
class Stage2Assigner(nn.Module): | |
def __init__(self, num_queries, max_k=4): | |
super().__init__() | |
self.positive_fraction = 0.25 | |
self.bg_label = 400 # number > 91 to filter out later | |
self.batch_size_per_image = num_queries | |
self.proposal_matcher = Matcher(thresholds=[0.6], labels=[0, 1], allow_low_quality_matches=True) | |
self.k = max_k | |
def _sample_proposals( | |
self, matched_idxs: torch.Tensor, matched_labels: torch.Tensor, gt_classes: torch.Tensor | |
): | |
""" | |
Based on the matching between N proposals and M groundtruth, | |
sample the proposals and set their classification labels. | |
Args: | |
matched_idxs (Tensor): a vector of length N, each is the best-matched | |
gt index in [0, M) for each proposal. | |
matched_labels (Tensor): a vector of length N, the matcher's label | |
(one of cfg.MODEL.ROI_HEADS.IOU_LABELS) for each proposal. | |
gt_classes (Tensor): a vector of length M. | |
Returns: | |
Tensor: a vector of indices of sampled proposals. Each is in [0, N). | |
Tensor: a vector of the same length, the classification label for | |
each sampled proposal. Each sample is labeled as either a category in | |
[0, num_classes) or the background (num_classes). | |
""" | |
has_gt = gt_classes.numel() > 0 | |
# Get the corresponding GT for each proposal | |
if has_gt: | |
gt_classes = gt_classes[matched_idxs] | |
# Label unmatched proposals (0 label from matcher) as background (label=num_classes) | |
gt_classes[matched_labels == 0] = self.bg_label | |
# Label ignore proposals (-1 label) | |
gt_classes[matched_labels == -1] = -1 | |
else: | |
gt_classes = torch.zeros_like(matched_idxs) + self.bg_label | |
sampled_fg_idxs, sampled_bg_idxs = subsample_labels( | |
gt_classes, self.batch_size_per_image, self.positive_fraction, self.bg_label | |
) | |
sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0) | |
return sampled_idxs, gt_classes[sampled_idxs] | |
def forward(self, outputs, targets, return_cost_matrix=False): | |
# COCO categories are from 1 to 90. They set num_classes=91 and apply sigmoid. | |
bs = len(targets) | |
indices = [] | |
ious = [] | |
for b in range(bs): | |
iou, _ = box_iou( | |
box_cxcywh_to_xyxy(targets[b]['boxes']), | |
box_cxcywh_to_xyxy(outputs['init_reference'][b].detach()), | |
) | |
matched_idxs, matched_labels = self.proposal_matcher(iou) # proposal_id -> highest_iou_gt_id, proposal_id -> [1 if iou > 0.6, 0 ow] | |
sampled_idxs, sampled_gt_classes = self._sample_proposals( # list of sampled proposal_ids, sampled_id -> [0, num_classes)+[bg_label] | |
matched_idxs, matched_labels, targets[b]['labels'] | |
) | |
pos_pr_inds = sampled_idxs[sampled_gt_classes != self.bg_label] | |
pos_gt_inds = matched_idxs[pos_pr_inds] | |
pos_pr_inds, pos_gt_inds = self.postprocess_indices(pos_pr_inds, pos_gt_inds, iou) | |
indices.append((pos_pr_inds, pos_gt_inds)) | |
ious.append(iou) | |
if return_cost_matrix: | |
return indices, ious | |
return indices | |
def postprocess_indices(self, pr_inds, gt_inds, iou): | |
return sample_topk_per_gt(pr_inds, gt_inds, iou, self.k) | |
# modified from https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/modeling/proposal_generator/rpn.py#L181 | |
class Stage1Assigner(nn.Module): | |
def __init__(self, t_low=0.3, t_high=0.7, max_k=4): | |
super().__init__() | |
self.positive_fraction = 0.5 | |
self.batch_size_per_image = 256 | |
self.k = max_k | |
self.t_low = t_low | |
self.t_high = t_high | |
self.anchor_matcher = Matcher(thresholds=[t_low, t_high], labels=[0, -1, 1], allow_low_quality_matches=True) | |
def _subsample_labels(self, label): | |
""" | |
Randomly sample a subset of positive and negative examples, and overwrite | |
the label vector to the ignore value (-1) for all elements that are not | |
included in the sample. | |
Args: | |
labels (Tensor): a vector of -1, 0, 1. Will be modified in-place and returned. | |
""" | |
pos_idx, neg_idx = subsample_labels( | |
label, self.batch_size_per_image, self.positive_fraction, 0 | |
) | |
# Fill with the ignore label (-1), then set positive and negative labels | |
label.fill_(-1) | |
label.scatter_(0, pos_idx, 1) | |
label.scatter_(0, neg_idx, 0) | |
return label | |
def forward(self, outputs, targets): | |
bs = len(targets) | |
indices = [] | |
for b in range(bs): | |
anchors = outputs['anchors'][b] | |
if len(targets[b]['boxes']) == 0: | |
indices.append((torch.tensor([], dtype=torch.long, device=anchors.device), | |
torch.tensor([], dtype=torch.long, device=anchors.device))) | |
continue | |
iou, _ = box_iou( | |
box_cxcywh_to_xyxy(targets[b]['boxes']), | |
box_cxcywh_to_xyxy(anchors), | |
) | |
matched_idxs, matched_labels = self.anchor_matcher(iou) # proposal_id -> highest_iou_gt_id, proposal_id -> [1 if iou > 0.7, 0 if iou < 0.3, -1 ow] | |
matched_labels = self._subsample_labels(matched_labels) | |
all_pr_inds = torch.arange(len(anchors)) | |
pos_pr_inds = all_pr_inds[matched_labels == 1] | |
pos_gt_inds = matched_idxs[pos_pr_inds] | |
pos_ious = iou[pos_gt_inds, pos_pr_inds] | |
pos_pr_inds, pos_gt_inds = self.postprocess_indices(pos_pr_inds, pos_gt_inds, iou) | |
pos_pr_inds, pos_gt_inds = pos_pr_inds.to(anchors.device), pos_gt_inds.to(anchors.device) | |
indices.append((pos_pr_inds, pos_gt_inds)) | |
return indices | |
def postprocess_indices(self, pr_inds, gt_inds, iou): | |
return sample_topk_per_gt(pr_inds, gt_inds, iou, self.k) | |