Spaces:
Build error
Build error
# Copyright (c) Tencent Inc. All rights reserved. | |
import torch | |
from torch import Tensor | |
from mmyolo.registry import TASK_UTILS | |
from mmyolo.models.task_modules.assigners import BatchTaskAlignedAssigner | |
from mmyolo.models.task_modules.assigners.utils import select_highest_overlaps | |
class YOLOWorldSegAssigner(BatchTaskAlignedAssigner): | |
def __init__(self, | |
num_classes: int, | |
topk: int = 13, | |
alpha: float = 1, | |
beta: float = 6, | |
eps: float = 1e-7, | |
use_ciou: bool = False): | |
super().__init__(num_classes, topk, alpha, beta, eps, use_ciou) | |
def forward( | |
self, | |
pred_bboxes: Tensor, | |
pred_scores: Tensor, | |
priors: Tensor, | |
gt_labels: Tensor, | |
gt_bboxes: Tensor, | |
pad_bbox_flag: Tensor, | |
) -> dict: | |
"""Assign gt to bboxes. | |
The assignment is done in following steps | |
1. compute alignment metric between all bbox (bbox of all pyramid | |
levels) and gt | |
2. select top-k bbox as candidates for each gt | |
3. limit the positive sample's center in gt (because the anchor-free | |
detector only can predict positive distance) | |
Args: | |
pred_bboxes (Tensor): Predict bboxes, | |
shape(batch_size, num_priors, 4) | |
pred_scores (Tensor): Scores of predict bboxes, | |
shape(batch_size, num_priors, num_classes) | |
priors (Tensor): Model priors, shape (num_priors, 4) | |
gt_labels (Tensor): Ground true labels, | |
shape(batch_size, num_gt, 1) | |
gt_bboxes (Tensor): Ground true bboxes, | |
shape(batch_size, num_gt, 4) | |
pad_bbox_flag (Tensor): Ground truth bbox mask, | |
1 means bbox, 0 means no bbox, | |
shape(batch_size, num_gt, 1) | |
Returns: | |
assigned_result (dict) Assigned result: | |
assigned_labels (Tensor): Assigned labels, | |
shape(batch_size, num_priors) | |
assigned_bboxes (Tensor): Assigned boxes, | |
shape(batch_size, num_priors, 4) | |
assigned_scores (Tensor): Assigned scores, | |
shape(batch_size, num_priors, num_classes) | |
fg_mask_pre_prior (Tensor): Force ground truth matching mask, | |
shape(batch_size, num_priors) | |
""" | |
# (num_priors, 4) -> (num_priors, 2) | |
priors = priors[:, :2] | |
batch_size = pred_scores.size(0) | |
num_gt = gt_bboxes.size(1) | |
assigned_result = { | |
'assigned_labels': | |
gt_bboxes.new_full(pred_scores[..., 0].shape, self.num_classes), | |
'assigned_bboxes': | |
gt_bboxes.new_full(pred_bboxes.shape, 0), | |
'assigned_scores': | |
gt_bboxes.new_full(pred_scores.shape, 0), | |
'fg_mask_pre_prior': | |
gt_bboxes.new_full(pred_scores[..., 0].shape, 0) | |
} | |
if num_gt == 0: | |
return assigned_result | |
pos_mask, alignment_metrics, overlaps = self.get_pos_mask( | |
pred_bboxes, pred_scores, priors, gt_labels, gt_bboxes, | |
pad_bbox_flag, batch_size, num_gt) | |
(assigned_gt_idxs, fg_mask_pre_prior, | |
pos_mask) = select_highest_overlaps(pos_mask, overlaps, num_gt) | |
# assigned target | |
assigned_labels, assigned_bboxes, assigned_scores = self.get_targets( | |
gt_labels, gt_bboxes, assigned_gt_idxs, fg_mask_pre_prior, | |
batch_size, num_gt) | |
# normalize | |
alignment_metrics *= pos_mask | |
pos_align_metrics = alignment_metrics.max(axis=-1, keepdim=True)[0] | |
pos_overlaps = (overlaps * pos_mask).max(axis=-1, keepdim=True)[0] | |
norm_align_metric = ( | |
alignment_metrics * pos_overlaps / | |
(pos_align_metrics + self.eps)).max(-2)[0].unsqueeze(-1) | |
assigned_scores = assigned_scores * norm_align_metric | |
assigned_result['assigned_labels'] = assigned_labels | |
assigned_result['assigned_bboxes'] = assigned_bboxes | |
assigned_result['assigned_scores'] = assigned_scores | |
assigned_result['fg_mask_pre_prior'] = fg_mask_pre_prior.bool() | |
assigned_result['assigned_gt_idxs'] = assigned_gt_idxs | |
return assigned_result | |