odor-detection / projects /dino /modeling /dino_transformer.py
mathiaszinnen's picture
Initialize app
3e99b05
# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from detrex.layers import (
FFN,
MLP,
BaseTransformerLayer,
MultiheadAttention,
MultiScaleDeformableAttention,
TransformerLayerSequence,
get_sine_pos_embed,
)
from detrex.utils import inverse_sigmoid
from fairscale.nn.checkpoint import checkpoint_wrapper
class DINOTransformerEncoder(TransformerLayerSequence):
def __init__(
self,
embed_dim: int = 256,
num_heads: int = 8,
feedforward_dim: int = 1024,
attn_dropout: float = 0.1,
ffn_dropout: float = 0.1,
num_layers: int = 6,
post_norm: bool = False,
num_feature_levels: int = 4,
use_checkpoint: bool = False,
):
super(DINOTransformerEncoder, self).__init__(
transformer_layers=BaseTransformerLayer(
attn=MultiScaleDeformableAttention(
embed_dim=embed_dim,
num_heads=num_heads,
dropout=attn_dropout,
batch_first=True,
num_levels=num_feature_levels,
),
ffn=FFN(
embed_dim=embed_dim,
feedforward_dim=feedforward_dim,
output_dim=embed_dim,
num_fcs=2,
ffn_drop=ffn_dropout,
),
norm=nn.LayerNorm(embed_dim),
operation_order=("self_attn", "norm", "ffn", "norm"),
),
num_layers=num_layers,
)
self.embed_dim = self.layers[0].embed_dim
self.pre_norm = self.layers[0].pre_norm
if post_norm:
self.post_norm_layer = nn.LayerNorm(self.embed_dim)
else:
self.post_norm_layer = None
# use encoder checkpoint
if use_checkpoint:
for layer in self.layers:
layer = checkpoint_wrapper(layer)
def forward(
self,
query,
key,
value,
query_pos=None,
key_pos=None,
attn_masks=None,
query_key_padding_mask=None,
key_padding_mask=None,
**kwargs,
):
for layer in self.layers:
query = layer(
query,
key,
value,
query_pos=query_pos,
attn_masks=attn_masks,
query_key_padding_mask=query_key_padding_mask,
key_padding_mask=key_padding_mask,
**kwargs,
)
if self.post_norm_layer is not None:
query = self.post_norm_layer(query)
return query
class DINOTransformerDecoder(TransformerLayerSequence):
def __init__(
self,
embed_dim: int = 256,
num_heads: int = 8,
feedforward_dim: int = 1024,
attn_dropout: float = 0.1,
ffn_dropout: float = 0.1,
num_layers: int = 6,
return_intermediate: bool = True,
num_feature_levels: int = 4,
look_forward_twice: bool = True,
use_checkpoint: bool = True,
):
super(DINOTransformerDecoder, self).__init__(
transformer_layers=BaseTransformerLayer(
attn=[
MultiheadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
attn_drop=attn_dropout,
batch_first=True,
),
MultiScaleDeformableAttention(
embed_dim=embed_dim,
num_heads=num_heads,
dropout=attn_dropout,
batch_first=True,
num_levels=num_feature_levels,
),
],
ffn=FFN(
embed_dim=embed_dim,
feedforward_dim=feedforward_dim,
output_dim=embed_dim,
ffn_drop=ffn_dropout,
),
norm=nn.LayerNorm(embed_dim),
operation_order=("self_attn", "norm", "cross_attn", "norm", "ffn", "norm"),
),
num_layers=num_layers,
)
self.return_intermediate = return_intermediate
self.ref_point_head = MLP(2 * embed_dim, embed_dim, embed_dim, 2)
self.bbox_embed = None
self.class_embed = None
self.look_forward_twice = look_forward_twice
self.norm = nn.LayerNorm(embed_dim)
# decoder checkpoint
if use_checkpoint:
for layer in self.layers:
layer = checkpoint_wrapper(layer)
def forward(
self,
query,
key,
value,
query_pos=None,
key_pos=None,
attn_masks=None,
query_key_padding_mask=None,
key_padding_mask=None,
reference_points=None, # num_queries, 4. normalized.
valid_ratios=None,
**kwargs,
):
output = query
bs, num_queries, _ = output.size()
if reference_points.dim() == 2:
reference_points = reference_points.unsqueeze(0).repeat(bs, 1, 1) # bs, num_queries, 4
intermediate = []
intermediate_reference_points = []
for layer_idx, layer in enumerate(self.layers):
if reference_points.shape[-1] == 4:
reference_points_input = (
reference_points[:, :, None]
* torch.cat([valid_ratios, valid_ratios], -1)[:, None]
)
else:
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :, None] * valid_ratios[:, None]
query_sine_embed = get_sine_pos_embed(reference_points_input[:, :, 0, :])
query_pos = self.ref_point_head(query_sine_embed)
output = layer(
output,
key,
value,
query_pos=query_pos,
key_pos=key_pos,
query_sine_embed=query_sine_embed,
attn_masks=attn_masks,
query_key_padding_mask=query_key_padding_mask,
key_padding_mask=key_padding_mask,
reference_points=reference_points_input,
**kwargs,
)
if self.bbox_embed is not None:
tmp = self.bbox_embed[layer_idx](output)
if reference_points.shape[-1] == 4:
new_reference_points = tmp + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
else:
assert reference_points.shape[-1] == 2
new_reference_points = tmp
new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
reference_points = new_reference_points.detach()
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.look_forward_twice:
intermediate_reference_points.append(new_reference_points)
else:
intermediate_reference_points.append(reference_points)
if self.return_intermediate:
return torch.stack(intermediate), torch.stack(intermediate_reference_points)
return output, reference_points
class DINOTransformer(nn.Module):
"""Transformer module for DINO
Args:
encoder (nn.Module): encoder module.
decoder (nn.Module): decoder module.
as_two_stage (bool): whether to use two-stage transformer. Default False.
num_feature_levels (int): number of feature levels. Default 4.
two_stage_num_proposals (int): number of proposals in two-stage transformer. Default 900.
"""
def __init__(
self,
encoder=None,
decoder=None,
num_feature_levels=4,
two_stage_num_proposals=900,
learnt_init_query=True,
):
super(DINOTransformer, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.num_feature_levels = num_feature_levels
self.two_stage_num_proposals = two_stage_num_proposals
self.embed_dim = self.encoder.embed_dim
self.level_embeds = nn.Parameter(torch.Tensor(self.num_feature_levels, self.embed_dim))
self.learnt_init_query = learnt_init_query
if self.learnt_init_query:
self.tgt_embed = nn.Embedding(self.two_stage_num_proposals, self.embed_dim)
self.enc_output = nn.Linear(self.embed_dim, self.embed_dim)
self.enc_output_norm = nn.LayerNorm(self.embed_dim)
self.init_weights()
def init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for m in self.modules():
if isinstance(m, MultiScaleDeformableAttention):
m.init_weights()
nn.init.normal_(self.level_embeds)
def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
N, S, C = memory.shape
proposals = []
_cur = 0
for lvl, (H, W) in enumerate(spatial_shapes):
mask_flatten_ = memory_padding_mask[:, _cur : (_cur + H * W)].view(N, H, W, 1)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
grid_y, grid_x = torch.meshgrid(
torch.linspace(0, H - 1, H, dtype=torch.float32, device=memory.device),
torch.linspace(0, W - 1, W, dtype=torch.float32, device=memory.device),
)
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2)
grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale
wh = torch.ones_like(grid) * 0.05 * (2.0**lvl)
proposal = torch.cat((grid, wh), -1).view(N, -1, 4)
proposals.append(proposal)
_cur += H * W
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(
-1, keepdim=True
)
output_proposals = torch.log(output_proposals / (1 - output_proposals))
output_proposals = output_proposals.masked_fill(
memory_padding_mask.unsqueeze(-1), float("inf")
)
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
output_memory = memory
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
output_memory = self.enc_output_norm(self.enc_output(output_memory))
return output_memory, output_proposals
@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
"""Get the reference points used in decoder.
Args:
spatial_shapes (Tensor): The shape of all
feature maps, has shape (num_level, 2).
valid_ratios (Tensor): The ratios of valid
points on the feature map, has shape
(bs, num_levels, 2)
device (obj:`device`): The device where
reference_points should be.
Returns:
Tensor: reference points used in decoder, has \
shape (bs, num_keys, num_levels, 2).
"""
reference_points_list = []
for lvl, (H, W) in enumerate(spatial_shapes):
# TODO check this 0.5
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H - 0.5, H, dtype=torch.float32, device=device),
torch.linspace(0.5, W - 0.5, W, dtype=torch.float32, device=device),
)
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points
def get_valid_ratio(self, mask):
"""Get the valid ratios of feature maps of all levels."""
_, H, W = mask.shape
valid_H = torch.sum(~mask[:, :, 0], 1)
valid_W = torch.sum(~mask[:, 0, :], 1)
valid_ratio_h = valid_H.float() / H
valid_ratio_w = valid_W.float() / W
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
return valid_ratio
def forward(
self,
multi_level_feats,
multi_level_masks,
multi_level_pos_embeds,
query_embed,
attn_masks,
**kwargs,
):
feat_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (feat, mask, pos_embed) in enumerate(
zip(multi_level_feats, multi_level_masks, multi_level_pos_embeds)
):
bs, c, h, w = feat.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
feat = feat.flatten(2).transpose(1, 2) # bs, hw, c
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
feat_flatten.append(feat)
mask_flatten.append(mask)
feat_flatten = torch.cat(feat_flatten, 1)
mask_flatten = torch.cat(mask_flatten, 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
spatial_shapes = torch.as_tensor(
spatial_shapes, dtype=torch.long, device=feat_flatten.device
)
level_start_index = torch.cat(
(spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])
)
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in multi_level_masks], 1)
reference_points = self.get_reference_points(
spatial_shapes, valid_ratios, device=feat.device
)
memory = self.encoder(
query=feat_flatten,
key=None,
value=None,
query_pos=lvl_pos_embed_flatten,
query_key_padding_mask=mask_flatten,
spatial_shapes=spatial_shapes,
reference_points=reference_points, # bs, num_token, num_level, 2
level_start_index=level_start_index,
valid_ratios=valid_ratios,
**kwargs,
)
output_memory, output_proposals = self.gen_encoder_output_proposals(
memory, mask_flatten, spatial_shapes
)
# output_memory: bs, num_tokens, c
# output_proposals: bs, num_tokens, 4. unsigmoided.
enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory)
enc_outputs_coord_unact = (
self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals
) # unsigmoided.
topk = self.two_stage_num_proposals
topk_proposals = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1]
# extract region proposal boxes
topk_coords_unact = torch.gather(
enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
) # unsigmoided.
reference_points = topk_coords_unact.detach().sigmoid()
if query_embed[1] is not None:
reference_points = torch.cat([query_embed[1].sigmoid(), reference_points], 1)
init_reference_out = reference_points
# extract region features
target_unact = torch.gather(
output_memory, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1])
)
if self.learnt_init_query:
target = self.tgt_embed.weight[None].repeat(bs, 1, 1)
else:
target = target_unact.detach()
if query_embed[0] is not None:
target = torch.cat([query_embed[0], target], 1)
# decoder
inter_states, inter_references = self.decoder(
query=target, # bs, num_queries, embed_dims
key=memory, # bs, num_tokens, embed_dims
value=memory, # bs, num_tokens, embed_dims
query_pos=None,
key_padding_mask=mask_flatten, # bs, num_tokens
reference_points=reference_points, # num_queries, 4
spatial_shapes=spatial_shapes, # nlvl, 2
level_start_index=level_start_index, # nlvl
valid_ratios=valid_ratios, # bs, nlvl, 2
attn_masks=attn_masks,
**kwargs,
)
inter_references_out = inter_references
return (
inter_states,
init_reference_out,
inter_references_out,
target_unact,
topk_coords_unact.sigmoid(),
)