scfive
Resolve README.md conflict and continue rebase
e8f2571
# Copyright (c) OpenMMLab. All rights reserved.
# 修改版本,增加了框缩放,默认不缩放,
from typing import Dict, Optional, Tuple,List, Union
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from torch.nn.init import normal_
from mmdet.registry import MODELS
from mmdet.structures import OptSampleList, SampleList
from mmdet.utils import OptConfigType
# from ..layers import (CdnQueryGenerator, DeformableDetrTransformerEncoder,
# DinoTransformerDecoder, SinePositionalEncoding)
from ..layers import SinePositionalEncoding
from ..layers.transformer.dino_layers import (CdnQueryGenerator, DeformableDetrTransformerEncoder,
DinoTransformerDecoder)
from .deformable_detr import DeformableDETR, MultiScaleDeformableAttention
@MODELS.register_module()
class DINO(DeformableDETR):
r"""Implementation of `DINO: DETR with Improved DeNoising Anchor Boxes
for End-to-End Object Detection <https://arxiv.org/abs/2203.03605>`_
Code is modified from the `official github repo
<https://github.com/IDEA-Research/DINO>`_.
Args:
dn_cfg (:obj:`ConfigDict` or dict, optional): Config of denoising
query generator. Defaults to `None`.
"""
def __init__(self, *args, dn_cfg: OptConfigType = None,
candidate_bboxes_size: float = 0.05,
scale_gt_bboxes_size: float = 0,
htd_2s: int = False,
**kwargs) -> None:
super().__init__(*args, **kwargs)
assert self.as_two_stage, 'as_two_stage must be True for DINO'
assert self.with_box_refine, 'with_box_refine must be True for DINO'
if dn_cfg is not None:
assert 'num_classes' not in dn_cfg and \
'num_queries' not in dn_cfg and \
'hidden_dim' not in dn_cfg, \
'The three keyword args `num_classes`, `embed_dims`, and ' \
'`num_matching_queries` are set in `detector.__init__()`, ' \
'users should not set them in `dn_cfg` config.'
dn_cfg['num_classes'] = self.bbox_head.num_classes
dn_cfg['embed_dims'] = self.embed_dims
dn_cfg['num_matching_queries'] = self.num_queries
self.dn_query_generator = CdnQueryGenerator(**dn_cfg)
self.scale_gt_bboxes_size = scale_gt_bboxes_size
self.candidate_bboxes_size = candidate_bboxes_size
self.htd_2s = htd_2s
def _init_layers(self) -> None:
"""Initialize layers except for backbone, neck and bbox_head."""
self.positional_encoding = SinePositionalEncoding(
**self.positional_encoding)
self.encoder = DeformableDetrTransformerEncoder(**self.encoder)
self.decoder = DinoTransformerDecoder(**self.decoder)
self.embed_dims = self.encoder.embed_dims
self.query_embedding = nn.Embedding(self.num_queries, self.embed_dims)
# NOTE In DINO, the query_embedding only contains content
# queries, while in Deformable DETR, the query_embedding
# contains both content and spatial queries, and in DETR,
# it only contains spatial queries.
num_feats = self.positional_encoding.num_feats
assert num_feats * 2 == self.embed_dims, \
f'embed_dims should be exactly 2 times of num_feats. ' \
f'Found {self.embed_dims} and {num_feats}.'
self.level_embed = nn.Parameter(
torch.Tensor(self.num_feature_levels, self.embed_dims))
self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims)
self.memory_trans_norm = nn.LayerNorm(self.embed_dims)
def init_weights(self) -> None:
"""Initialize weights for Transformer and other components."""
super(DeformableDETR, self).init_weights()
for coder in self.encoder, self.decoder:
for p in coder.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MultiScaleDeformableAttention):
m.init_weights()
nn.init.xavier_uniform_(self.memory_trans_fc.weight)
nn.init.xavier_uniform_(self.query_embedding.weight)
normal_(self.level_embed)
def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]:
"""Extract features.
Args:
batch_inputs (Tensor): Image tensor, has shape (bs, dim, H, W).
Returns:
tuple[Tensor]: Tuple of feature maps from neck. Each feature map
has shape (bs, dim, H, W).
"""
x = self.backbone(batch_inputs)
if self.with_neck:
x = self.neck(x)
return x
def loss(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> Union[dict, list]:
"""Calculate losses from a batch of inputs and data samples.
Args:
batch_inputs (Tensor): Input images of shape (bs, dim, H, W).
These should usually be mean centered and std scaled.
batch_data_samples (List[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Returns:
dict: A dictionary of loss components
"""
# add by lzx
if self.scale_gt_bboxes_size>0:
batch_data_samples = self.rescale_gt_bboxes(batch_data_samples, self.scale_gt_bboxes_size)
img_feats = self.extract_feat(batch_inputs)
head_inputs_dict = self.forward_transformer(img_feats,
batch_data_samples)
losses = self.bbox_head.loss(
**head_inputs_dict, batch_data_samples=batch_data_samples)
return losses
def predict(self,
batch_inputs: Tensor,
batch_data_samples: SampleList,
rescale: bool = True) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
Args:
batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W).
batch_data_samples (List[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
rescale (bool): Whether to rescale the results.
Defaults to True.
Returns:
list[:obj:`DetDataSample`]: Detection results of the input images.
Each DetDataSample usually contain 'pred_instances'. And the
`pred_instances` usually contains following keys.
- scores (Tensor): Classification scores, has a shape
(num_instance, )
- labels (Tensor): Labels of bboxes, has a shape
(num_instances, ).
- bboxes (Tensor): Has a shape (num_instances, 4),
the last dimension 4 arrange as (x1, y1, x2, y2).
"""
img_feats = self.extract_feat(batch_inputs)
head_inputs_dict = self.forward_transformer(img_feats,
batch_data_samples)
results_list = self.bbox_head.predict(
**head_inputs_dict,
rescale=rescale,
batch_data_samples=batch_data_samples)
batch_data_samples = self.add_pred_to_datasample(
batch_data_samples, results_list)
return batch_data_samples
def _forward(self,
batch_inputs: Tensor,
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
"""Network forward process. Usually includes backbone, neck and head
forward without any post-processing.
Args:
batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W).
batch_data_samples (List[:obj:`DetDataSample`], optional): The
batch data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Defaults to None.
Returns:
tuple[Tensor]: A tuple of features from ``bbox_head`` forward.
"""
img_feats = self.extract_feat(batch_inputs)
head_inputs_dict = self.forward_transformer(img_feats,
batch_data_samples)
results = self.bbox_head.forward(**head_inputs_dict)
return results
def forward_transformer(
self,
img_feats: Tuple[Tensor],
batch_data_samples: OptSampleList = None,
) -> Dict:
"""Forward process of Transformer.
The forward procedure of the transformer is defined as:
'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
More details can be found at `TransformerDetector.forward_transformer`
in `mmdet/detector/base_detr.py`.
The difference is that the ground truth in `batch_data_samples` is
required for the `pre_decoder` to prepare the query of DINO.
Additionally, DINO inherits the `pre_transformer` method and the
`forward_encoder` method of DeformableDETR. More details about the
two methods can be found in `mmdet/detector/deformable_detr.py`.
Args:
img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each
feature map has shape (bs, dim, H, W).
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Defaults to None.
Returns:
dict: The dictionary of bbox_head function inputs, which always
includes the `hidden_states` of the decoder output and may contain
`references` including the initial and intermediate references.
"""
encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer(
img_feats, batch_data_samples)
encoder_outputs_dict = self.forward_encoder(**encoder_inputs_dict)
tmp_dec_in, head_inputs_dict = self.pre_decoder(
**encoder_outputs_dict, batch_data_samples=batch_data_samples)
decoder_inputs_dict.update(tmp_dec_in)
decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict)
head_inputs_dict.update(decoder_outputs_dict)
return head_inputs_dict
def pre_transformer(
self,
mlvl_feats: Tuple[Tensor],
batch_data_samples: OptSampleList = None) -> Tuple[Dict]:
"""Process image features before feeding them to the transformer.
The forward procedure of the transformer is defined as:
'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
More details can be found at `TransformerDetector.forward_transformer`
in `mmdet/detector/base_detr.py`.
Args:
mlvl_feats (tuple[Tensor]): Multi-level features that may have
different resolutions, output from neck. Each feature has
shape (bs, dim, h_lvl, w_lvl), where 'lvl' means 'layer'.
batch_data_samples (list[:obj:`DetDataSample`], optional): The
batch data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Defaults to None.
Returns:
tuple[dict]: The first dict contains the inputs of encoder and the
second dict contains the inputs of decoder.
- encoder_inputs_dict (dict): The keyword args dictionary of
`self.forward_encoder()`, which includes 'feat', 'feat_mask',
and 'feat_pos'.
- decoder_inputs_dict (dict): The keyword args dictionary of
`self.forward_decoder()`, which includes 'memory_mask'.
"""
batch_size = mlvl_feats[0].size(0)
# construct binary masks for the transformer.
assert batch_data_samples is not None
batch_input_shape = batch_data_samples[0].batch_input_shape
img_shape_list = [sample.img_shape for sample in batch_data_samples]
input_img_h, input_img_w = batch_input_shape
masks = mlvl_feats[0].new_ones((batch_size, input_img_h, input_img_w))
for img_id in range(batch_size):
img_h, img_w = img_shape_list[img_id]
masks[img_id, :img_h, :img_w] = 0
# NOTE following the official DETR repo, non-zero values representing
# ignored positions, while zero values means valid positions.
mlvl_masks = []
mlvl_pos_embeds = []
for feat in mlvl_feats:
mlvl_masks.append(
F.interpolate(masks[None],
size=feat.shape[-2:]).to(torch.bool).squeeze(0))
mlvl_pos_embeds.append(self.positional_encoding(mlvl_masks[-1]))
feat_flatten = []
lvl_pos_embed_flatten = []
mask_flatten = []
spatial_shapes = []
for lvl, (feat, mask, pos_embed) in enumerate(
zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)):
batch_size, c, h, w = feat.shape
# [bs, c, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl, c]
feat = feat.view(batch_size, c, -1).permute(0, 2, 1)
pos_embed = pos_embed.view(batch_size, c, -1).permute(0, 2, 1)
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
# [bs, h_lvl, w_lvl] -> [bs, h_lvl*w_lvl]
mask = mask.flatten(1)
spatial_shape = (h, w)
feat_flatten.append(feat)
lvl_pos_embed_flatten.append(lvl_pos_embed)
mask_flatten.append(mask)
spatial_shapes.append(spatial_shape)
# (bs, num_feat_points, dim)
feat_flatten = torch.cat(feat_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
# (bs, num_feat_points), where num_feat_points = sum_lvl(h_lvl*w_lvl)
mask_flatten = torch.cat(mask_flatten, 1)
spatial_shapes = torch.as_tensor( # (num_level, 2)
spatial_shapes,
dtype=torch.long,
device=feat_flatten.device)
level_start_index = torch.cat((
spatial_shapes.new_zeros((1, )), # (num_level)
spatial_shapes.prod(1).cumsum(0)[:-1]))
valid_ratios = torch.stack( # (bs, num_level, 2)
[self.get_valid_ratio(m) for m in mlvl_masks], 1)
encoder_inputs_dict = dict(
feat=feat_flatten,
feat_mask=mask_flatten,
feat_pos=lvl_pos_embed_flatten,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios)
decoder_inputs_dict = dict(
memory_mask=mask_flatten,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios)
return encoder_inputs_dict, decoder_inputs_dict
def forward_encoder(self, feat: Tensor, feat_mask: Tensor,
feat_pos: Tensor, spatial_shapes: Tensor,
level_start_index: Tensor,
valid_ratios: Tensor) -> Dict:
"""Forward with Transformer encoder.
The forward procedure of the transformer is defined as:
'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
More details can be found at `TransformerDetector.forward_transformer`
in `mmdet/detector/base_detr.py`.
Args:
feat (Tensor): Sequential features, has shape (bs, num_feat_points,
dim).
feat_mask (Tensor): ByteTensor, the padding mask of the features,
has shape (bs, num_feat_points).
feat_pos (Tensor): The positional embeddings of the features, has
shape (bs, num_feat_points, dim).
spatial_shapes (Tensor): Spatial shapes of features in all levels,
has shape (num_levels, 2), last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels, ) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
valid_ratios (Tensor): The ratios of the valid width and the valid
height relative to the width and the height of features in all
levels, has shape (bs, num_levels, 2).
Returns:
dict: The dictionary of encoder outputs, which includes the
`memory` of the encoder output.
"""
memory = self.encoder(
query=feat,
query_pos=feat_pos,
key_padding_mask=feat_mask, # for self_attn
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios)
encoder_outputs_dict = dict(
memory=memory,
memory_mask=feat_mask,
spatial_shapes=spatial_shapes)
return encoder_outputs_dict
def pre_decoder(
self,
memory: Tensor,
memory_mask: Tensor,
spatial_shapes: Tensor,
batch_data_samples: OptSampleList = None,
) -> Tuple[Dict]:
"""Prepare intermediate variables before entering Transformer decoder,
such as `query`, `query_pos`, and `reference_points`.
Args:
memory (Tensor): The output embeddings of the Transformer encoder,
has shape (bs, num_feat_points, dim).
memory_mask (Tensor): ByteTensor, the padding mask of the memory,
has shape (bs, num_feat_points). Will only be used when
`as_two_stage` is `True`.
spatial_shapes (Tensor): Spatial shapes of features in all levels.
With shape (num_levels, 2), last dimension represents (h, w).
Will only be used when `as_two_stage` is `True`.
batch_data_samples (list[:obj:`DetDataSample`]): The batch
data samples. It usually includes information such
as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
Defaults to None.
Returns:
tuple[dict]: The decoder_inputs_dict and head_inputs_dict.
- decoder_inputs_dict (dict): The keyword dictionary args of
`self.forward_decoder()`, which includes 'query', 'memory',
`reference_points`, and `dn_mask`. The reference points of
decoder input here are 4D boxes, although it has `points`
in its name.
- head_inputs_dict (dict): The keyword dictionary args of the
bbox_head functions, which includes `topk_score`, `topk_coords`,
and `dn_meta` when `self.training` is `True`, else is empty.
"""
bs, _, c = memory.shape
cls_out_features = self.bbox_head.cls_branches[
self.decoder.num_layers].out_features
output_memory, output_proposals = self.gen_encoder_output_proposals(
memory, memory_mask, spatial_shapes)
output_memory = output_memory[:,:-1,:]
output_proposals = output_proposals[:,:-1,:]
enc_outputs_class = self.bbox_head.cls_branches[
self.decoder.num_layers](
output_memory)
enc_outputs_coord_unact = self.bbox_head.reg_branches[
self.decoder.num_layers](output_memory) + output_proposals
# NOTE The DINO selects top-k proposals according to scores of
# multi-class classification, while DeformDETR, where the input
# is `enc_outputs_class[..., 0]` selects according to scores of
# binary classification.
topk_indices = torch.topk(
enc_outputs_class.max(-1)[0], k=self.num_queries, dim=1)[1]
topk_score = torch.gather(
enc_outputs_class, 1,
topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features))
topk_coords_unact = torch.gather(
enc_outputs_coord_unact, 1,
topk_indices.unsqueeze(-1).repeat(1, 1, 4))
topk_coords = topk_coords_unact.sigmoid()
topk_coords_unact = topk_coords_unact.detach()
query = self.query_embedding.weight[:, None, :]
query = query.repeat(1, bs, 1).transpose(0, 1)
if self.training:
dn_label_query, dn_bbox_query, dn_mask, dn_meta = \
self.dn_query_generator(batch_data_samples)
query = torch.cat([dn_label_query, query], dim=1)
reference_points = torch.cat([dn_bbox_query, topk_coords_unact],
dim=1)
else:
reference_points = topk_coords_unact
dn_mask, dn_meta = None, None
reference_points = reference_points.sigmoid()
decoder_inputs_dict = dict(
query=query,
memory=memory,
reference_points=reference_points,
dn_mask=dn_mask)
# NOTE DINO calculates encoder losses on scores and coordinates
# of selected top-k encoder queries, while DeformDETR is of all
# encoder queries.
head_inputs_dict = dict(
enc_outputs_class=topk_score,
enc_outputs_coord=topk_coords,
dn_meta=dn_meta) if self.training else dict()
return decoder_inputs_dict, head_inputs_dict
def forward_decoder(self,
query: Tensor,
memory: Tensor,
memory_mask: Tensor,
reference_points: Tensor,
spatial_shapes: Tensor,
level_start_index: Tensor,
valid_ratios: Tensor,
dn_mask: Optional[Tensor] = None) -> Dict:
"""Forward with Transformer decoder.
The forward procedure of the transformer is defined as:
'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
More details can be found at `TransformerDetector.forward_transformer`
in `mmdet/detector/base_detr.py`.
Args:
query (Tensor): The queries of decoder inputs, has shape
(bs, num_queries_total, dim), where `num_queries_total` is the
sum of `num_denoising_queries` and `num_matching_queries` when
`self.training` is `True`, else `num_matching_queries`.
memory (Tensor): The output embeddings of the Transformer encoder,
has shape (bs, num_feat_points, dim).
memory_mask (Tensor): ByteTensor, the padding mask of the memory,
has shape (bs, num_feat_points).
reference_points (Tensor): The initial reference, has shape
(bs, num_queries_total, 4) with the last dimension arranged as
(cx, cy, w, h).
spatial_shapes (Tensor): Spatial shapes of features in all levels,
has shape (num_levels, 2), last dimension represents (h, w).
level_start_index (Tensor): The start index of each level.
A tensor has shape (num_levels, ) and can be represented
as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
valid_ratios (Tensor): The ratios of the valid width and the valid
height relative to the width and the height of features in all
levels, has shape (bs, num_levels, 2).
dn_mask (Tensor, optional): The attention mask to prevent
information leakage from different denoising groups and
matching parts, will be used as `self_attn_mask` of the
`self.decoder`, has shape (num_queries_total,
num_queries_total).
It is `None` when `self.training` is `False`.
Returns:
dict: The dictionary of decoder outputs, which includes the
`hidden_states` of the decoder output and `references` including
the initial and intermediate reference_points.
"""
inter_states, references = self.decoder(
query=query,
value=memory,
key_padding_mask=memory_mask,
self_attn_mask=dn_mask,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index,
valid_ratios=valid_ratios,
reg_branches=self.bbox_head.reg_branches)
# inter_states, references = self.decoder(
# query=query,
# value=memory[:,:-1,:],
# key_padding_mask=memory_mask[:,:-1],
# self_attn_mask=dn_mask,
# reference_points=reference_points,
# spatial_shapes=spatial_shapes[:-1],
# level_start_index=level_start_index[:-1],
# valid_ratios=valid_ratios[:,:-1, :],
# reg_branches=self.bbox_head.reg_branches)
if len(query) == self.num_queries:
# NOTE: This is to make sure label_embeding can be involved to
# produce loss even if there is no denoising query (no ground truth
# target in this GPU), otherwise, this will raise runtime error in
# distributed training.
inter_states[0] += \
self.dn_query_generator.label_embedding.weight[0, 0] * 0.0
decoder_outputs_dict = dict(
hidden_states=inter_states, references=list(references))
return decoder_outputs_dict
@staticmethod
def get_valid_ratio(mask: Tensor) -> Tensor:
"""Get the valid radios of feature map in a level.
.. code:: text
|---> valid_W <---|
---+-----------------+-----+---
A | | | A
| | | | |
| | | | |
valid_H | | | |
| | | | H
| | | | |
V | | | |
---+-----------------+ | |
| | V
+-----------------------+---
|---------> W <---------|
The valid_ratios are defined as:
r_h = valid_H / H, r_w = valid_W / W
They are the factors to re-normalize the relative coordinates of the
image to the relative coordinates of the current level feature map.
Args:
mask (Tensor): Binary mask of a feature map, has shape (bs, H, W).
Returns:
Tensor: valid ratios [r_w, r_h] of a feature map, has shape (1, 2).
"""
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def gen_encoder_output_proposals(
self, memory: Tensor, memory_mask: Tensor,
spatial_shapes: Tensor) -> Tuple[Tensor, Tensor]:
"""Generate proposals from encoded memory. The function will only be
used when `as_two_stage` is `True`.
Args:
memory (Tensor): The output embeddings of the Transformer encoder,
has shape (bs, num_feat_points, dim).
memory_mask (Tensor): ByteTensor, the padding mask of the memory,
has shape (bs, num_feat_points).
spatial_shapes (Tensor): Spatial shapes of features in all levels,
has shape (num_levels, 2), last dimension represents (h, w).
Returns:
tuple: A tuple of transformed memory and proposals.
- output_memory (Tensor): The transformed memory for obtaining
top-k proposals, has shape (bs, num_feat_points, dim).
- output_proposals (Tensor): The inverse-normalized proposal, has
shape (batch_size, num_keys, 4) with the last dimension arranged
as (cx, cy, w, h).
"""
bs = memory.size(0)
proposals = []
# memory_mask[:,-1] =True
_cur = 0 # start index in the sequence of the current level
for lvl, (H, W) in enumerate(spatial_shapes):
mask_flatten_ = memory_mask[:,
_cur:(_cur + H * W)].view(bs, H, W, 1)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1).unsqueeze(-1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1).unsqueeze(-1)
grid_y, grid_x = torch.meshgrid(
torch.linspace(
0, H - 1, H, dtype=torch.float32, device=memory.device),
torch.linspace(
0, W - 1, W, dtype=torch.float32, device=memory.device))
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
scale = torch.cat([valid_W, valid_H], 1).view(bs, 1, 1, 2)
grid = (grid.unsqueeze(0).expand(bs, -1, -1, -1) + 0.5) / scale
wh = torch.ones_like(grid) * self.candidate_bboxes_size * (2.0 ** lvl)
# wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
proposal = torch.cat((grid, wh), -1).view(bs, -1, 4)
proposals.append(proposal)
_cur += (H * W)
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) &
(output_proposals < 0.99)).all(
-1, keepdim=True)
if self.htd_2s:
output_proposals_valid = ((output_proposals > 0.0001) &
(output_proposals < 0.9999)).all(
-1, keepdim=True)
# inverse_sigmoid
output_proposals = torch.log(output_proposals / (1 - output_proposals))
output_proposals = output_proposals.masked_fill(
memory_mask.unsqueeze(-1), float('inf'))
output_proposals = output_proposals.masked_fill(
~output_proposals_valid, float('inf'))
output_memory = memory
output_memory = output_memory.masked_fill(
memory_mask.unsqueeze(-1), float(0))
output_memory = output_memory.masked_fill(~output_proposals_valid,
float(0))
output_memory = self.memory_trans_fc(output_memory)
output_memory = self.memory_trans_norm(output_memory)
# [bs, sum(hw), 2]
return output_memory, output_proposals
@staticmethod
def rescale_gt_bboxes(batch_data_samples:OptSampleList, scale_gt_bboxes_size:float = 0.25) -> OptSampleList:
for i_sample in range(len(batch_data_samples)):
gt_bboxes = batch_data_samples[i_sample].gt_instances.bboxes
gt_bboxes[:, :2] = gt_bboxes[:, :2] +scale_gt_bboxes_size
gt_bboxes[:, 2:] = gt_bboxes[:, 2:] - scale_gt_bboxes_size
# batch_data_samples[i_sample]['gt_instances']['bboxes'] = gt_bboxes
return batch_data_samples