File size: 1,531 Bytes
b34d1d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from typing import Optional, Union

import torch
from mmdet.models.task_modules.assigners.match_cost import BaseMatchCost
from mmengine.structures import InstanceData
from torch import Tensor

from mmdet.registry import TASK_UTILS


@TASK_UTILS.register_module()
class FlexibleClassificationCost(BaseMatchCost):
    def __init__(self, weight: Union[float, int] = 1) -> None:
        super().__init__(weight=weight)

    def __call__(self,
                 pred_instances: InstanceData,
                 gt_instances: InstanceData,
                 img_meta: Optional[dict] = None,
                 **kwargs) -> Tensor:
        """Compute match cost.

        Args:
            pred_instances (:obj:`InstanceData`): ``scores`` inside is
                predicted classification logits, of shape
                (num_queries, num_class).
            gt_instances (:obj:`InstanceData`): ``labels`` inside should have
                shape (num_gt, ).
            img_meta (Optional[dict]): _description_. Defaults to None.

        Returns:
            Tensor: Match Cost matrix of shape (num_preds, num_gts).
        """
        _pred_scores = pred_instances.scores
        gt_labels = gt_instances.labels

        pred_scores = _pred_scores[..., :-1]
        iou_score = _pred_scores[..., -1:]

        pred_scores = pred_scores.softmax(-1)
        iou_score = iou_score.sigmoid()
        pred_scores = torch.cat([pred_scores, iou_score], dim=-1)
        cls_cost = -pred_scores[:, gt_labels]

        return cls_cost * self.weight