Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2022 The IDEA Authors. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ------------------------------------------------------------------------------------------------ | |
| # Copyright (c) OpenMMLab. All rights reserved. | |
| # ------------------------------------------------------------------------------------------------ | |
| # Modified from: | |
| # https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/bricks/transformer.py | |
| # ------------------------------------------------------------------------------------------------ | |
| import copy | |
| import warnings | |
| from typing import List | |
| import torch | |
| import torch.nn as nn | |
| class BaseTransformerLayer(nn.Module): | |
| # TODO: add more tutorials about BaseTransformerLayer | |
| """The implementation of Base `TransformerLayer` used in Transformer. Modified | |
| from `mmcv <https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/bricks/transformer.py>`_. | |
| It can be built by directly passing the `Attentions`, `FFNs`, `Norms` | |
| module, which support more flexible cusomization combined with | |
| `LazyConfig` system. The `BaseTransformerLayer` also supports `prenorm` | |
| when you specifying the `norm` as the first element of `operation_order`. | |
| More details about the `prenorm`: `On Layer Normalization in the | |
| Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ . | |
| Args: | |
| attn (list[nn.Module] | nn.Module): nn.Module or a list | |
| contains the attention module used in TransformerLayer. | |
| ffn (nn.Module): FFN module used in TransformerLayer. | |
| norm (nn.Module): Normalization layer used in TransformerLayer. | |
| operation_order (tuple[str]): The execution order of operation in | |
| transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). | |
| Support `prenorm` when you specifying the first element as `norm`. | |
| Default = None. | |
| """ | |
| def __init__( | |
| self, | |
| attn: List[nn.Module], | |
| ffn: nn.Module, | |
| norm: nn.Module, | |
| operation_order: tuple = None, | |
| ): | |
| super(BaseTransformerLayer, self).__init__() | |
| assert set(operation_order).issubset({"self_attn", "norm", "cross_attn", "ffn"}) | |
| # count attention nums | |
| num_attn = operation_order.count("self_attn") + operation_order.count("cross_attn") | |
| if isinstance(attn, nn.Module): | |
| attn = [copy.deepcopy(attn) for _ in range(num_attn)] | |
| else: | |
| assert len(attn) == num_attn, ( | |
| f"The length of attn (nn.Module or List[nn.Module]) {num_attn}" | |
| f"is not consistent with the number of attention in " | |
| f"operation_order {operation_order}" | |
| ) | |
| self.num_attn = num_attn | |
| self.operation_order = operation_order | |
| self.pre_norm = operation_order[0] == "norm" | |
| self.attentions = nn.ModuleList() | |
| index = 0 | |
| for operation_name in operation_order: | |
| if operation_name in ["self_attn", "cross_attn"]: | |
| self.attentions.append(attn[index]) | |
| index += 1 | |
| self.embed_dim = self.attentions[0].embed_dim | |
| # count ffn nums | |
| self.ffns = nn.ModuleList() | |
| num_ffns = operation_order.count("ffn") | |
| for _ in range(num_ffns): | |
| self.ffns.append(copy.deepcopy(ffn)) | |
| # count norm nums | |
| self.norms = nn.ModuleList() | |
| num_norms = operation_order.count("norm") | |
| for _ in range(num_norms): | |
| self.norms.append(copy.deepcopy(norm)) | |
| def forward( | |
| self, | |
| query: torch.Tensor, | |
| key: torch.Tensor = None, | |
| value: torch.Tensor = None, | |
| query_pos: torch.Tensor = None, | |
| key_pos: torch.Tensor = None, | |
| attn_masks: List[torch.Tensor] = None, | |
| query_key_padding_mask: torch.Tensor = None, | |
| key_padding_mask: torch.Tensor = None, | |
| **kwargs, | |
| ): | |
| """Forward function for `BaseTransformerLayer`. | |
| **kwargs contains the specific arguments of attentions. | |
| Args: | |
| query (torch.Tensor): Query embeddings with shape | |
| `(num_query, bs, embed_dim)` or `(bs, num_query, embed_dim)` | |
| which should be specified follows the attention module used in | |
| `BaseTransformerLayer`. | |
| key (torch.Tensor): Key embeddings used in `Attention`. | |
| value (torch.Tensor): Value embeddings with the same shape as `key`. | |
| query_pos (torch.Tensor): The position embedding for `query`. | |
| Default: None. | |
| key_pos (torch.Tensor): The position embedding for `key`. | |
| Default: None. | |
| attn_masks (List[Tensor] | None): A list of 2D ByteTensor used | |
| in calculation the corresponding attention. The length of | |
| `attn_masks` should be equal to the number of `attention` in | |
| `operation_order`. Default: None. | |
| query_key_padding_mask (torch.Tensor): ByteTensor for `query`, with | |
| shape `(bs, num_query)`. Only used in `self_attn` layer. | |
| Defaults to None. | |
| key_padding_mask (torch.Tensor): ByteTensor for `key`, with | |
| shape `(bs, num_key)`. Default: None. | |
| """ | |
| norm_index = 0 | |
| attn_index = 0 | |
| ffn_index = 0 | |
| identity = query | |
| if attn_masks is None: | |
| attn_masks = [None for _ in range(self.num_attn)] | |
| elif isinstance(attn_masks, torch.Tensor): | |
| attn_masks = [copy.deepcopy(attn_masks) for _ in range(self.num_attn)] | |
| warnings.warn(f"Use same attn_mask in all attentions in " f"{self.__class__.__name__} ") | |
| else: | |
| assert len(attn_masks) == self.num_attn, ( | |
| f"The length of " | |
| f"attn_masks {len(attn_masks)} must be equal " | |
| f"to the number of attention in " | |
| f"operation_order {self.num_attn}" | |
| ) | |
| for layer in self.operation_order: | |
| if layer == "self_attn": | |
| temp_key = temp_value = query | |
| query = self.attentions[attn_index]( | |
| query, | |
| temp_key, | |
| temp_value, | |
| identity if self.pre_norm else None, | |
| query_pos=query_pos, | |
| key_pos=query_pos, | |
| attn_mask=attn_masks[attn_index], | |
| key_padding_mask=query_key_padding_mask, | |
| **kwargs, | |
| ) | |
| attn_index += 1 | |
| identity = query | |
| elif layer == "norm": | |
| query = self.norms[norm_index](query) | |
| norm_index += 1 | |
| elif layer == "cross_attn": | |
| query = self.attentions[attn_index]( | |
| query, | |
| key, | |
| value, | |
| identity if self.pre_norm else None, | |
| query_pos=query_pos, | |
| key_pos=key_pos, | |
| attn_mask=attn_masks[attn_index], | |
| key_padding_mask=key_padding_mask, | |
| **kwargs, | |
| ) | |
| attn_index += 1 | |
| identity = query | |
| elif layer == "ffn": | |
| query = self.ffns[ffn_index](query, identity if self.pre_norm else None) | |
| ffn_index += 1 | |
| return query | |
| class TransformerLayerSequence(nn.Module): | |
| """Base class for TransformerEncoder and TransformerDecoder, which will copy | |
| the passed `transformer_layers` module `num_layers` time or save the passed | |
| list of `transformer_layers` as parameters named ``self.layers`` | |
| which is the type of ``nn.ModuleList``. | |
| The users should inherit `TransformerLayerSequence` and implemente their | |
| own forward function. | |
| Args: | |
| transformer_layers (list[BaseTransformerLayer] | BaseTransformerLayer): A list | |
| of BaseTransformerLayer. If it is obj:`BaseTransformerLayer`, it | |
| would be repeated `num_layers` times to a list[BaseTransformerLayer] | |
| num_layers (int): The number of `TransformerLayer`. Default: None. | |
| """ | |
| def __init__( | |
| self, | |
| transformer_layers=None, | |
| num_layers=None, | |
| ): | |
| super(TransformerLayerSequence, self).__init__() | |
| self.num_layers = num_layers | |
| self.layers = nn.ModuleList() | |
| if isinstance(transformer_layers, nn.Module): | |
| for _ in range(num_layers): | |
| self.layers.append(copy.deepcopy(transformer_layers)) | |
| else: | |
| assert isinstance(transformer_layers, list) and len(transformer_layers) == num_layers | |
| def forward(self): | |
| """Forward function of `TransformerLayerSequence`. The users should inherit | |
| `TransformerLayerSequence` and implemente their own forward function. | |
| """ | |
| raise NotImplementedError() | |