odor-detection / tests /test_cond_attn.py
mathiaszinnen's picture
Initialize app
3e99b05
# coding=utf-8
# Copyright 2022 The IDEA Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from detrex.layers import ConditionalCrossAttention, ConditionalSelfAttention
from utils import OriginalConditionalAttentionDecoder, OriginalConditionalAttentionEncoder
def test_cond_self_attention():
# hyper-parameters
d_model = 256
nhead = 8
# module definition
cond_attention_original = OriginalConditionalAttentionEncoder(d_model=d_model, nhead=nhead)
cond_attention_detrex = ConditionalSelfAttention(embed_dim=d_model, num_heads=nhead)
# weight transfer
cond_attention_detrex.query_content_proj.weight = (
cond_attention_original.sa_qcontent_proj.weight
)
cond_attention_detrex.query_content_proj.bias = cond_attention_original.sa_qcontent_proj.bias
cond_attention_detrex.query_pos_proj.weight = cond_attention_original.sa_qpos_proj.weight
cond_attention_detrex.query_pos_proj.bias = cond_attention_original.sa_qpos_proj.bias
cond_attention_detrex.key_content_proj.weight = cond_attention_original.sa_kcontent_proj.weight
cond_attention_detrex.key_content_proj.bias = cond_attention_original.sa_kcontent_proj.bias
cond_attention_detrex.key_pos_proj.weight = cond_attention_original.sa_kpos_proj.weight
cond_attention_detrex.key_pos_proj.bias = cond_attention_original.sa_kpos_proj.bias
cond_attention_detrex.value_proj.weight = cond_attention_original.sa_v_proj.weight
cond_attention_detrex.value_proj.bias = cond_attention_original.sa_v_proj.bias
cond_attention_detrex.out_proj.weight = cond_attention_original.self_attn.out_proj.weight
cond_attention_detrex.out_proj.bias = cond_attention_original.self_attn.out_proj.bias
# test output
input = torch.randn(16, 1, 256) # (n, b, c)
query_pos = torch.randn(16, 1, 256)
# self-attention with short-cut
original_output = cond_attention_original(tgt=input, query_pos=query_pos)[0] + input
detrex_output = cond_attention_detrex(
query=input, key=input, value=input, query_pos=query_pos, key_pos=query_pos
)
torch.allclose(original_output.sum(), detrex_output.sum())
def test_cond_decoder():
# hyper-parameters
d_model = 256
nhead = 8
# original conditional decoder
cond_decoder_original = OriginalConditionalAttentionDecoder(
d_model=d_model,
nhead=nhead,
)
# detrex self-attn + cross-attn
detrex_cond_self_attn = ConditionalSelfAttention(
embed_dim=d_model,
num_heads=nhead,
)
detrex_cond_cross_attn = ConditionalCrossAttention(embed_dim=d_model, num_heads=nhead)
# weight transfer
detrex_cond_self_attn.query_content_proj.weight = cond_decoder_original.sa_qcontent_proj.weight
detrex_cond_self_attn.query_content_proj.bias = cond_decoder_original.sa_qcontent_proj.bias
detrex_cond_self_attn.query_pos_proj.weight = cond_decoder_original.sa_qpos_proj.weight
detrex_cond_self_attn.query_pos_proj.bias = cond_decoder_original.sa_qpos_proj.bias
detrex_cond_self_attn.key_content_proj.weight = cond_decoder_original.sa_kcontent_proj.weight
detrex_cond_self_attn.key_content_proj.bias = cond_decoder_original.sa_kcontent_proj.bias
detrex_cond_self_attn.key_pos_proj.weight = cond_decoder_original.sa_kpos_proj.weight
detrex_cond_self_attn.key_pos_proj.bias = cond_decoder_original.sa_kpos_proj.bias
detrex_cond_self_attn.value_proj.weight = cond_decoder_original.sa_v_proj.weight
detrex_cond_self_attn.value_proj.bias = cond_decoder_original.sa_v_proj.bias
detrex_cond_self_attn.out_proj.weight = cond_decoder_original.self_attn.out_proj.weight
detrex_cond_self_attn.out_proj.bias = cond_decoder_original.self_attn.out_proj.bias
detrex_cond_cross_attn.query_content_proj.weight = cond_decoder_original.ca_qcontent_proj.weight
detrex_cond_cross_attn.query_content_proj.bias = cond_decoder_original.ca_qcontent_proj.bias
detrex_cond_cross_attn.query_pos_proj.weight = cond_decoder_original.ca_qpos_proj.weight
detrex_cond_cross_attn.query_pos_proj.bias = cond_decoder_original.ca_qpos_proj.bias
detrex_cond_cross_attn.key_content_proj.weight = cond_decoder_original.ca_kcontent_proj.weight
detrex_cond_cross_attn.key_content_proj.bias = cond_decoder_original.ca_kcontent_proj.bias
detrex_cond_cross_attn.key_pos_proj.weight = cond_decoder_original.ca_kpos_proj.weight
detrex_cond_cross_attn.key_pos_proj.bias = cond_decoder_original.ca_kpos_proj.bias
detrex_cond_cross_attn.value_proj.weight = cond_decoder_original.ca_v_proj.weight
detrex_cond_cross_attn.value_proj.bias = cond_decoder_original.ca_v_proj.bias
detrex_cond_cross_attn.out_proj.weight = cond_decoder_original.cross_attn.out_proj.weight
detrex_cond_cross_attn.out_proj.bias = cond_decoder_original.cross_attn.out_proj.bias
detrex_cond_cross_attn.query_pos_sine_proj.weight = (
cond_decoder_original.ca_qpos_sine_proj.weight
)
detrex_cond_cross_attn.query_pos_sine_proj.bias = cond_decoder_original.ca_qpos_sine_proj.bias
# test output
input = torch.randn(16, 1, 256) # (n, b, c)
query_pos = torch.randn(16, 1, 256)
key_pos = torch.randn(16, 1, 256)
query_sine_pos = torch.randn(16, 1, 256)
original_output = cond_decoder_original(
tgt=input, memory=input, query_pos=query_pos, pos=key_pos, query_sine_embed=query_sine_pos
)
# detrex cond attn output
temp = detrex_cond_self_attn(
query=input,
key=input,
value=input,
query_pos=query_pos,
key_pos=query_pos,
)
detrex_output = detrex_cond_cross_attn(
query=temp,
key=input,
value=input,
query_pos=query_pos,
key_pos=key_pos,
query_sine_embed=query_sine_pos,
is_first_layer=True,
)
torch.allclose(original_output.sum(), detrex_output.sum())