YOLO-World4 / yolo_world /models /assigner /task_aligned_assigner.py
csmithxc's picture
Upload 146 files
1530901 verified
raw
history blame
No virus
4.38 kB
# Copyright (c) Tencent Inc. All rights reserved.
import torch
from torch import Tensor
from mmyolo.registry import TASK_UTILS
from mmyolo.models.task_modules.assigners import BatchTaskAlignedAssigner
from mmyolo.models.task_modules.assigners.utils import select_highest_overlaps
@TASK_UTILS.register_module()
class YOLOWorldSegAssigner(BatchTaskAlignedAssigner):
def __init__(self,
num_classes: int,
topk: int = 13,
alpha: float = 1,
beta: float = 6,
eps: float = 1e-7,
use_ciou: bool = False):
super().__init__(num_classes, topk, alpha, beta, eps, use_ciou)
@torch.no_grad()
def forward(
self,
pred_bboxes: Tensor,
pred_scores: Tensor,
priors: Tensor,
gt_labels: Tensor,
gt_bboxes: Tensor,
pad_bbox_flag: Tensor,
) -> dict:
"""Assign gt to bboxes.
The assignment is done in following steps
1. compute alignment metric between all bbox (bbox of all pyramid
levels) and gt
2. select top-k bbox as candidates for each gt
3. limit the positive sample's center in gt (because the anchor-free
detector only can predict positive distance)
Args:
pred_bboxes (Tensor): Predict bboxes,
shape(batch_size, num_priors, 4)
pred_scores (Tensor): Scores of predict bboxes,
shape(batch_size, num_priors, num_classes)
priors (Tensor): Model priors, shape (num_priors, 4)
gt_labels (Tensor): Ground true labels,
shape(batch_size, num_gt, 1)
gt_bboxes (Tensor): Ground true bboxes,
shape(batch_size, num_gt, 4)
pad_bbox_flag (Tensor): Ground truth bbox mask,
1 means bbox, 0 means no bbox,
shape(batch_size, num_gt, 1)
Returns:
assigned_result (dict) Assigned result:
assigned_labels (Tensor): Assigned labels,
shape(batch_size, num_priors)
assigned_bboxes (Tensor): Assigned boxes,
shape(batch_size, num_priors, 4)
assigned_scores (Tensor): Assigned scores,
shape(batch_size, num_priors, num_classes)
fg_mask_pre_prior (Tensor): Force ground truth matching mask,
shape(batch_size, num_priors)
"""
# (num_priors, 4) -> (num_priors, 2)
priors = priors[:, :2]
batch_size = pred_scores.size(0)
num_gt = gt_bboxes.size(1)
assigned_result = {
'assigned_labels':
gt_bboxes.new_full(pred_scores[..., 0].shape, self.num_classes),
'assigned_bboxes':
gt_bboxes.new_full(pred_bboxes.shape, 0),
'assigned_scores':
gt_bboxes.new_full(pred_scores.shape, 0),
'fg_mask_pre_prior':
gt_bboxes.new_full(pred_scores[..., 0].shape, 0)
}
if num_gt == 0:
return assigned_result
pos_mask, alignment_metrics, overlaps = self.get_pos_mask(
pred_bboxes, pred_scores, priors, gt_labels, gt_bboxes,
pad_bbox_flag, batch_size, num_gt)
(assigned_gt_idxs, fg_mask_pre_prior,
pos_mask) = select_highest_overlaps(pos_mask, overlaps, num_gt)
# assigned target
assigned_labels, assigned_bboxes, assigned_scores = self.get_targets(
gt_labels, gt_bboxes, assigned_gt_idxs, fg_mask_pre_prior,
batch_size, num_gt)
# normalize
alignment_metrics *= pos_mask
pos_align_metrics = alignment_metrics.max(axis=-1, keepdim=True)[0]
pos_overlaps = (overlaps * pos_mask).max(axis=-1, keepdim=True)[0]
norm_align_metric = (
alignment_metrics * pos_overlaps /
(pos_align_metrics + self.eps)).max(-2)[0].unsqueeze(-1)
assigned_scores = assigned_scores * norm_align_metric
assigned_result['assigned_labels'] = assigned_labels
assigned_result['assigned_bboxes'] = assigned_bboxes
assigned_result['assigned_scores'] = assigned_scores
assigned_result['fg_mask_pre_prior'] = fg_mask_pre_prior.bool()
assigned_result['assigned_gt_idxs'] = assigned_gt_idxs
return assigned_result