mathiaszinnen's picture
Initialize app
3e99b05
# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------------------------------
# Modified from:
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/bricks/transformer.py
# ------------------------------------------------------------------------------------------------
import warnings
from typing import Optional
import torch
import torch.nn as nn
class MultiheadAttention(nn.Module):
"""A wrapper for ``torch.nn.MultiheadAttention``
Implemente MultiheadAttention with identity connection,
and position embedding is also passed as input.
Args:
embed_dim (int): The embedding dimension for attention.
num_heads (int): The number of attention heads.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
proj_drop (float): A Dropout layer after `MultiheadAttention`.
Default: 0.0.
batch_first (bool): if `True`, then the input and output tensor will be
provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)`
"""
def __init__(
self,
embed_dim: int,
num_heads: int,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
batch_first: bool = False,
**kwargs,
):
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.batch_first = batch_first
self.attn = nn.MultiheadAttention(
embed_dim=embed_dim,
num_heads=num_heads,
dropout=attn_drop,
batch_first=batch_first,
**kwargs,
)
self.proj_drop = nn.Dropout(proj_drop)
def forward(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
identity: Optional[torch.Tensor] = None,
query_pos: Optional[torch.Tensor] = None,
key_pos: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""Forward function for `MultiheadAttention`
**kwargs allow passing a more general data flow when combining
with other operations in `transformerlayer`.
Args:
query (torch.Tensor): Query embeddings with shape
`(num_query, bs, embed_dim)` if self.batch_first is False,
else `(bs, num_query, embed_dim)`
key (torch.Tensor): Key embeddings with shape
`(num_key, bs, embed_dim)` if self.batch_first is False,
else `(bs, num_key, embed_dim)`
value (torch.Tensor): Value embeddings with the same shape as `key`.
Same in `torch.nn.MultiheadAttention.forward`. Default: None.
If None, the `key` will be used.
identity (torch.Tensor): The tensor, with the same shape as x, will
be used for identity addition. Default: None.
If None, `query` will be used.
query_pos (torch.Tensor): The position embedding for query, with the
same shape as `query`. Default: None.
key_pos (torch.Tensor): The position embedding for key. Default: None.
If None, and `query_pos` has the same shape as `key`, then `query_pos`
will be used for `key_pos`.
attn_mask (torch.Tensor): ByteTensor mask with shape `(num_query, num_key)`.
Same as `torch.nn.MultiheadAttention.forward`. Default: None.
key_padding_mask (torch.Tensor): ByteTensor with shape `(bs, num_key)` which
indicates which elements within `key` to be ignored in attention.
Default: None.
"""
if key is None:
key = query
if value is None:
value = key
if identity is None:
identity = query
if key_pos is None:
if query_pos is not None:
# use query_pos if key_pos is not available
if query_pos.shape == key.shape:
key_pos = query_pos
else:
warnings.warn(
f"position encoding of key is" f"missing in {self.__class__.__name__}."
)
if query_pos is not None:
query = query + query_pos
if key_pos is not None:
key = key + key_pos
out = self.attn(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
)[0]
return identity + self.proj_drop(out)
class ConditionalSelfAttention(nn.Module):
"""Conditional Self-Attention Module used in Conditional-DETR
`Conditional DETR for Fast Training Convergence.
<https://arxiv.org/pdf/2108.06152.pdf>`_
Args:
embed_dim (int): The embedding dimension for attention.
num_heads (int): The number of attention heads.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
proj_drop (float): A Dropout layer after `MultiheadAttention`.
Default: 0.0.
batch_first (bool): if `True`, then the input and output tensor will be
provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)`
"""
def __init__(
self,
embed_dim,
num_heads,
attn_drop=0.0,
proj_drop=0.0,
batch_first=False,
**kwargs,
):
super(ConditionalSelfAttention, self).__init__()
self.query_content_proj = nn.Linear(embed_dim, embed_dim)
self.query_pos_proj = nn.Linear(embed_dim, embed_dim)
self.key_content_proj = nn.Linear(embed_dim, embed_dim)
self.key_pos_proj = nn.Linear(embed_dim, embed_dim)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
self.num_heads = num_heads
self.embed_dim = embed_dim
head_dim = embed_dim // num_heads
self.scale = head_dim**-0.5
self.batch_first = batch_first
def forward(
self,
query,
key=None,
value=None,
identity=None,
query_pos=None,
key_pos=None,
attn_mask=None,
key_padding_mask=None,
**kwargs,
):
"""Forward function for `ConditionalSelfAttention`
**kwargs allow passing a more general data flow when combining
with other operations in `transformerlayer`.
Args:
query (torch.Tensor): Query embeddings with shape
`(num_query, bs, embed_dim)` if self.batch_first is False,
else `(bs, num_query, embed_dim)`
key (torch.Tensor): Key embeddings with shape
`(num_key, bs, embed_dim)` if self.batch_first is False,
else `(bs, num_key, embed_dim)`
value (torch.Tensor): Value embeddings with the same shape as `key`.
Same in `torch.nn.MultiheadAttention.forward`. Default: None.
If None, the `key` will be used.
identity (torch.Tensor): The tensor, with the same shape as `query``,
which will be used for identity addition. Default: None.
If None, `query` will be used.
query_pos (torch.Tensor): The position embedding for query, with the
same shape as `query`. Default: None.
key_pos (torch.Tensor): The position embedding for key. Default: None.
If None, and `query_pos` has the same shape as `key`, then `query_pos`
will be used for `key_pos`.
attn_mask (torch.Tensor): ByteTensor mask with shape `(num_query, num_key)`.
Same as `torch.nn.MultiheadAttention.forward`. Default: None.
key_padding_mask (torch.Tensor): ByteTensor with shape `(bs, num_key)` which
indicates which elements within `key` to be ignored in attention.
Default: None.
"""
if key is None:
key = query
if value is None:
value = key
if identity is None:
identity = query
if key_pos is None:
if query_pos is not None:
# use query_pos if key_pos is not available
if query_pos.shape == key.shape:
key_pos = query_pos
else:
warnings.warn(
f"position encoding of key is" f"missing in {self.__class__.__name__}."
)
assert (
query_pos is not None and key_pos is not None
), "query_pos and key_pos must be passed into ConditionalAttention Module"
# transpose (b n c) to (n b c) for attention calculation
if self.batch_first:
query = query.transpose(0, 1) # (n b c)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
query_pos = query_pos.transpose(0, 1)
key_pos = key_pos.transpose(0, 1)
identity = identity.transpose(0, 1)
# query/key/value content and position embedding projection
query_content = self.query_content_proj(query)
query_pos = self.query_pos_proj(query_pos)
key_content = self.key_content_proj(key)
key_pos = self.key_pos_proj(key_pos)
value = self.value_proj(value)
# attention calculation
N, B, C = query_content.shape
q = query_content + query_pos
k = key_content + key_pos
v = value
q = q.reshape(N, B, self.num_heads, C // self.num_heads).permute(
1, 2, 0, 3
) # (B, num_heads, N, head_dim)
k = k.reshape(N, B, self.num_heads, C // self.num_heads).permute(1, 2, 0, 3)
v = v.reshape(N, B, self.num_heads, C // self.num_heads).permute(1, 2, 0, 3)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
# add attention mask
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn.masked_fill_(attn_mask, float("-inf"))
else:
attn += attn_mask
if key_padding_mask is not None:
attn = attn.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
out = (attn @ v).transpose(1, 2).reshape(B, N, C)
out = self.out_proj(out)
if not self.batch_first:
out = out.transpose(0, 1)
return identity + self.proj_drop(out)
class ConditionalCrossAttention(nn.Module):
"""Conditional Cross-Attention Module used in Conditional-DETR
`Conditional DETR for Fast Training Convergence.
<https://arxiv.org/pdf/2108.06152.pdf>`_
Args:
embed_dim (int): The embedding dimension for attention.
num_heads (int): The number of attention heads.
attn_drop (float): A Dropout layer on attn_output_weights.
Default: 0.0.
proj_drop (float): A Dropout layer after `MultiheadAttention`.
Default: 0.0.
batch_first (bool): if `True`, then the input and output tensor will be
provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)`
"""
def __init__(
self,
embed_dim,
num_heads,
attn_drop=0.0,
proj_drop=0.0,
batch_first=False,
**kwargs,
):
super(ConditionalCrossAttention, self).__init__()
self.query_content_proj = nn.Linear(embed_dim, embed_dim)
self.query_pos_proj = nn.Linear(embed_dim, embed_dim)
self.query_pos_sine_proj = nn.Linear(embed_dim, embed_dim)
self.key_content_proj = nn.Linear(embed_dim, embed_dim)
self.key_pos_proj = nn.Linear(embed_dim, embed_dim)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.attn_drop = nn.Dropout(attn_drop)
self.proj_drop = nn.Dropout(proj_drop)
self.num_heads = num_heads
self.batch_first = batch_first
def forward(
self,
query,
key=None,
value=None,
identity=None,
query_pos=None,
key_pos=None,
query_sine_embed=None,
is_first_layer=False,
attn_mask=None,
key_padding_mask=None,
**kwargs,
):
"""Forward function for `ConditionalCrossAttention`
**kwargs allow passing a more general data flow when combining
with other operations in `transformerlayer`.
Args:
query (torch.Tensor): Query embeddings with shape
`(num_query, bs, embed_dim)` if self.batch_first is False,
else `(bs, num_query, embed_dim)`
key (torch.Tensor): Key embeddings with shape
`(num_key, bs, embed_dim)` if self.batch_first is False,
else `(bs, num_key, embed_dim)`
value (torch.Tensor): Value embeddings with the same shape as `key`.
Same in `torch.nn.MultiheadAttention.forward`. Default: None.
If None, the `key` will be used.
identity (torch.Tensor): The tensor, with the same shape as x, will
be used for identity addition. Default: None.
If None, `query` will be used.
query_pos (torch.Tensor): The position embedding for query, with the
same shape as `query`. Default: None.
key_pos (torch.Tensor): The position embedding for key. Default: None.
If None, and `query_pos` has the same shape as `key`, then `query_pos`
will be used for `key_pos`.
query_sine_embed (torch.Tensor): None
is_first_layer (bool): None
attn_mask (torch.Tensor): ByteTensor mask with shape `(num_query, num_key)`.
Same as `torch.nn.MultiheadAttention.forward`. Default: None.
key_padding_mask (torch.Tensor): ByteTensor with shape `(bs, num_key)` which
indicates which elements within `key` to be ignored in attention.
Default: None.
"""
if key is None:
key = query
if value is None:
value = key
if identity is None:
identity = query
if key_pos is None:
if query_pos is not None:
# use query_pos if key_pos is not available
if query_pos.shape == key.shape:
key_pos = query_pos
else:
warnings.warn(
f"position encoding of key is" f"missing in {self.__class__.__name__}."
)
assert (
query_pos is not None and key_pos is not None
), "query_pos and key_pos must be passed into ConditionalAttention Module"
# transpose (b n c) to (n b c) for attention calculation
if self.batch_first:
query = query.transpose(0, 1) # (n b c)
key = key.transpose(0, 1)
value = value.transpose(0, 1)
query_pos = query_pos.transpose(0, 1)
key_pos = key_pos.transpose(0, 1)
identity = identity.transpose(0, 1)
# content projection
query_content = self.query_content_proj(query)
key_content = self.key_content_proj(key)
value = self.value_proj(value)
# shape info
N, B, C = query_content.shape
HW, _, _ = key_content.shape
# position projection
key_pos = self.key_pos_proj(key_pos)
if is_first_layer:
query_pos = self.query_pos_proj(query_pos)
q = query_content + query_pos
k = key_content + key_pos
else:
q = query_content
k = key_content
v = value
# preprocess
q = q.view(N, B, self.num_heads, C // self.num_heads)
query_sine_embed = self.query_pos_sine_proj(query_sine_embed).view(
N, B, self.num_heads, C // self.num_heads
)
q = torch.cat([q, query_sine_embed], dim=3).view(N, B, C * 2)
k = k.view(HW, B, self.num_heads, C // self.num_heads) # N, 16, 256
key_pos = key_pos.view(HW, B, self.num_heads, C // self.num_heads)
k = torch.cat([k, key_pos], dim=3).view(HW, B, C * 2)
# attention calculation
q = q.reshape(N, B, self.num_heads, C * 2 // self.num_heads).permute(
1, 2, 0, 3
) # (B, num_heads, N, head_dim)
k = k.reshape(HW, B, self.num_heads, C * 2 // self.num_heads).permute(1, 2, 0, 3)
v = v.reshape(HW, B, self.num_heads, C // self.num_heads).permute(1, 2, 0, 3)
scale = (C * 2 // self.num_heads) ** -0.5
q = q * scale
attn = q @ k.transpose(-2, -1)
# add attention mask
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn.masked_fill_(attn_mask, float("-inf"))
else:
attn += attn_mask
if key_padding_mask is not None:
attn = attn.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf"))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
out = (attn @ v).transpose(1, 2).reshape(B, N, C)
out = self.out_proj(out)
if not self.batch_first:
out = out.transpose(0, 1)
return identity + self.proj_drop(out)