Spaces:
Runtime error
Runtime error
File size: 10,750 Bytes
3e06e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
import torch
from torch import Tensor
from mmdet.registry import MODELS
from mmdet.structures import SampleList
from mmdet.structures.bbox import bbox_overlaps
from mmdet.utils import ConfigType, InstanceList, OptInstanceList, reduce_mean
from ..utils import multi_apply, unpack_gt_instances
from .gfl_head import GFLHead
@MODELS.register_module()
class LDHead(GFLHead):
"""Localization distillation Head. (Short description)
It utilizes the learned bbox distributions to transfer the localization
dark knowledge from teacher to student. Original paper: `Localization
Distillation for Object Detection. <https://arxiv.org/abs/2102.12252>`_
Args:
num_classes (int): Number of categories excluding the background
category.
in_channels (int): Number of channels in the input feature map.
loss_ld (:obj:`ConfigDict` or dict): Config of Localization
Distillation Loss (LD), T is the temperature for distillation.
"""
def __init__(self,
num_classes: int,
in_channels: int,
loss_ld: ConfigType = dict(
type='LocalizationDistillationLoss',
loss_weight=0.25,
T=10),
**kwargs) -> dict:
super().__init__(
num_classes=num_classes, in_channels=in_channels, **kwargs)
self.loss_ld = MODELS.build(loss_ld)
def loss_by_feat_single(self, anchors: Tensor, cls_score: Tensor,
bbox_pred: Tensor, labels: Tensor,
label_weights: Tensor, bbox_targets: Tensor,
stride: Tuple[int], soft_targets: Tensor,
avg_factor: int):
"""Calculate the loss of a single scale level based on the features
extracted by the detection head.
Args:
anchors (Tensor): Box reference for each scale level with shape
(N, num_total_anchors, 4).
cls_score (Tensor): Cls and quality joint scores for each scale
level has shape (N, num_classes, H, W).
bbox_pred (Tensor): Box distribution logits for each scale
level with shape (N, 4*(n+1), H, W), n is max value of integral
set.
labels (Tensor): Labels of each anchors with shape
(N, num_total_anchors).
label_weights (Tensor): Label weights of each anchor with shape
(N, num_total_anchors)
bbox_targets (Tensor): BBox regression targets of each anchor
weight shape (N, num_total_anchors, 4).
stride (tuple): Stride in this scale level.
soft_targets (Tensor): Soft BBox regression targets.
avg_factor (int): Average factor that is used to average
the loss. When using sampling method, avg_factor is usually
the sum of positive and negative priors. When using
`PseudoSampler`, `avg_factor` is usually equal to the number
of positive priors.
Returns:
dict[tuple, Tensor]: Loss components and weight targets.
"""
assert stride[0] == stride[1], 'h stride is not equal to w stride!'
anchors = anchors.reshape(-1, 4)
cls_score = cls_score.permute(0, 2, 3,
1).reshape(-1, self.cls_out_channels)
bbox_pred = bbox_pred.permute(0, 2, 3,
1).reshape(-1, 4 * (self.reg_max + 1))
soft_targets = soft_targets.permute(0, 2, 3,
1).reshape(-1,
4 * (self.reg_max + 1))
bbox_targets = bbox_targets.reshape(-1, 4)
labels = labels.reshape(-1)
label_weights = label_weights.reshape(-1)
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes
bg_class_ind = self.num_classes
pos_inds = ((labels >= 0)
& (labels < bg_class_ind)).nonzero().squeeze(1)
score = label_weights.new_zeros(labels.shape)
if len(pos_inds) > 0:
pos_bbox_targets = bbox_targets[pos_inds]
pos_bbox_pred = bbox_pred[pos_inds]
pos_anchors = anchors[pos_inds]
pos_anchor_centers = self.anchor_center(pos_anchors) / stride[0]
weight_targets = cls_score.detach().sigmoid()
weight_targets = weight_targets.max(dim=1)[0][pos_inds]
pos_bbox_pred_corners = self.integral(pos_bbox_pred)
pos_decode_bbox_pred = self.bbox_coder.decode(
pos_anchor_centers, pos_bbox_pred_corners)
pos_decode_bbox_targets = pos_bbox_targets / stride[0]
score[pos_inds] = bbox_overlaps(
pos_decode_bbox_pred.detach(),
pos_decode_bbox_targets,
is_aligned=True)
pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1)
pos_soft_targets = soft_targets[pos_inds]
soft_corners = pos_soft_targets.reshape(-1, self.reg_max + 1)
target_corners = self.bbox_coder.encode(pos_anchor_centers,
pos_decode_bbox_targets,
self.reg_max).reshape(-1)
# regression loss
loss_bbox = self.loss_bbox(
pos_decode_bbox_pred,
pos_decode_bbox_targets,
weight=weight_targets,
avg_factor=1.0)
# dfl loss
loss_dfl = self.loss_dfl(
pred_corners,
target_corners,
weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
avg_factor=4.0)
# ld loss
loss_ld = self.loss_ld(
pred_corners,
soft_corners,
weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
avg_factor=4.0)
else:
loss_ld = bbox_pred.sum() * 0
loss_bbox = bbox_pred.sum() * 0
loss_dfl = bbox_pred.sum() * 0
weight_targets = bbox_pred.new_tensor(0)
# cls (qfl) loss
loss_cls = self.loss_cls(
cls_score, (labels, score),
weight=label_weights,
avg_factor=avg_factor)
return loss_cls, loss_bbox, loss_dfl, loss_ld, weight_targets.sum()
def loss(self, x: List[Tensor], out_teacher: Tuple[Tensor],
batch_data_samples: SampleList) -> dict:
"""
Args:
x (list[Tensor]): Features from FPN.
out_teacher (tuple[Tensor]): The output of teacher.
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Returns:
tuple[dict, list]: The loss components and proposals of each image.
- losses (dict[str, Tensor]): A dictionary of loss components.
- proposal_list (list[Tensor]): Proposals of each image.
"""
outputs = unpack_gt_instances(batch_data_samples)
batch_gt_instances, batch_gt_instances_ignore, batch_img_metas \
= outputs
outs = self(x)
soft_targets = out_teacher[1]
loss_inputs = outs + (batch_gt_instances, batch_img_metas,
soft_targets)
losses = self.loss_by_feat(
*loss_inputs, batch_gt_instances_ignore=batch_gt_instances_ignore)
return losses
def loss_by_feat(
self,
cls_scores: List[Tensor],
bbox_preds: List[Tensor],
batch_gt_instances: InstanceList,
batch_img_metas: List[dict],
soft_targets: List[Tensor],
batch_gt_instances_ignore: OptInstanceList = None) -> dict:
"""Compute losses of the head.
Args:
cls_scores (list[Tensor]): Cls and quality scores for each scale
level has shape (N, num_classes, H, W).
bbox_preds (list[Tensor]): Box distribution logits for each scale
level with shape (N, 4*(n+1), H, W), n is max value of integral
set.
batch_gt_instances (list[:obj:`InstanceData`]): Batch of
gt_instance. It usually includes ``bboxes`` and ``labels``
attributes.
soft_targets (list[Tensor]): Soft BBox regression targets.
batch_img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional):
Batch of gt_instances_ignore. It includes ``bboxes`` attribute
data that is ignored during training and testing.
Defaults to None.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
assert len(featmap_sizes) == self.prior_generator.num_levels
device = cls_scores[0].device
anchor_list, valid_flag_list = self.get_anchors(
featmap_sizes, batch_img_metas, device=device)
cls_reg_targets = self.get_targets(
anchor_list,
valid_flag_list,
batch_gt_instances,
batch_img_metas,
batch_gt_instances_ignore=batch_gt_instances_ignore)
(anchor_list, labels_list, label_weights_list, bbox_targets_list,
bbox_weights_list, avg_factor) = cls_reg_targets
avg_factor = reduce_mean(
torch.tensor(avg_factor, dtype=torch.float, device=device)).item()
losses_cls, losses_bbox, losses_dfl, losses_ld, \
avg_factor = multi_apply(
self.loss_by_feat_single,
anchor_list,
cls_scores,
bbox_preds,
labels_list,
label_weights_list,
bbox_targets_list,
self.prior_generator.strides,
soft_targets,
avg_factor=avg_factor)
avg_factor = sum(avg_factor) + 1e-6
avg_factor = reduce_mean(avg_factor).item()
losses_bbox = [x / avg_factor for x in losses_bbox]
losses_dfl = [x / avg_factor for x in losses_dfl]
return dict(
loss_cls=losses_cls,
loss_bbox=losses_bbox,
loss_dfl=losses_dfl,
loss_ld=losses_ld)
|