Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2022 The IDEA Authors. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ------------------------------------------------------------------------------------------------ | |
# Copyright (c) OpenMMLab. All rights reserved. | |
# ------------------------------------------------------------------------------------------------ | |
# Modified from: | |
# https://github.com/open-mmlab/mmcv/blob/master/tests/test_cnn/test_transformer.py | |
# ------------------------------------------------------------------------------------------------ | |
import pytest | |
import torch | |
import torch.nn as nn | |
from detrex.layers import FFN, BaseTransformerLayer, MultiheadAttention, TransformerLayerSequence | |
def test_ffn(): | |
with pytest.raises(AssertionError): | |
FFN(num_fcs=1) | |
ffn = FFN(ffn_drop=0.0) | |
input_tensor = torch.rand(2, 20, 256) | |
input_tensor_nbc = input_tensor.transpose(0, 1) | |
assert torch.allclose(ffn(input_tensor).sum(), ffn(input_tensor_nbc).sum()) | |
residual = torch.rand_like(input_tensor) | |
torch.allclose( | |
ffn(input_tensor, identity=residual).sum(), | |
ffn(input_tensor).sum() + residual.sum() - input_tensor.sum(), | |
) | |
def test_basetransformerlayer(embed_dim): | |
attn = MultiheadAttention(embed_dim=embed_dim, num_heads=8, batch_first=True) | |
ffn = FFN(embed_dim, 1024, num_fcs=2, activation=nn.ReLU(inplace=True)) | |
base_layer = BaseTransformerLayer( | |
attn=attn, | |
ffn=ffn, | |
norm=nn.LayerNorm(embed_dim), | |
operation_order=("self_attn", "norm", "ffn", "norm"), | |
) | |
feedforward_dim = 1024 | |
assert attn.batch_first is True | |
assert base_layer.ffns[0].feedforward_dim == feedforward_dim | |
in_tensor = torch.rand(2, 10, embed_dim) | |
base_layer(in_tensor) | |
def test_transformerlayersequence(): | |
sequence = TransformerLayerSequence( | |
transformer_layers=BaseTransformerLayer( | |
attn=[ | |
MultiheadAttention(256, 8, batch_first=True), | |
MultiheadAttention(256, 8, batch_first=True), | |
], | |
ffn=FFN(256, 1024, num_fcs=2), | |
norm=nn.LayerNorm(256), | |
operation_order=("self_attn", "norm", "cross_attn", "norm", "ffn", "norm"), | |
), | |
num_layers=6, | |
) | |
assert sequence.num_layers == 6 | |
with pytest.raises(AssertionError): | |
TransformerLayerSequence( | |
transformer_layers=[ | |
BaseTransformerLayer( | |
attn=[ | |
MultiheadAttention(256, 8, batch_first=True), | |
MultiheadAttention(256, 8, batch_first=True), | |
], | |
ffn=FFN(256, 1024, num_fcs=2), | |
norm=nn.LayerNorm(256), | |
operation_order=("self_attn", "norm", "cross_attn", "norm", "ffn", "norm"), | |
), | |
], | |
num_layers=6, | |
) | |