Spaces:
Runtime error
Runtime error
File size: 3,444 Bytes
3e99b05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ------------------------------------------------------------------------------------------------
# Modified from:
# https://github.com/open-mmlab/mmcv/blob/master/tests/test_cnn/test_transformer.py
# ------------------------------------------------------------------------------------------------
import pytest
import torch
import torch.nn as nn
from detrex.layers import FFN, BaseTransformerLayer, MultiheadAttention, TransformerLayerSequence
def test_ffn():
with pytest.raises(AssertionError):
FFN(num_fcs=1)
ffn = FFN(ffn_drop=0.0)
input_tensor = torch.rand(2, 20, 256)
input_tensor_nbc = input_tensor.transpose(0, 1)
assert torch.allclose(ffn(input_tensor).sum(), ffn(input_tensor_nbc).sum())
residual = torch.rand_like(input_tensor)
torch.allclose(
ffn(input_tensor, identity=residual).sum(),
ffn(input_tensor).sum() + residual.sum() - input_tensor.sum(),
)
@pytest.mark.parametrize("embed_dim", [256])
def test_basetransformerlayer(embed_dim):
attn = MultiheadAttention(embed_dim=embed_dim, num_heads=8, batch_first=True)
ffn = FFN(embed_dim, 1024, num_fcs=2, activation=nn.ReLU(inplace=True))
base_layer = BaseTransformerLayer(
attn=attn,
ffn=ffn,
norm=nn.LayerNorm(embed_dim),
operation_order=("self_attn", "norm", "ffn", "norm"),
)
feedforward_dim = 1024
assert attn.batch_first is True
assert base_layer.ffns[0].feedforward_dim == feedforward_dim
in_tensor = torch.rand(2, 10, embed_dim)
base_layer(in_tensor)
def test_transformerlayersequence():
sequence = TransformerLayerSequence(
transformer_layers=BaseTransformerLayer(
attn=[
MultiheadAttention(256, 8, batch_first=True),
MultiheadAttention(256, 8, batch_first=True),
],
ffn=FFN(256, 1024, num_fcs=2),
norm=nn.LayerNorm(256),
operation_order=("self_attn", "norm", "cross_attn", "norm", "ffn", "norm"),
),
num_layers=6,
)
assert sequence.num_layers == 6
with pytest.raises(AssertionError):
TransformerLayerSequence(
transformer_layers=[
BaseTransformerLayer(
attn=[
MultiheadAttention(256, 8, batch_first=True),
MultiheadAttention(256, 8, batch_first=True),
],
ffn=FFN(256, 1024, num_fcs=2),
norm=nn.LayerNorm(256),
operation_order=("self_attn", "norm", "cross_attn", "norm", "ffn", "norm"),
),
],
num_layers=6,
)
|