# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import unittest import torch from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention class TestSparseMultiheadAttention(unittest.TestCase): def test_sparse_multihead_attention(self): attn_weights = torch.randn(1, 8, 8) bidirectional_sparse_mask = torch.tensor( [ [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0], [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0], [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0], [0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0], [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0], [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0], [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0], [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0], ] ) bidirectional_attention = SparseMultiheadAttention( 16, 1, stride=4, expressivity=1, is_bidirectional=True ) bidirectional_attention_sparse_mask = ( bidirectional_attention.buffered_sparse_mask(attn_weights, 8, 8) ) torch.all( torch.eq(bidirectional_attention_sparse_mask, bidirectional_sparse_mask) ) sparse_mask = torch.tensor( [ [ 0, float("-inf"), float("-inf"), float("-inf"), float("-inf"), float("-inf"), float("-inf"), float("-inf"), ], [ 0, 0, float("-inf"), float("-inf"), float("-inf"), float("-inf"), float("-inf"), float("-inf"), ], [ 0, 0, 0, float("-inf"), float("-inf"), float("-inf"), float("-inf"), float("-inf"), ], [ 0, 0, 0, 0, float("-inf"), float("-inf"), float("-inf"), float("-inf"), ], [0, 0, 0, 0, 0, float("-inf"), float("-inf"), float("-inf")], [ float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, float("-inf"), float("-inf"), ], [ float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, float("-inf"), ], [float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0], ] ) attention = SparseMultiheadAttention( 16, 1, stride=4, expressivity=1, is_bidirectional=False ) attention_sparse_mask = attention.buffered_sparse_mask(attn_weights, 8, 8) torch.all(torch.eq(attention_sparse_mask, sparse_mask)) if __name__ == "__main__": unittest.main()