# coding=utf-8 # Copyright 2022 The IDEA Authors. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ------------------------------------------------------------------------------------------------ # Copyright (c) OpenMMLab. All rights reserved. # ------------------------------------------------------------------------------------------------ # Modified from: # https://github.com/open-mmlab/mmcv/blob/master/tests/test_cnn/test_transformer.py # ------------------------------------------------------------------------------------------------ import pytest import torch import torch.nn as nn from detrex.layers import FFN, BaseTransformerLayer, MultiheadAttention, TransformerLayerSequence def test_ffn(): with pytest.raises(AssertionError): FFN(num_fcs=1) ffn = FFN(ffn_drop=0.0) input_tensor = torch.rand(2, 20, 256) input_tensor_nbc = input_tensor.transpose(0, 1) assert torch.allclose(ffn(input_tensor).sum(), ffn(input_tensor_nbc).sum()) residual = torch.rand_like(input_tensor) torch.allclose( ffn(input_tensor, identity=residual).sum(), ffn(input_tensor).sum() + residual.sum() - input_tensor.sum(), ) @pytest.mark.parametrize("embed_dim", [256]) def test_basetransformerlayer(embed_dim): attn = MultiheadAttention(embed_dim=embed_dim, num_heads=8, batch_first=True) ffn = FFN(embed_dim, 1024, num_fcs=2, activation=nn.ReLU(inplace=True)) base_layer = BaseTransformerLayer( attn=attn, ffn=ffn, norm=nn.LayerNorm(embed_dim), operation_order=("self_attn", "norm", "ffn", "norm"), ) feedforward_dim = 1024 assert attn.batch_first is True assert base_layer.ffns[0].feedforward_dim == feedforward_dim in_tensor = torch.rand(2, 10, embed_dim) base_layer(in_tensor) def test_transformerlayersequence(): sequence = TransformerLayerSequence( transformer_layers=BaseTransformerLayer( attn=[ MultiheadAttention(256, 8, batch_first=True), MultiheadAttention(256, 8, batch_first=True), ], ffn=FFN(256, 1024, num_fcs=2), norm=nn.LayerNorm(256), operation_order=("self_attn", "norm", "cross_attn", "norm", "ffn", "norm"), ), num_layers=6, ) assert sequence.num_layers == 6 with pytest.raises(AssertionError): TransformerLayerSequence( transformer_layers=[ BaseTransformerLayer( attn=[ MultiheadAttention(256, 8, batch_first=True), MultiheadAttention(256, 8, batch_first=True), ], ffn=FFN(256, 1024, num_fcs=2), norm=nn.LayerNorm(256), operation_order=("self_attn", "norm", "cross_attn", "norm", "ffn", "norm"), ), ], num_layers=6, )