Spaces:
Runtime error
Runtime error
File size: 6,426 Bytes
3e99b05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from detrex.layers import ConditionalCrossAttention, ConditionalSelfAttention
from utils import OriginalConditionalAttentionDecoder, OriginalConditionalAttentionEncoder
def test_cond_self_attention():
# hyper-parameters
d_model = 256
nhead = 8
# module definition
cond_attention_original = OriginalConditionalAttentionEncoder(d_model=d_model, nhead=nhead)
cond_attention_detrex = ConditionalSelfAttention(embed_dim=d_model, num_heads=nhead)
# weight transfer
cond_attention_detrex.query_content_proj.weight = (
cond_attention_original.sa_qcontent_proj.weight
)
cond_attention_detrex.query_content_proj.bias = cond_attention_original.sa_qcontent_proj.bias
cond_attention_detrex.query_pos_proj.weight = cond_attention_original.sa_qpos_proj.weight
cond_attention_detrex.query_pos_proj.bias = cond_attention_original.sa_qpos_proj.bias
cond_attention_detrex.key_content_proj.weight = cond_attention_original.sa_kcontent_proj.weight
cond_attention_detrex.key_content_proj.bias = cond_attention_original.sa_kcontent_proj.bias
cond_attention_detrex.key_pos_proj.weight = cond_attention_original.sa_kpos_proj.weight
cond_attention_detrex.key_pos_proj.bias = cond_attention_original.sa_kpos_proj.bias
cond_attention_detrex.value_proj.weight = cond_attention_original.sa_v_proj.weight
cond_attention_detrex.value_proj.bias = cond_attention_original.sa_v_proj.bias
cond_attention_detrex.out_proj.weight = cond_attention_original.self_attn.out_proj.weight
cond_attention_detrex.out_proj.bias = cond_attention_original.self_attn.out_proj.bias
# test output
input = torch.randn(16, 1, 256) # (n, b, c)
query_pos = torch.randn(16, 1, 256)
# self-attention with short-cut
original_output = cond_attention_original(tgt=input, query_pos=query_pos)[0] + input
detrex_output = cond_attention_detrex(
query=input, key=input, value=input, query_pos=query_pos, key_pos=query_pos
)
torch.allclose(original_output.sum(), detrex_output.sum())
def test_cond_decoder():
# hyper-parameters
d_model = 256
nhead = 8
# original conditional decoder
cond_decoder_original = OriginalConditionalAttentionDecoder(
d_model=d_model,
nhead=nhead,
)
# detrex self-attn + cross-attn
detrex_cond_self_attn = ConditionalSelfAttention(
embed_dim=d_model,
num_heads=nhead,
)
detrex_cond_cross_attn = ConditionalCrossAttention(embed_dim=d_model, num_heads=nhead)
# weight transfer
detrex_cond_self_attn.query_content_proj.weight = cond_decoder_original.sa_qcontent_proj.weight
detrex_cond_self_attn.query_content_proj.bias = cond_decoder_original.sa_qcontent_proj.bias
detrex_cond_self_attn.query_pos_proj.weight = cond_decoder_original.sa_qpos_proj.weight
detrex_cond_self_attn.query_pos_proj.bias = cond_decoder_original.sa_qpos_proj.bias
detrex_cond_self_attn.key_content_proj.weight = cond_decoder_original.sa_kcontent_proj.weight
detrex_cond_self_attn.key_content_proj.bias = cond_decoder_original.sa_kcontent_proj.bias
detrex_cond_self_attn.key_pos_proj.weight = cond_decoder_original.sa_kpos_proj.weight
detrex_cond_self_attn.key_pos_proj.bias = cond_decoder_original.sa_kpos_proj.bias
detrex_cond_self_attn.value_proj.weight = cond_decoder_original.sa_v_proj.weight
detrex_cond_self_attn.value_proj.bias = cond_decoder_original.sa_v_proj.bias
detrex_cond_self_attn.out_proj.weight = cond_decoder_original.self_attn.out_proj.weight
detrex_cond_self_attn.out_proj.bias = cond_decoder_original.self_attn.out_proj.bias
detrex_cond_cross_attn.query_content_proj.weight = cond_decoder_original.ca_qcontent_proj.weight
detrex_cond_cross_attn.query_content_proj.bias = cond_decoder_original.ca_qcontent_proj.bias
detrex_cond_cross_attn.query_pos_proj.weight = cond_decoder_original.ca_qpos_proj.weight
detrex_cond_cross_attn.query_pos_proj.bias = cond_decoder_original.ca_qpos_proj.bias
detrex_cond_cross_attn.key_content_proj.weight = cond_decoder_original.ca_kcontent_proj.weight
detrex_cond_cross_attn.key_content_proj.bias = cond_decoder_original.ca_kcontent_proj.bias
detrex_cond_cross_attn.key_pos_proj.weight = cond_decoder_original.ca_kpos_proj.weight
detrex_cond_cross_attn.key_pos_proj.bias = cond_decoder_original.ca_kpos_proj.bias
detrex_cond_cross_attn.value_proj.weight = cond_decoder_original.ca_v_proj.weight
detrex_cond_cross_attn.value_proj.bias = cond_decoder_original.ca_v_proj.bias
detrex_cond_cross_attn.out_proj.weight = cond_decoder_original.cross_attn.out_proj.weight
detrex_cond_cross_attn.out_proj.bias = cond_decoder_original.cross_attn.out_proj.bias
detrex_cond_cross_attn.query_pos_sine_proj.weight = (
cond_decoder_original.ca_qpos_sine_proj.weight
)
detrex_cond_cross_attn.query_pos_sine_proj.bias = cond_decoder_original.ca_qpos_sine_proj.bias
# test output
input = torch.randn(16, 1, 256) # (n, b, c)
query_pos = torch.randn(16, 1, 256)
key_pos = torch.randn(16, 1, 256)
query_sine_pos = torch.randn(16, 1, 256)
original_output = cond_decoder_original(
tgt=input, memory=input, query_pos=query_pos, pos=key_pos, query_sine_embed=query_sine_pos
)
# detrex cond attn output
temp = detrex_cond_self_attn(
query=input,
key=input,
value=input,
query_pos=query_pos,
key_pos=query_pos,
)
detrex_output = detrex_cond_cross_attn(
query=temp,
key=input,
value=input,
query_pos=query_pos,
key_pos=key_pos,
query_sine_embed=query_sine_pos,
is_first_layer=True,
)
torch.allclose(original_output.sum(), detrex_output.sum())
|