Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2022 The IDEA Authors. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import torch.nn as nn | |
from detrex.layers import ( | |
FFN, | |
MLP, | |
BaseTransformerLayer, | |
MultiheadAttention, | |
MultiScaleDeformableAttention, | |
TransformerLayerSequence, | |
get_sine_pos_embed, | |
) | |
from detrex.utils import inverse_sigmoid | |
from fairscale.nn.checkpoint import checkpoint_wrapper | |
class DINOTransformerEncoder(TransformerLayerSequence): | |
def __init__( | |
self, | |
embed_dim: int = 256, | |
num_heads: int = 8, | |
feedforward_dim: int = 1024, | |
attn_dropout: float = 0.1, | |
ffn_dropout: float = 0.1, | |
num_layers: int = 6, | |
post_norm: bool = False, | |
num_feature_levels: int = 4, | |
use_checkpoint: bool = False, | |
): | |
super(DINOTransformerEncoder, self).__init__( | |
transformer_layers=BaseTransformerLayer( | |
attn=MultiScaleDeformableAttention( | |
embed_dim=embed_dim, | |
num_heads=num_heads, | |
dropout=attn_dropout, | |
batch_first=True, | |
num_levels=num_feature_levels, | |
), | |
ffn=FFN( | |
embed_dim=embed_dim, | |
feedforward_dim=feedforward_dim, | |
output_dim=embed_dim, | |
num_fcs=2, | |
ffn_drop=ffn_dropout, | |
), | |
norm=nn.LayerNorm(embed_dim), | |
operation_order=("self_attn", "norm", "ffn", "norm"), | |
), | |
num_layers=num_layers, | |
) | |
self.embed_dim = self.layers[0].embed_dim | |
self.pre_norm = self.layers[0].pre_norm | |
if post_norm: | |
self.post_norm_layer = nn.LayerNorm(self.embed_dim) | |
else: | |
self.post_norm_layer = None | |
# use encoder checkpoint | |
if use_checkpoint: | |
for layer in self.layers: | |
layer = checkpoint_wrapper(layer) | |
def forward( | |
self, | |
query, | |
key, | |
value, | |
query_pos=None, | |
key_pos=None, | |
attn_masks=None, | |
query_key_padding_mask=None, | |
key_padding_mask=None, | |
**kwargs, | |
): | |
for layer in self.layers: | |
query = layer( | |
query, | |
key, | |
value, | |
query_pos=query_pos, | |
attn_masks=attn_masks, | |
query_key_padding_mask=query_key_padding_mask, | |
key_padding_mask=key_padding_mask, | |
**kwargs, | |
) | |
if self.post_norm_layer is not None: | |
query = self.post_norm_layer(query) | |
return query | |
class DINOTransformerDecoder(TransformerLayerSequence): | |
def __init__( | |
self, | |
embed_dim: int = 256, | |
num_heads: int = 8, | |
feedforward_dim: int = 1024, | |
attn_dropout: float = 0.1, | |
ffn_dropout: float = 0.1, | |
num_layers: int = 6, | |
return_intermediate: bool = True, | |
num_feature_levels: int = 4, | |
look_forward_twice: bool = True, | |
use_checkpoint: bool = True, | |
): | |
super(DINOTransformerDecoder, self).__init__( | |
transformer_layers=BaseTransformerLayer( | |
attn=[ | |
MultiheadAttention( | |
embed_dim=embed_dim, | |
num_heads=num_heads, | |
attn_drop=attn_dropout, | |
batch_first=True, | |
), | |
MultiScaleDeformableAttention( | |
embed_dim=embed_dim, | |
num_heads=num_heads, | |
dropout=attn_dropout, | |
batch_first=True, | |
num_levels=num_feature_levels, | |
), | |
], | |
ffn=FFN( | |
embed_dim=embed_dim, | |
feedforward_dim=feedforward_dim, | |
output_dim=embed_dim, | |
ffn_drop=ffn_dropout, | |
), | |
norm=nn.LayerNorm(embed_dim), | |
operation_order=("self_attn", "norm", "cross_attn", "norm", "ffn", "norm"), | |
), | |
num_layers=num_layers, | |
) | |
self.return_intermediate = return_intermediate | |
self.ref_point_head = MLP(2 * embed_dim, embed_dim, embed_dim, 2) | |
self.bbox_embed = None | |
self.class_embed = None | |
self.look_forward_twice = look_forward_twice | |
self.norm = nn.LayerNorm(embed_dim) | |
# decoder checkpoint | |
if use_checkpoint: | |
for layer in self.layers: | |
layer = checkpoint_wrapper(layer) | |
def forward( | |
self, | |
query, | |
key, | |
value, | |
query_pos=None, | |
key_pos=None, | |
attn_masks=None, | |
query_key_padding_mask=None, | |
key_padding_mask=None, | |
reference_points=None, # num_queries, 4. normalized. | |
valid_ratios=None, | |
**kwargs, | |
): | |
output = query | |
bs, num_queries, _ = output.size() | |
if reference_points.dim() == 2: | |
reference_points = reference_points.unsqueeze(0).repeat(bs, 1, 1) # bs, num_queries, 4 | |
intermediate = [] | |
intermediate_reference_points = [] | |
for layer_idx, layer in enumerate(self.layers): | |
if reference_points.shape[-1] == 4: | |
reference_points_input = ( | |
reference_points[:, :, None] | |
* torch.cat([valid_ratios, valid_ratios], -1)[:, None] | |
) | |
else: | |
assert reference_points.shape[-1] == 2 | |
reference_points_input = reference_points[:, :, None] * valid_ratios[:, None] | |
query_sine_embed = get_sine_pos_embed(reference_points_input[:, :, 0, :]) | |
query_pos = self.ref_point_head(query_sine_embed) | |
output = layer( | |
output, | |
key, | |
value, | |
query_pos=query_pos, | |
key_pos=key_pos, | |
query_sine_embed=query_sine_embed, | |
attn_masks=attn_masks, | |
query_key_padding_mask=query_key_padding_mask, | |
key_padding_mask=key_padding_mask, | |
reference_points=reference_points_input, | |
**kwargs, | |
) | |
if self.bbox_embed is not None: | |
tmp = self.bbox_embed[layer_idx](output) | |
if reference_points.shape[-1] == 4: | |
new_reference_points = tmp + inverse_sigmoid(reference_points) | |
new_reference_points = new_reference_points.sigmoid() | |
else: | |
assert reference_points.shape[-1] == 2 | |
new_reference_points = tmp | |
new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points) | |
new_reference_points = new_reference_points.sigmoid() | |
reference_points = new_reference_points.detach() | |
if self.return_intermediate: | |
intermediate.append(self.norm(output)) | |
if self.look_forward_twice: | |
intermediate_reference_points.append(new_reference_points) | |
else: | |
intermediate_reference_points.append(reference_points) | |
if self.return_intermediate: | |
return torch.stack(intermediate), torch.stack(intermediate_reference_points) | |
return output, reference_points | |
class DINOTransformer(nn.Module): | |
"""Transformer module for DINO | |
Args: | |
encoder (nn.Module): encoder module. | |
decoder (nn.Module): decoder module. | |
as_two_stage (bool): whether to use two-stage transformer. Default False. | |
num_feature_levels (int): number of feature levels. Default 4. | |
two_stage_num_proposals (int): number of proposals in two-stage transformer. Default 900. | |
""" | |
def __init__( | |
self, | |
encoder=None, | |
decoder=None, | |
num_feature_levels=4, | |
two_stage_num_proposals=900, | |
learnt_init_query=True, | |
): | |
super(DINOTransformer, self).__init__() | |
self.encoder = encoder | |
self.decoder = decoder | |
self.num_feature_levels = num_feature_levels | |
self.two_stage_num_proposals = two_stage_num_proposals | |
self.embed_dim = self.encoder.embed_dim | |
self.level_embeds = nn.Parameter(torch.Tensor(self.num_feature_levels, self.embed_dim)) | |
self.learnt_init_query = learnt_init_query | |
if self.learnt_init_query: | |
self.tgt_embed = nn.Embedding(self.two_stage_num_proposals, self.embed_dim) | |
self.enc_output = nn.Linear(self.embed_dim, self.embed_dim) | |
self.enc_output_norm = nn.LayerNorm(self.embed_dim) | |
self.init_weights() | |
def init_weights(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
for m in self.modules(): | |
if isinstance(m, MultiScaleDeformableAttention): | |
m.init_weights() | |
nn.init.normal_(self.level_embeds) | |
def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): | |
N, S, C = memory.shape | |
proposals = [] | |
_cur = 0 | |
for lvl, (H, W) in enumerate(spatial_shapes): | |
mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H * W)].view(N, H, W, 1) | |
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) | |
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) | |
grid_y, grid_x = torch.meshgrid( | |
torch.linspace(0, H - 1, H, dtype=torch.float32, device=memory.device), | |
torch.linspace(0, W - 1, W, dtype=torch.float32, device=memory.device), | |
) | |
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) | |
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2) | |
grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale | |
wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) | |
proposal = torch.cat((grid, wh), -1).view(N, -1, 4) | |
proposals.append(proposal) | |
_cur += H * W | |
output_proposals = torch.cat(proposals, 1) | |
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all( | |
-1, keepdim=True | |
) | |
output_proposals = torch.log(output_proposals / (1 - output_proposals)) | |
output_proposals = output_proposals.masked_fill( | |
memory_padding_mask.unsqueeze(-1), float("inf") | |
) | |
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf")) | |
output_memory = memory | |
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) | |
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) | |
output_memory = self.enc_output_norm(self.enc_output(output_memory)) | |
return output_memory, output_proposals | |
def get_reference_points(spatial_shapes, valid_ratios, device): | |
"""Get the reference points used in decoder. | |
Args: | |
spatial_shapes (Tensor): The shape of all | |
feature maps, has shape (num_level, 2). | |
valid_ratios (Tensor): The ratios of valid | |
points on the feature map, has shape | |
(bs, num_levels, 2) | |
device (obj:`device`): The device where | |
reference_points should be. | |
Returns: | |
Tensor: reference points used in decoder, has \ | |
shape (bs, num_keys, num_levels, 2). | |
""" | |
reference_points_list = [] | |
for lvl, (H, W) in enumerate(spatial_shapes): | |
# TODO check this 0.5 | |
ref_y, ref_x = torch.meshgrid( | |
torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device), | |
torch.linspace(0.5, W - 0.5, W, dtype=torch.float32, device=device), | |
) | |
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H) | |
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W) | |
ref = torch.stack((ref_x, ref_y), -1) | |
reference_points_list.append(ref) | |
reference_points = torch.cat(reference_points_list, 1) | |
reference_points = reference_points[:, :, None] * valid_ratios[:, None] | |
return reference_points | |
def get_valid_ratio(self, mask): | |
"""Get the valid ratios of feature maps of all levels.""" | |
_, H, W = mask.shape | |
valid_H = torch.sum(~mask[:, :, 0], 1) | |
valid_W = torch.sum(~mask[:, 0, :], 1) | |
valid_ratio_h = valid_H.float() / H | |
valid_ratio_w = valid_W.float() / W | |
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) | |
return valid_ratio | |
def forward( | |
self, | |
multi_level_feats, | |
multi_level_masks, | |
multi_level_pos_embeds, | |
query_embed, | |
attn_masks, | |
**kwargs, | |
): | |
feat_flatten = [] | |
mask_flatten = [] | |
lvl_pos_embed_flatten = [] | |
spatial_shapes = [] | |
for lvl, (feat, mask, pos_embed) in enumerate( | |
zip(multi_level_feats, multi_level_masks, multi_level_pos_embeds) | |
): | |
bs, c, h, w = feat.shape | |
spatial_shape = (h, w) | |
spatial_shapes.append(spatial_shape) | |
feat = feat.flatten(2).transpose(1, 2) # bs, hw, c | |
mask = mask.flatten(1) | |
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c | |
lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) | |
lvl_pos_embed_flatten.append(lvl_pos_embed) | |
feat_flatten.append(feat) | |
mask_flatten.append(mask) | |
feat_flatten = torch.cat(feat_flatten, 1) | |
mask_flatten = torch.cat(mask_flatten, 1) | |
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) | |
spatial_shapes = torch.as_tensor( | |
spatial_shapes, dtype=torch.long, device=feat_flatten.device | |
) | |
level_start_index = torch.cat( | |
(spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]) | |
) | |
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in multi_level_masks], 1) | |
reference_points = self.get_reference_points( | |
spatial_shapes, valid_ratios, device=feat.device | |
) | |
memory = self.encoder( | |
query=feat_flatten, | |
key=None, | |
value=None, | |
query_pos=lvl_pos_embed_flatten, | |
query_key_padding_mask=mask_flatten, | |
spatial_shapes=spatial_shapes, | |
reference_points=reference_points, # bs, num_token, num_level, 2 | |
level_start_index=level_start_index, | |
valid_ratios=valid_ratios, | |
**kwargs, | |
) | |
output_memory, output_proposals = self.gen_encoder_output_proposals( | |
memory, mask_flatten, spatial_shapes | |
) | |
# output_memory: bs, num_tokens, c | |
# output_proposals: bs, num_tokens, 4. unsigmoided. | |
enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory) | |
enc_outputs_coord_unact = ( | |
self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals | |
) # unsigmoided. | |
topk = self.two_stage_num_proposals | |
topk_proposals = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1] | |
# extract region proposal boxes | |
topk_coords_unact = torch.gather( | |
enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4) | |
) # unsigmoided. | |
reference_points = topk_coords_unact.detach().sigmoid() | |
if query_embed[1] is not None: | |
reference_points = torch.cat([query_embed[1].sigmoid(), reference_points], 1) | |
init_reference_out = reference_points | |
# extract region features | |
target_unact = torch.gather( | |
output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]) | |
) | |
if self.learnt_init_query: | |
target = self.tgt_embed.weight[None].repeat(bs, 1, 1) | |
else: | |
target = target_unact.detach() | |
if query_embed[0] is not None: | |
target = torch.cat([query_embed[0], target], 1) | |
# decoder | |
inter_states, inter_references = self.decoder( | |
query=target, # bs, num_queries, embed_dims | |
key=memory, # bs, num_tokens, embed_dims | |
value=memory, # bs, num_tokens, embed_dims | |
query_pos=None, | |
key_padding_mask=mask_flatten, # bs, num_tokens | |
reference_points=reference_points, # num_queries, 4 | |
spatial_shapes=spatial_shapes, # nlvl, 2 | |
level_start_index=level_start_index, # nlvl | |
valid_ratios=valid_ratios, # bs, nlvl, 2 | |
attn_masks=attn_masks, | |
**kwargs, | |
) | |
inter_references_out = inter_references | |
return ( | |
inter_states, | |
init_reference_out, | |
inter_references_out, | |
target_unact, | |
topk_coords_unact.sigmoid(), | |
) | |