HarborYuan's picture
add omg code
b34d1d6
raw
history blame
No virus
1.53 kB
from typing import Optional, Union
import torch
from mmdet.models.task_modules.assigners.match_cost import BaseMatchCost
from mmengine.structures import InstanceData
from torch import Tensor
from mmdet.registry import TASK_UTILS
@TASK_UTILS.register_module()
class FlexibleClassificationCost(BaseMatchCost):
def __init__(self, weight: Union[float, int] = 1) -> None:
super().__init__(weight=weight)
def __call__(self,
pred_instances: InstanceData,
gt_instances: InstanceData,
img_meta: Optional[dict] = None,
**kwargs) -> Tensor:
"""Compute match cost.
Args:
pred_instances (:obj:`InstanceData`): ``scores`` inside is
predicted classification logits, of shape
(num_queries, num_class).
gt_instances (:obj:`InstanceData`): ``labels`` inside should have
shape (num_gt, ).
img_meta (Optional[dict]): _description_. Defaults to None.
Returns:
Tensor: Match Cost matrix of shape (num_preds, num_gts).
"""
_pred_scores = pred_instances.scores
gt_labels = gt_instances.labels
pred_scores = _pred_scores[..., :-1]
iou_score = _pred_scores[..., -1:]
pred_scores = pred_scores.softmax(-1)
iou_score = iou_score.sigmoid()
pred_scores = torch.cat([pred_scores, iou_score], dim=-1)
cls_cost = -pred_scores[:, gt_labels]
return cls_cost * self.weight