Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,762 Bytes
e85fecb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
"""
Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
Modules to compute the matching cost and solve the corresponding LSAP.
Copyright (c) 2024 The D-FINE Authors All Rights Reserved.
"""
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from ...core import register
from .box_ops import box_cxcywh_to_xyxy, generalized_box_iou
@register()
class HungarianMatcher(nn.Module):
"""This class computes an assignment between the targets and the predictions of the network
For efficiency reasons, the targets don't include the no_object. Because of this, in general,
there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
while the others are un-matched (and thus treated as non-objects).
"""
__share__ = [
"use_focal_loss",
]
def __init__(self, weight_dict, use_focal_loss=False, alpha=0.25, gamma=2.0):
"""Creates the matcher
Params:
cost_class: This is the relative weight of the classification error in the matching cost
cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
"""
super().__init__()
self.cost_class = weight_dict["cost_class"]
self.cost_bbox = weight_dict["cost_bbox"]
self.cost_giou = weight_dict["cost_giou"]
self.use_focal_loss = use_focal_loss
self.alpha = alpha
self.gamma = gamma
assert (
self.cost_class != 0 or self.cost_bbox != 0 or self.cost_giou != 0
), "all costs cant be 0"
@torch.no_grad()
def forward(self, outputs: Dict[str, torch.Tensor], targets, return_topk=False):
"""Performs the matching
Params:
outputs: This is a dict that contains at least these entries:
"pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
"pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
"labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
objects in the target) containing the class labels
"boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
Returns:
A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected targets (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
"""
bs, num_queries = outputs["pred_logits"].shape[:2]
# We flatten to compute the cost matrices in a batch
if self.use_focal_loss:
out_prob = F.sigmoid(outputs["pred_logits"].flatten(0, 1))
else:
out_prob = (
outputs["pred_logits"].flatten(0, 1).softmax(-1)
) # [batch_size * num_queries, num_classes]
out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
# Also concat the target labels and boxes
tgt_ids = torch.cat([v["labels"] for v in targets])
tgt_bbox = torch.cat([v["boxes"] for v in targets])
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
if self.use_focal_loss:
out_prob = out_prob[:, tgt_ids]
neg_cost_class = (
(1 - self.alpha) * (out_prob**self.gamma) * (-(1 - out_prob + 1e-8).log())
)
pos_cost_class = (
self.alpha * ((1 - out_prob) ** self.gamma) * (-(out_prob + 1e-8).log())
)
cost_class = pos_cost_class - neg_cost_class
else:
cost_class = -out_prob[:, tgt_ids]
# Compute the L1 cost between boxes
cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
# Compute the giou cost betwen boxes
cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
# Final cost matrix 3 * self.cost_bbox + 2 * self.cost_class + self.cost_giou
C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
C = C.view(bs, num_queries, -1).cpu()
sizes = [len(v["boxes"]) for v in targets]
C = torch.nan_to_num(C, nan=1.0)
indices_pre = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
indices = [
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
for i, j in indices_pre
]
# Compute topk indices
if return_topk:
return {
"indices_o2m": self.get_top_k_matches(
C, sizes=sizes, k=return_topk, initial_indices=indices_pre
)
}
return {"indices": indices} # , 'indices_o2m': C.min(-1)[1]}
def get_top_k_matches(self, C, sizes, k=1, initial_indices=None):
indices_list = []
# C_original = C.clone()
for i in range(k):
indices_k = (
[linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
if i > 0
else initial_indices
)
indices_list.append(
[
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64))
for i, j in indices_k
]
)
for c, idx_k in zip(C.split(sizes, -1), indices_k):
idx_k = np.stack(idx_k)
c[:, idx_k] = 1e6
indices_list = [
(
torch.cat([indices_list[i][j][0] for i in range(k)], dim=0),
torch.cat([indices_list[i][j][1] for i in range(k)], dim=0),
)
for j in range(len(sizes))
]
# C.copy_(C_original)
return indices_list
|