odor-detection / detrex /layers /transformer.py
mathiaszinnen's picture
Initialize app
3e99b05
# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------------------------------
# Modified from:
# https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/bricks/transformer.py
# ------------------------------------------------------------------------------------------------
import copy
import warnings
from typing import List
import torch
import torch.nn as nn
class BaseTransformerLayer(nn.Module):
# TODO: add more tutorials about BaseTransformerLayer
"""The implementation of Base `TransformerLayer` used in Transformer. Modified
from `mmcv <https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/bricks/transformer.py>`_.
It can be built by directly passing the `Attentions`, `FFNs`, `Norms`
module, which support more flexible cusomization combined with
`LazyConfig` system. The `BaseTransformerLayer` also supports `prenorm`
when you specifying the `norm` as the first element of `operation_order`.
More details about the `prenorm`: `On Layer Normalization in the
Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
Args:
attn (list[nn.Module] | nn.Module): nn.Module or a list
contains the attention module used in TransformerLayer.
ffn (nn.Module): FFN module used in TransformerLayer.
norm (nn.Module): Normalization layer used in TransformerLayer.
operation_order (tuple[str]): The execution order of operation in
transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
Support `prenorm` when you specifying the first element as `norm`.
Default = None.
"""
def __init__(
self,
attn: List[nn.Module],
ffn: nn.Module,
norm: nn.Module,
operation_order: tuple = None,
):
super(BaseTransformerLayer, self).__init__()
assert set(operation_order).issubset({"self_attn", "norm", "cross_attn", "ffn"})
# count attention nums
num_attn = operation_order.count("self_attn") + operation_order.count("cross_attn")
if isinstance(attn, nn.Module):
attn = [copy.deepcopy(attn) for _ in range(num_attn)]
else:
assert len(attn) == num_attn, (
f"The length of attn (nn.Module or List[nn.Module]) {num_attn}"
f"is not consistent with the number of attention in "
f"operation_order {operation_order}"
)
self.num_attn = num_attn
self.operation_order = operation_order
self.pre_norm = operation_order[0] == "norm"
self.attentions = nn.ModuleList()
index = 0
for operation_name in operation_order:
if operation_name in ["self_attn", "cross_attn"]:
self.attentions.append(attn[index])
index += 1
self.embed_dim = self.attentions[0].embed_dim
# count ffn nums
self.ffns = nn.ModuleList()
num_ffns = operation_order.count("ffn")
for _ in range(num_ffns):
self.ffns.append(copy.deepcopy(ffn))
# count norm nums
self.norms = nn.ModuleList()
num_norms = operation_order.count("norm")
for _ in range(num_norms):
self.norms.append(copy.deepcopy(norm))
def forward(
self,
query: torch.Tensor,
key: torch.Tensor = None,
value: torch.Tensor = None,
query_pos: torch.Tensor = None,
key_pos: torch.Tensor = None,
attn_masks: List[torch.Tensor] = None,
query_key_padding_mask: torch.Tensor = None,
key_padding_mask: torch.Tensor = None,
**kwargs,
):
"""Forward function for `BaseTransformerLayer`.
**kwargs contains the specific arguments of attentions.
Args:
query (torch.Tensor): Query embeddings with shape
`(num_query, bs, embed_dim)` or `(bs, num_query, embed_dim)`
which should be specified follows the attention module used in
`BaseTransformerLayer`.
key (torch.Tensor): Key embeddings used in `Attention`.
value (torch.Tensor): Value embeddings with the same shape as `key`.
query_pos (torch.Tensor): The position embedding for `query`.
Default: None.
key_pos (torch.Tensor): The position embedding for `key`.
Default: None.
attn_masks (List[Tensor] | None): A list of 2D ByteTensor used
in calculation the corresponding attention. The length of
`attn_masks` should be equal to the number of `attention` in
`operation_order`. Default: None.
query_key_padding_mask (torch.Tensor): ByteTensor for `query`, with
shape `(bs, num_query)`. Only used in `self_attn` layer.
Defaults to None.
key_padding_mask (torch.Tensor): ByteTensor for `key`, with
shape `(bs, num_key)`. Default: None.
"""
norm_index = 0
attn_index = 0
ffn_index = 0
identity = query
if attn_masks is None:
attn_masks = [None for _ in range(self.num_attn)]
elif isinstance(attn_masks, torch.Tensor):
attn_masks = [copy.deepcopy(attn_masks) for _ in range(self.num_attn)]
warnings.warn(f"Use same attn_mask in all attentions in " f"{self.__class__.__name__} ")
else:
assert len(attn_masks) == self.num_attn, (
f"The length of "
f"attn_masks {len(attn_masks)} must be equal "
f"to the number of attention in "
f"operation_order {self.num_attn}"
)
for layer in self.operation_order:
if layer == "self_attn":
temp_key = temp_value = query
query = self.attentions[attn_index](
query,
temp_key,
temp_value,
identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=query_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=query_key_padding_mask,
**kwargs,
)
attn_index += 1
identity = query
elif layer == "norm":
query = self.norms[norm_index](query)
norm_index += 1
elif layer == "cross_attn":
query = self.attentions[attn_index](
query,
key,
value,
identity if self.pre_norm else None,
query_pos=query_pos,
key_pos=key_pos,
attn_mask=attn_masks[attn_index],
key_padding_mask=key_padding_mask,
**kwargs,
)
attn_index += 1
identity = query
elif layer == "ffn":
query = self.ffns[ffn_index](query, identity if self.pre_norm else None)
ffn_index += 1
return query
class TransformerLayerSequence(nn.Module):
"""Base class for TransformerEncoder and TransformerDecoder, which will copy
the passed `transformer_layers` module `num_layers` time or save the passed
list of `transformer_layers` as parameters named ``self.layers``
which is the type of ``nn.ModuleList``.
The users should inherit `TransformerLayerSequence` and implemente their
own forward function.
Args:
transformer_layers (list[BaseTransformerLayer] | BaseTransformerLayer): A list
of BaseTransformerLayer. If it is obj:`BaseTransformerLayer`, it
would be repeated `num_layers` times to a list[BaseTransformerLayer]
num_layers (int): The number of `TransformerLayer`. Default: None.
"""
def __init__(
self,
transformer_layers=None,
num_layers=None,
):
super(TransformerLayerSequence, self).__init__()
self.num_layers = num_layers
self.layers = nn.ModuleList()
if isinstance(transformer_layers, nn.Module):
for _ in range(num_layers):
self.layers.append(copy.deepcopy(transformer_layers))
else:
assert isinstance(transformer_layers, list) and len(transformer_layers) == num_layers
def forward(self):
"""Forward function of `TransformerLayerSequence`. The users should inherit
`TransformerLayerSequence` and implemente their own forward function.
"""
raise NotImplementedError()