Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2022 The IDEA Authors. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ------------------------------------------------------------------------------------------------ | |
# Copyright (c) OpenMMLab. All rights reserved. | |
# ------------------------------------------------------------------------------------------------ | |
# Modified from: | |
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/bricks/transformer.py | |
# ------------------------------------------------------------------------------------------------ | |
import warnings | |
from typing import Optional | |
import torch | |
import torch.nn as nn | |
class MultiheadAttention(nn.Module): | |
"""A wrapper for ``torch.nn.MultiheadAttention`` | |
Implemente MultiheadAttention with identity connection, | |
and position embedding is also passed as input. | |
Args: | |
embed_dim (int): The embedding dimension for attention. | |
num_heads (int): The number of attention heads. | |
attn_drop (float): A Dropout layer on attn_output_weights. | |
Default: 0.0. | |
proj_drop (float): A Dropout layer after `MultiheadAttention`. | |
Default: 0.0. | |
batch_first (bool): if `True`, then the input and output tensor will be | |
provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)` | |
""" | |
def __init__( | |
self, | |
embed_dim: int, | |
num_heads: int, | |
attn_drop: float = 0.0, | |
proj_drop: float = 0.0, | |
batch_first: bool = False, | |
**kwargs, | |
): | |
super(MultiheadAttention, self).__init__() | |
self.embed_dim = embed_dim | |
self.num_heads = num_heads | |
self.batch_first = batch_first | |
self.attn = nn.MultiheadAttention( | |
embed_dim=embed_dim, | |
num_heads=num_heads, | |
dropout=attn_drop, | |
batch_first=batch_first, | |
**kwargs, | |
) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward( | |
self, | |
query: torch.Tensor, | |
key: Optional[torch.Tensor] = None, | |
value: Optional[torch.Tensor] = None, | |
identity: Optional[torch.Tensor] = None, | |
query_pos: Optional[torch.Tensor] = None, | |
key_pos: Optional[torch.Tensor] = None, | |
attn_mask: Optional[torch.Tensor] = None, | |
key_padding_mask: Optional[torch.Tensor] = None, | |
**kwargs, | |
) -> torch.Tensor: | |
"""Forward function for `MultiheadAttention` | |
**kwargs allow passing a more general data flow when combining | |
with other operations in `transformerlayer`. | |
Args: | |
query (torch.Tensor): Query embeddings with shape | |
`(num_query, bs, embed_dim)` if self.batch_first is False, | |
else `(bs, num_query, embed_dim)` | |
key (torch.Tensor): Key embeddings with shape | |
`(num_key, bs, embed_dim)` if self.batch_first is False, | |
else `(bs, num_key, embed_dim)` | |
value (torch.Tensor): Value embeddings with the same shape as `key`. | |
Same in `torch.nn.MultiheadAttention.forward`. Default: None. | |
If None, the `key` will be used. | |
identity (torch.Tensor): The tensor, with the same shape as x, will | |
be used for identity addition. Default: None. | |
If None, `query` will be used. | |
query_pos (torch.Tensor): The position embedding for query, with the | |
same shape as `query`. Default: None. | |
key_pos (torch.Tensor): The position embedding for key. Default: None. | |
If None, and `query_pos` has the same shape as `key`, then `query_pos` | |
will be used for `key_pos`. | |
attn_mask (torch.Tensor): ByteTensor mask with shape `(num_query, num_key)`. | |
Same as `torch.nn.MultiheadAttention.forward`. Default: None. | |
key_padding_mask (torch.Tensor): ByteTensor with shape `(bs, num_key)` which | |
indicates which elements within `key` to be ignored in attention. | |
Default: None. | |
""" | |
if key is None: | |
key = query | |
if value is None: | |
value = key | |
if identity is None: | |
identity = query | |
if key_pos is None: | |
if query_pos is not None: | |
# use query_pos if key_pos is not available | |
if query_pos.shape == key.shape: | |
key_pos = query_pos | |
else: | |
warnings.warn( | |
f"position encoding of key is" f"missing in {self.__class__.__name__}." | |
) | |
if query_pos is not None: | |
query = query + query_pos | |
if key_pos is not None: | |
key = key + key_pos | |
out = self.attn( | |
query=query, | |
key=key, | |
value=value, | |
attn_mask=attn_mask, | |
key_padding_mask=key_padding_mask, | |
)[0] | |
return identity + self.proj_drop(out) | |
class ConditionalSelfAttention(nn.Module): | |
"""Conditional Self-Attention Module used in Conditional-DETR | |
`Conditional DETR for Fast Training Convergence. | |
<https://arxiv.org/pdf/2108.06152.pdf>`_ | |
Args: | |
embed_dim (int): The embedding dimension for attention. | |
num_heads (int): The number of attention heads. | |
attn_drop (float): A Dropout layer on attn_output_weights. | |
Default: 0.0. | |
proj_drop (float): A Dropout layer after `MultiheadAttention`. | |
Default: 0.0. | |
batch_first (bool): if `True`, then the input and output tensor will be | |
provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)` | |
""" | |
def __init__( | |
self, | |
embed_dim, | |
num_heads, | |
attn_drop=0.0, | |
proj_drop=0.0, | |
batch_first=False, | |
**kwargs, | |
): | |
super(ConditionalSelfAttention, self).__init__() | |
self.query_content_proj = nn.Linear(embed_dim, embed_dim) | |
self.query_pos_proj = nn.Linear(embed_dim, embed_dim) | |
self.key_content_proj = nn.Linear(embed_dim, embed_dim) | |
self.key_pos_proj = nn.Linear(embed_dim, embed_dim) | |
self.value_proj = nn.Linear(embed_dim, embed_dim) | |
self.out_proj = nn.Linear(embed_dim, embed_dim) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj_drop = nn.Dropout(proj_drop) | |
self.num_heads = num_heads | |
self.embed_dim = embed_dim | |
head_dim = embed_dim // num_heads | |
self.scale = head_dim**-0.5 | |
self.batch_first = batch_first | |
def forward( | |
self, | |
query, | |
key=None, | |
value=None, | |
identity=None, | |
query_pos=None, | |
key_pos=None, | |
attn_mask=None, | |
key_padding_mask=None, | |
**kwargs, | |
): | |
"""Forward function for `ConditionalSelfAttention` | |
**kwargs allow passing a more general data flow when combining | |
with other operations in `transformerlayer`. | |
Args: | |
query (torch.Tensor): Query embeddings with shape | |
`(num_query, bs, embed_dim)` if self.batch_first is False, | |
else `(bs, num_query, embed_dim)` | |
key (torch.Tensor): Key embeddings with shape | |
`(num_key, bs, embed_dim)` if self.batch_first is False, | |
else `(bs, num_key, embed_dim)` | |
value (torch.Tensor): Value embeddings with the same shape as `key`. | |
Same in `torch.nn.MultiheadAttention.forward`. Default: None. | |
If None, the `key` will be used. | |
identity (torch.Tensor): The tensor, with the same shape as `query``, | |
which will be used for identity addition. Default: None. | |
If None, `query` will be used. | |
query_pos (torch.Tensor): The position embedding for query, with the | |
same shape as `query`. Default: None. | |
key_pos (torch.Tensor): The position embedding for key. Default: None. | |
If None, and `query_pos` has the same shape as `key`, then `query_pos` | |
will be used for `key_pos`. | |
attn_mask (torch.Tensor): ByteTensor mask with shape `(num_query, num_key)`. | |
Same as `torch.nn.MultiheadAttention.forward`. Default: None. | |
key_padding_mask (torch.Tensor): ByteTensor with shape `(bs, num_key)` which | |
indicates which elements within `key` to be ignored in attention. | |
Default: None. | |
""" | |
if key is None: | |
key = query | |
if value is None: | |
value = key | |
if identity is None: | |
identity = query | |
if key_pos is None: | |
if query_pos is not None: | |
# use query_pos if key_pos is not available | |
if query_pos.shape == key.shape: | |
key_pos = query_pos | |
else: | |
warnings.warn( | |
f"position encoding of key is" f"missing in {self.__class__.__name__}." | |
) | |
assert ( | |
query_pos is not None and key_pos is not None | |
), "query_pos and key_pos must be passed into ConditionalAttention Module" | |
# transpose (b n c) to (n b c) for attention calculation | |
if self.batch_first: | |
query = query.transpose(0, 1) # (n b c) | |
key = key.transpose(0, 1) | |
value = value.transpose(0, 1) | |
query_pos = query_pos.transpose(0, 1) | |
key_pos = key_pos.transpose(0, 1) | |
identity = identity.transpose(0, 1) | |
# query/key/value content and position embedding projection | |
query_content = self.query_content_proj(query) | |
query_pos = self.query_pos_proj(query_pos) | |
key_content = self.key_content_proj(key) | |
key_pos = self.key_pos_proj(key_pos) | |
value = self.value_proj(value) | |
# attention calculation | |
N, B, C = query_content.shape | |
q = query_content + query_pos | |
k = key_content + key_pos | |
v = value | |
q = q.reshape(N, B, self.num_heads, C // self.num_heads).permute( | |
1, 2, 0, 3 | |
) # (B, num_heads, N, head_dim) | |
k = k.reshape(N, B, self.num_heads, C // self.num_heads).permute(1, 2, 0, 3) | |
v = v.reshape(N, B, self.num_heads, C // self.num_heads).permute(1, 2, 0, 3) | |
q = q * self.scale | |
attn = q @ k.transpose(-2, -1) | |
# add attention mask | |
if attn_mask is not None: | |
if attn_mask.dtype == torch.bool: | |
attn.masked_fill_(attn_mask, float("-inf")) | |
else: | |
attn += attn_mask | |
if key_padding_mask is not None: | |
attn = attn.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")) | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
out = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
out = self.out_proj(out) | |
if not self.batch_first: | |
out = out.transpose(0, 1) | |
return identity + self.proj_drop(out) | |
class ConditionalCrossAttention(nn.Module): | |
"""Conditional Cross-Attention Module used in Conditional-DETR | |
`Conditional DETR for Fast Training Convergence. | |
<https://arxiv.org/pdf/2108.06152.pdf>`_ | |
Args: | |
embed_dim (int): The embedding dimension for attention. | |
num_heads (int): The number of attention heads. | |
attn_drop (float): A Dropout layer on attn_output_weights. | |
Default: 0.0. | |
proj_drop (float): A Dropout layer after `MultiheadAttention`. | |
Default: 0.0. | |
batch_first (bool): if `True`, then the input and output tensor will be | |
provided as `(bs, n, embed_dim)`. Default: False. `(n, bs, embed_dim)` | |
""" | |
def __init__( | |
self, | |
embed_dim, | |
num_heads, | |
attn_drop=0.0, | |
proj_drop=0.0, | |
batch_first=False, | |
**kwargs, | |
): | |
super(ConditionalCrossAttention, self).__init__() | |
self.query_content_proj = nn.Linear(embed_dim, embed_dim) | |
self.query_pos_proj = nn.Linear(embed_dim, embed_dim) | |
self.query_pos_sine_proj = nn.Linear(embed_dim, embed_dim) | |
self.key_content_proj = nn.Linear(embed_dim, embed_dim) | |
self.key_pos_proj = nn.Linear(embed_dim, embed_dim) | |
self.value_proj = nn.Linear(embed_dim, embed_dim) | |
self.out_proj = nn.Linear(embed_dim, embed_dim) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj_drop = nn.Dropout(proj_drop) | |
self.num_heads = num_heads | |
self.batch_first = batch_first | |
def forward( | |
self, | |
query, | |
key=None, | |
value=None, | |
identity=None, | |
query_pos=None, | |
key_pos=None, | |
query_sine_embed=None, | |
is_first_layer=False, | |
attn_mask=None, | |
key_padding_mask=None, | |
**kwargs, | |
): | |
"""Forward function for `ConditionalCrossAttention` | |
**kwargs allow passing a more general data flow when combining | |
with other operations in `transformerlayer`. | |
Args: | |
query (torch.Tensor): Query embeddings with shape | |
`(num_query, bs, embed_dim)` if self.batch_first is False, | |
else `(bs, num_query, embed_dim)` | |
key (torch.Tensor): Key embeddings with shape | |
`(num_key, bs, embed_dim)` if self.batch_first is False, | |
else `(bs, num_key, embed_dim)` | |
value (torch.Tensor): Value embeddings with the same shape as `key`. | |
Same in `torch.nn.MultiheadAttention.forward`. Default: None. | |
If None, the `key` will be used. | |
identity (torch.Tensor): The tensor, with the same shape as x, will | |
be used for identity addition. Default: None. | |
If None, `query` will be used. | |
query_pos (torch.Tensor): The position embedding for query, with the | |
same shape as `query`. Default: None. | |
key_pos (torch.Tensor): The position embedding for key. Default: None. | |
If None, and `query_pos` has the same shape as `key`, then `query_pos` | |
will be used for `key_pos`. | |
query_sine_embed (torch.Tensor): None | |
is_first_layer (bool): None | |
attn_mask (torch.Tensor): ByteTensor mask with shape `(num_query, num_key)`. | |
Same as `torch.nn.MultiheadAttention.forward`. Default: None. | |
key_padding_mask (torch.Tensor): ByteTensor with shape `(bs, num_key)` which | |
indicates which elements within `key` to be ignored in attention. | |
Default: None. | |
""" | |
if key is None: | |
key = query | |
if value is None: | |
value = key | |
if identity is None: | |
identity = query | |
if key_pos is None: | |
if query_pos is not None: | |
# use query_pos if key_pos is not available | |
if query_pos.shape == key.shape: | |
key_pos = query_pos | |
else: | |
warnings.warn( | |
f"position encoding of key is" f"missing in {self.__class__.__name__}." | |
) | |
assert ( | |
query_pos is not None and key_pos is not None | |
), "query_pos and key_pos must be passed into ConditionalAttention Module" | |
# transpose (b n c) to (n b c) for attention calculation | |
if self.batch_first: | |
query = query.transpose(0, 1) # (n b c) | |
key = key.transpose(0, 1) | |
value = value.transpose(0, 1) | |
query_pos = query_pos.transpose(0, 1) | |
key_pos = key_pos.transpose(0, 1) | |
identity = identity.transpose(0, 1) | |
# content projection | |
query_content = self.query_content_proj(query) | |
key_content = self.key_content_proj(key) | |
value = self.value_proj(value) | |
# shape info | |
N, B, C = query_content.shape | |
HW, _, _ = key_content.shape | |
# position projection | |
key_pos = self.key_pos_proj(key_pos) | |
if is_first_layer: | |
query_pos = self.query_pos_proj(query_pos) | |
q = query_content + query_pos | |
k = key_content + key_pos | |
else: | |
q = query_content | |
k = key_content | |
v = value | |
# preprocess | |
q = q.view(N, B, self.num_heads, C // self.num_heads) | |
query_sine_embed = self.query_pos_sine_proj(query_sine_embed).view( | |
N, B, self.num_heads, C // self.num_heads | |
) | |
q = torch.cat([q, query_sine_embed], dim=3).view(N, B, C * 2) | |
k = k.view(HW, B, self.num_heads, C // self.num_heads) # N, 16, 256 | |
key_pos = key_pos.view(HW, B, self.num_heads, C // self.num_heads) | |
k = torch.cat([k, key_pos], dim=3).view(HW, B, C * 2) | |
# attention calculation | |
q = q.reshape(N, B, self.num_heads, C * 2 // self.num_heads).permute( | |
1, 2, 0, 3 | |
) # (B, num_heads, N, head_dim) | |
k = k.reshape(HW, B, self.num_heads, C * 2 // self.num_heads).permute(1, 2, 0, 3) | |
v = v.reshape(HW, B, self.num_heads, C // self.num_heads).permute(1, 2, 0, 3) | |
scale = (C * 2 // self.num_heads) ** -0.5 | |
q = q * scale | |
attn = q @ k.transpose(-2, -1) | |
# add attention mask | |
if attn_mask is not None: | |
if attn_mask.dtype == torch.bool: | |
attn.masked_fill_(attn_mask, float("-inf")) | |
else: | |
attn += attn_mask | |
if key_padding_mask is not None: | |
attn = attn.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), float("-inf")) | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
out = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
out = self.out_proj(out) | |
if not self.batch_first: | |
out = out.transpose(0, 1) | |
return identity + self.proj_drop(out) | |