# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import unittest import torch from fairseq.modules.multihead_attention import MultiheadAttention class TestMultiheadAttention(unittest.TestCase): def test_append_prev_key_padding_mask(self): bsz = 1 src_len = 4 cases = [ # no padding mask (None, None, None), # current padding mask only ( torch.tensor([[1]]).bool(), None, torch.tensor([[0, 0, 0, 1]]).bool(), ), # previous padding mask only ( None, torch.tensor([[0, 1, 0]]).bool(), torch.tensor([[0, 1, 0, 0]]).bool(), ), # both padding masks ( torch.tensor([[1]]).bool(), torch.tensor([[0, 1, 0]]).bool(), torch.tensor([[0, 1, 0, 1]]).bool(), ), # prev_key_padding_mask already full ( torch.tensor([[0, 1, 0, 1]]).bool(), None, torch.tensor([[0, 1, 0, 1]]).bool(), ), # key_padding_mask already full ( None, torch.tensor([[0, 1, 0, 1]]).bool(), torch.tensor([[0, 1, 0, 1]]).bool(), ), ] for c in cases: key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( c[0], c[1], batch_size=bsz, src_len=src_len, static_kv=False, ) if key_padding_mask is not None: self.assertTrue( torch.all(torch.eq(key_padding_mask, c[2])), f"Unexpected resultant key padding mask: {key_padding_mask}" f" given current: {c[0]} and previous: {c[1]}", ) self.assertEqual(key_padding_mask.size(0), bsz) self.assertEqual(key_padding_mask.size(1), src_len) else: self.assertIsNone(c[2]) if __name__ == "__main__": unittest.main()