Spaces:
Starting
on
A10G
Starting
on
A10G
from typing import Optional, Union | |
import torch | |
from mmdet.models.task_modules.assigners.match_cost import BaseMatchCost | |
from mmengine.structures import InstanceData | |
from torch import Tensor | |
from mmdet.registry import TASK_UTILS | |
class FlexibleClassificationCost(BaseMatchCost): | |
def __init__(self, weight: Union[float, int] = 1) -> None: | |
super().__init__(weight=weight) | |
def __call__(self, | |
pred_instances: InstanceData, | |
gt_instances: InstanceData, | |
img_meta: Optional[dict] = None, | |
**kwargs) -> Tensor: | |
"""Compute match cost. | |
Args: | |
pred_instances (:obj:`InstanceData`): ``scores`` inside is | |
predicted classification logits, of shape | |
(num_queries, num_class). | |
gt_instances (:obj:`InstanceData`): ``labels`` inside should have | |
shape (num_gt, ). | |
img_meta (Optional[dict]): _description_. Defaults to None. | |
Returns: | |
Tensor: Match Cost matrix of shape (num_preds, num_gts). | |
""" | |
_pred_scores = pred_instances.scores | |
gt_labels = gt_instances.labels | |
pred_scores = _pred_scores[..., :-1] | |
iou_score = _pred_scores[..., -1:] | |
pred_scores = pred_scores.softmax(-1) | |
iou_score = iou_score.sigmoid() | |
pred_scores = torch.cat([pred_scores, iou_score], dim=-1) | |
cls_cost = -pred_scores[:, gt_labels] | |
return cls_cost * self.weight | |