File size: 31,260 Bytes
e8f2571 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 |
# Copyright (c) OpenMMLab. All rights reserved.
# 修改版本,增加了框缩放,默认不缩放,
from typing import Dict, Optional, Tuple,List, Union
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from torch.nn.init import normal_
from mmdet.registry import MODELS
from mmdet.structures import OptSampleList, SampleList
from mmdet.utils import OptConfigType
# from ..layers import (CdnQueryGenerator, DeformableDetrTransformerEncoder,
# DinoTransformerDecoder, SinePositionalEncoding)
from ..layers import SinePositionalEncoding
from ..layers.transformer.dino_layers import (CdnQueryGenerator, DeformableDetrTransformerEncoder,
DinoTransformerDecoder)
from .deformable_detr import DeformableDETR, MultiScaleDeformableAttention
@MODELS.register_module()
class DINO(DeformableDETR):
r"""Implementation of `DINO: DETR with Improved DeNoising Anchor Boxes
for End-to-End Object Detection <https://arxiv.org/abs/2203.03605>`_
Code is modified from the `official github repo
<https://github.com/IDEA-Research/DINO>`_.
Args:
dn_cfg (:obj:`ConfigDict` or dict, optional): Config of denoising
query generator. Defaults to `None`.
"""
def __init__(self, *args, dn_cfg: OptConfigType = None,
candidate_bboxes_size: float = 0.05,
scale_gt_bboxes_size: float = 0,
htd_2s: int = False,
**kwargs) -> None:
super().__init__(*args, **kwargs)
assert self.as_two_stage, 'as_two_stage must be True for DINO'
assert self.with_box_refine, 'with_box_refine must be True for DINO'
if dn_cfg is not None:
assert 'num_classes' not in dn_cfg and \
'num_queries' not in dn_cfg and \
'hidden_dim' not in dn_cfg, \
'The three keyword args `num_classes`, `embed_dims`, and ' \
'`num_matching_queries` are set in `detector.__init__()`, ' \
'users should not set them in `dn_cfg` config.'
dn_cfg['num_classes'] = self.bbox_head.num_classes
dn_cfg['embed_dims'] = self.embed_dims
dn_cfg['num_matching_queries'] = self.num_queries
self.dn_query_generator = CdnQueryGenerator(**dn_cfg)
self.scale_gt_bboxes_size = scale_gt_bboxes_size
self.candidate_bboxes_size = candidate_bboxes_size
self.htd_2s = htd_2s
def _init_layers(self) -> None:
"""Initialize layers except for backbone, neck and bbox_head."""
self.positional_encoding = SinePositionalEncoding(
**self.positional_encoding)
self.encoder = DeformableDetrTransformerEncoder(**self.encoder)
self.decoder = DinoTransformerDecoder(**self.decoder)
self.embed_dims = self.encoder.embed_dims
self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims)
# NOTE In DINO, the query_embedding only contains content
# queries, while in Deformable DETR, the query_embedding
# contains both content and spatial queries, and in DETR,
# it only contains spatial queries.
num_feats = self.positional_encoding.num_feats
assert num_feats * 2 == self.embed_dims, \
f'embed_dims should be exactly 2 times of num_feats. ' \
f'Found {self.embed_dims} and {num_feats}.'
self.level_embed = nn.Parameter(
torch.Tensor(self.num_feature_levels, self.embed_dims))
self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims)
self.memory_trans_norm = nn.LayerNorm(self.embed_dims)
def init_weights(self) -> None:
"""Initialize weights for Transformer and other components."""
super(DeformableDETR, self).init_weights()
for coder in self.encoder, self.decoder:
for p in coder.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MultiScaleDeformableAttention):
m.init_weights()
nn.init.xavier_uniform_(self.memory_trans_fc.weight)
nn.init.xavier_uniform_(self.query_embedding.weight)
normal_(self.level_embed)
def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:
"""Extract features.
Args:
batch_inputs (Tensor): Image tensor, has shape (bs, dim, H, W).
Returns:
tuple[Tensor]: Tuple of feature maps from neck. Each feature map
has shape (bs, dim, H, W).
"""
x = self.backbone(batch_inputs)
if self.with_neck:
x = self.neck(x)
return x
def loss(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> Union[dict, list]:
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs (Tensor): Input images of shape (bs, dim, H, W).
These should usually be mean centered and std scaled.
batch_data_samples (List[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Returns:
dict: A dictionary of loss components
"""
# add by lzx
if self.scale_gt_bboxes_size>0:
batch_data_samples = self.rescale_gt_bboxes(batch_data_samples, self.scale_gt_bboxes_size)
img_feats = self.extract_feat(batch_inputs)
head_inputs_dict = self.forward_transformer(img_feats,
batch_data_samples)
losses = self.bbox_head.loss(
**head_inputs_dict, batch_data_samples=batch_data_samples)
return losses
def predict(self,
batch_inputs: Tensor,
batch_data_samples: SampleList,
rescale: bool = True) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W).
batch_data_samples (List[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
rescale (bool): Whether to rescale the results.
Defaults to True.
Returns:
list[:obj:`DetDataSample`]: Detection results of the input images.
Each DetDataSample usually contain 'pred_instances'. And the
`pred_instances` usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
img_feats = self.extract_feat(batch_inputs)
head_inputs_dict = self.forward_transformer(img_feats,
batch_data_samples)
results_list = self.bbox_head.predict(
**head_inputs_dict,
rescale=rescale,
batch_data_samples=batch_data_samples)
batch_data_samples = self.add_pred_to_datasample(
batch_data_samples, results_list)
return batch_data_samples
def _forward(self,
batch_inputs: Tensor,
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
Args:
batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W).
batch_data_samples (List[:obj:`DetDataSample`], optional): The
batch data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Defaults to None.
Returns:
tuple[Tensor]: A tuple of features from ``bbox_head`` forward.
"""
img_feats = self.extract_feat(batch_inputs)
head_inputs_dict = self.forward_transformer(img_feats,
batch_data_samples)
results = self.bbox_head.forward(**head_inputs_dict)
return results
def forward_transformer(
self,
img_feats: Tuple[Tensor],
batch_data_samples: OptSampleList = None,
) -> Dict:
"""Forward process of Transformer.
The forward procedure of the transformer is defined as:
'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
More details can be found at `TransformerDetector.forward_transformer`
in `mmdet/detector/base_detr.py`.
The difference is that the ground truth in `batch_data_samples` is
required for the `pre_decoder` to prepare the query of DINO.
Additionally, DINO inherits the `pre_transformer` method and the
`forward_encoder` method of DeformableDETR. More details about the
two methods can be found in `mmdet/detector/deformable_detr.py`.
Args:
img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each
feature map has shape (bs, dim, H, W).
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Defaults to None.
Returns:
dict: The dictionary of bbox_head function inputs, which always
includes the `hidden_states` of the decoder output and may contain
`references` including the initial and intermediate references.
"""
encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer(
img_feats, batch_data_samples)
encoder_outputs_dict = self.forward_encoder(**encoder_inputs_dict)
tmp_dec_in, head_inputs_dict = self.pre_decoder(
**encoder_outputs_dict, batch_data_samples=batch_data_samples)
decoder_inputs_dict.update(tmp_dec_in)
decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict)
head_inputs_dict.update(decoder_outputs_dict)
return head_inputs_dict
def pre_transformer(
self,
mlvl_feats: Tuple[Tensor],
batch_data_samples: OptSampleList = None) -> Tuple[Dict]:
"""Process image features before feeding them to the transformer.
The forward procedure of the transformer is defined as:
'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
More details can be found at `TransformerDetector.forward_transformer`
in `mmdet/detector/base_detr.py`.
Args:
mlvl_feats (tuple[Tensor]): Multi-level features that may have
different resolutions, output from neck. Each feature has
shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'.
batch_data_samples (list[:obj:`DetDataSample`], optional): The
batch data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Defaults to None.
Returns:
tuple[dict]: The first dict contains the inputs of encoder and the
second dict contains the inputs of decoder.
- encoder_inputs_dict (dict): The keyword args dictionary of
`self.forward_encoder()`, which includes 'feat', 'feat_mask',
and 'feat_pos'.
- decoder_inputs_dict (dict): The keyword args dictionary of
`self.forward_decoder()`, which includes 'memory_mask'.
"""
batch_size = mlvl_feats[0].size(0)
# construct binary masks for the transformer.
assert batch_data_samples is not None
batch_input_shape = batch_data_samples[0].batch_input_shape
img_shape_list = [sample.img_shape for sample in batch_data_samples]
input_img_h, input_img_w = batch_input_shape
masks = mlvl_feats[0].new_ones((batch_size, input_img_h, input_img_w))
for img_id in range(batch_size):
img_h, img_w = img_shape_list[img_id]
masks[img_id, :img_h, :img_w] = 0
# NOTE following the official DETR repo, non-zero values representing
# ignored positions, while zero values means valid positions.
mlvl_masks = []
mlvl_pos_embeds = []
for feat in mlvl_feats:
mlvl_masks.append(
F.interpolate(masks[None],
size=feat.shape[-2:]).to(torch.bool).squeeze(0))
mlvl_pos_embeds.append(self.positional_encoding(mlvl_masks[-1]))
feat_flatten = []
lvl_pos_embed_flatten = []
mask_flatten = []
spatial_shapes = []
for lvl, (feat, mask, pos_embed) in enumerate(
zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
batch_size, c, h, w = feat.shape
# [bs, c, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl, c]
feat = feat.view(batch_size, c, -1).permute(0, 2, 1)
pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1)
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
# [bs, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl]
mask = mask.flatten(1)
spatial_shape = (h, w)
feat_flatten.append(feat)
lvl_pos_embed_flatten.append(lvl_pos_embed)
mask_flatten.append(mask)
spatial_shapes.append(spatial_shape)
# (bs, num_feat_points, dim)
feat_flatten = torch.cat(feat_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
# (bs, num_feat_points), where num_feat_points = sum_lvl(h_lvl*w_lvl)
mask_flatten = torch.cat(mask_flatten, 1)
spatial_shapes = torch.as_tensor( # (num_level, 2)
spatial_shapes,
dtype=torch.long,
device=feat_flatten.device)
level_start_index = torch.cat((
spatial_shapes.new_zeros((1, )), # (num_level)
spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack( # (bs, num_level, 2)
[self.get_valid_ratio(m) for m in mlvl_masks], 1)
encoder_inputs_dict = dict(
feat=feat_flatten,
feat_mask=mask_flatten,
feat_pos=lvl_pos_embed_flatten,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios)
decoder_inputs_dict = dict(
memory_mask=mask_flatten,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios)
return encoder_inputs_dict, decoder_inputs_dict
def forward_encoder(self, feat: Tensor, feat_mask: Tensor,
feat_pos: Tensor, spatial_shapes: Tensor,
level_start_index: Tensor,
valid_ratios: Tensor) -> Dict:
"""Forward with Transformer encoder.
The forward procedure of the transformer is defined as:
'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
More details can be found at `TransformerDetector.forward_transformer`
in `mmdet/detector/base_detr.py`.
Args:
feat (Tensor): Sequential features, has shape (bs, num_feat_points,
dim).
feat_mask (Tensor): ByteTensor, the padding mask of the features,
has shape (bs, num_feat_points).
feat_pos (Tensor): The positional embeddings of the features, has
shape (bs, num_feat_points, dim).
spatial_shapes (Tensor): Spatial shapes of features in all levels,
has shape (num_levels, 2), last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels, ) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
valid_ratios (Tensor): The ratios of the valid width and the valid
height relative to the width and the height of features in all
levels, has shape (bs, num_levels, 2).
Returns:
dict: The dictionary of encoder outputs, which includes the
`memory` of the encoder output.
"""
memory = self.encoder(
query=feat,
query_pos=feat_pos,
key_padding_mask=feat_mask, # for self_attn
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios)
encoder_outputs_dict = dict(
memory=memory,
memory_mask=feat_mask,
spatial_shapes=spatial_shapes)
return encoder_outputs_dict
def pre_decoder(
self,
memory: Tensor,
memory_mask: Tensor,
spatial_shapes: Tensor,
batch_data_samples: OptSampleList = None,
) -> Tuple[Dict]:
"""Prepare intermediate variables before entering Transformer decoder,
such as `query`, `query_pos`, and `reference_points`.
Args:
memory (Tensor): The output embeddings of the Transformer encoder,
has shape (bs, num_feat_points, dim).
memory_mask (Tensor): ByteTensor, the padding mask of the memory,
has shape (bs, num_feat_points). Will only be used when
`as_two_stage` is `True`.
spatial_shapes (Tensor): Spatial shapes of features in all levels.
With shape (num_levels, 2), last dimension represents (h, w).
Will only be used when `as_two_stage` is `True`.
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Defaults to None.
Returns:
tuple[dict]: The decoder_inputs_dict and head_inputs_dict.
- decoder_inputs_dict (dict): The keyword dictionary args of
`self.forward_decoder()`, which includes 'query', 'memory',
`reference_points`, and `dn_mask`. The reference points of
decoder input here are 4D boxes, although it has `points`
in its name.
- head_inputs_dict (dict): The keyword dictionary args of the
bbox_head functions, which includes `topk_score`, `topk_coords`,
and `dn_meta` when `self.training` is `True`, else is empty.
"""
bs, _, c = memory.shape
cls_out_features = self.bbox_head.cls_branches[
self.decoder.num_layers].out_features
output_memory, output_proposals = self.gen_encoder_output_proposals(
memory, memory_mask, spatial_shapes)
output_memory = output_memory[:,:-1,:]
output_proposals = output_proposals[:,:-1,:]
enc_outputs_class = self.bbox_head.cls_branches[
self.decoder.num_layers](
output_memory)
enc_outputs_coord_unact = self.bbox_head.reg_branches[
self.decoder.num_layers](output_memory) + output_proposals
# NOTE The DINO selects top-k proposals according to scores of
# multi-class classification, while DeformDETR, where the input
# is `enc_outputs_class[..., 0]` selects according to scores of
# binary classification.
topk_indices = torch.topk(
enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1)[1]
topk_score = torch.gather(
enc_outputs_class, 1,
topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features))
topk_coords_unact = torch.gather(
enc_outputs_coord_unact, 1,
topk_indices.unsqueeze(-1).repeat(1, 1, 4))
topk_coords = topk_coords_unact.sigmoid()
topk_coords_unact = topk_coords_unact.detach()
query = self.query_embedding.weight[:, None, :]
query = query.repeat(1, bs, 1).transpose(0, 1)
if self.training:
dn_label_query, dn_bbox_query, dn_mask, dn_meta = \
self.dn_query_generator(batch_data_samples)
query = torch.cat([dn_label_query, query], dim=1)
reference_points = torch.cat([dn_bbox_query, topk_coords_unact],
dim=1)
else:
reference_points = topk_coords_unact
dn_mask, dn_meta = None, None
reference_points = reference_points.sigmoid()
decoder_inputs_dict = dict(
query=query,
memory=memory,
reference_points=reference_points,
dn_mask=dn_mask)
# NOTE DINO calculates encoder losses on scores and coordinates
# of selected top-k encoder queries, while DeformDETR is of all
# encoder queries.
head_inputs_dict = dict(
enc_outputs_class=topk_score,
enc_outputs_coord=topk_coords,
dn_meta=dn_meta) if self.training else dict()
return decoder_inputs_dict, head_inputs_dict
def forward_decoder(self,
query: Tensor,
memory: Tensor,
memory_mask: Tensor,
reference_points: Tensor,
spatial_shapes: Tensor,
level_start_index: Tensor,
valid_ratios: Tensor,
dn_mask: Optional[Tensor] = None) -> Dict:
"""Forward with Transformer decoder.
The forward procedure of the transformer is defined as:
'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
More details can be found at `TransformerDetector.forward_transformer`
in `mmdet/detector/base_detr.py`.
Args:
query (Tensor): The queries of decoder inputs, has shape
(bs, num_queries_total, dim), where `num_queries_total` is the
sum of `num_denoising_queries` and `num_matching_queries` when
`self.training` is `True`, else `num_matching_queries`.
memory (Tensor): The output embeddings of the Transformer encoder,
has shape (bs, num_feat_points, dim).
memory_mask (Tensor): ByteTensor, the padding mask of the memory,
has shape (bs, num_feat_points).
reference_points (Tensor): The initial reference, has shape
(bs, num_queries_total, 4) with the last dimension arranged as
(cx, cy, w, h).
spatial_shapes (Tensor): Spatial shapes of features in all levels,
has shape (num_levels, 2), last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels, ) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
valid_ratios (Tensor): The ratios of the valid width and the valid
height relative to the width and the height of features in all
levels, has shape (bs, num_levels, 2).
dn_mask (Tensor, optional): The attention mask to prevent
information leakage from different denoising groups and
matching parts, will be used as `self_attn_mask` of the
`self.decoder`, has shape (num_queries_total,
num_queries_total).
It is `None` when `self.training` is `False`.
Returns:
dict: The dictionary of decoder outputs, which includes the
`hidden_states` of the decoder output and `references` including
the initial and intermediate reference_points.
"""
inter_states, references = self.decoder(
query=query,
value=memory,
key_padding_mask=memory_mask,
self_attn_mask=dn_mask,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
reg_branches=self.bbox_head.reg_branches)
# inter_states, references = self.decoder(
# query=query,
# value=memory[:,:-1,:],
# key_padding_mask=memory_mask[:,:-1],
# self_attn_mask=dn_mask,
# reference_points=reference_points,
# spatial_shapes=spatial_shapes[:-1],
# level_start_index=level_start_index[:-1],
# valid_ratios=valid_ratios[:,:-1, :],
# reg_branches=self.bbox_head.reg_branches)
if len(query) == self.num_queries:
# NOTE: This is to make sure label_embeding can be involved to
# produce loss even if there is no denoising query (no ground truth
# target in this GPU), otherwise, this will raise runtime error in
# distributed training.
inter_states[0] += \
self.dn_query_generator.label_embedding.weight[0, 0] * 0.0
decoder_outputs_dict = dict(
hidden_states=inter_states, references=list(references))
return decoder_outputs_dict
@staticmethod
def get_valid_ratio(mask: Tensor) -> Tensor:
"""Get the valid radios of feature map in a level.
.. code:: text
|---> valid_W <---|
---+-----------------+-----+---
A | | | A
| | | | |
| | | | |
valid_H | | | |
| | | | H
| | | | |
V | | | |
---+-----------------+ | |
| | V
+-----------------------+---
|---------> W <---------|
The valid_ratios are defined as:
r_h = valid_H / H, r_w = valid_W / W
They are the factors to re-normalize the relative coordinates of the
image to the relative coordinates of the current level feature map.
Args:
mask (Tensor): Binary mask of a feature map, has shape (bs, H, W).
Returns:
Tensor: valid ratios [r_w, r_h] of a feature map, has shape (1, 2).
"""
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def gen_encoder_output_proposals(
self, memory: Tensor, memory_mask: Tensor,
spatial_shapes: Tensor) -> Tuple[Tensor, Tensor]:
"""Generate proposals from encoded memory. The function will only be
used when `as_two_stage` is `True`.
Args:
memory (Tensor): The output embeddings of the Transformer encoder,
has shape (bs, num_feat_points, dim).
memory_mask (Tensor): ByteTensor, the padding mask of the memory,
has shape (bs, num_feat_points).
spatial_shapes (Tensor): Spatial shapes of features in all levels,
has shape (num_levels, 2), last dimension represents (h, w).
Returns:
tuple: A tuple of transformed memory and proposals.
- output_memory (Tensor): The transformed memory for obtaining
top-k proposals, has shape (bs, num_feat_points, dim).
- output_proposals (Tensor): The inverse-normalized proposal, has
shape (batch_size, num_keys, 4) with the last dimension arranged
as (cx, cy, w, h).
"""
bs = memory.size(0)
proposals = []
# memory_mask[:,-1] =True
_cur = 0 # start index in the sequence of the current level
for lvl, (H, W) in enumerate(spatial_shapes):
mask_flatten_ = memory_mask[:,
_cur:(_cur + H * W)].view(bs, H, W, 1)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1).unsqueeze(-1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1).unsqueeze(-1)
grid_y, grid_x = torch.meshgrid(
torch.linspace(
0, H - 1, H, dtype=torch.float32, device=memory.device),
torch.linspace(
0, W - 1, W, dtype=torch.float32, device=memory.device))
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
scale = torch.cat([valid_W, valid_H], 1).view(bs, 1, 1, 2)
grid = (grid.unsqueeze(0).expand(bs, -1, -1, -1) + 0.5) / scale
wh = torch.ones_like(grid) * self.candidate_bboxes_size * (2.0 ** lvl)
# wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
proposal = torch.cat((grid, wh), -1).view(bs, -1, 4)
proposals.append(proposal)
_cur += (H * W)
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) &
(output_proposals < 0.99)).all(
-1, keepdim=True)
if self.htd_2s:
output_proposals_valid = ((output_proposals > 0.0001) &
(output_proposals < 0.9999)).all(
-1, keepdim=True)
# inverse_sigmoid
output_proposals = torch.log(output_proposals / (1 - output_proposals))
output_proposals = output_proposals.masked_fill(
memory_mask.unsqueeze(-1), float('inf'))
output_proposals = output_proposals.masked_fill(
~output_proposals_valid, float('inf'))
output_memory = memory
output_memory = output_memory.masked_fill(
memory_mask.unsqueeze(-1), float(0))
output_memory = output_memory.masked_fill(~output_proposals_valid,
float(0))
output_memory = self.memory_trans_fc(output_memory)
output_memory = self.memory_trans_norm(output_memory)
# [bs, sum(hw), 2]
return output_memory, output_proposals
@staticmethod
def rescale_gt_bboxes(batch_data_samples:OptSampleList, scale_gt_bboxes_size:float = 0.25) -> OptSampleList:
for i_sample in range(len(batch_data_samples)):
gt_bboxes = batch_data_samples[i_sample].gt_instances.bboxes
gt_bboxes[:, :2] = gt_bboxes[:, :2] +scale_gt_bboxes_size
gt_bboxes[:, 2:] = gt_bboxes[:, 2:] - scale_gt_bboxes_size
# batch_data_samples[i_sample]['gt_instances']['bboxes'] = gt_bboxes
return batch_data_samples
|