|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
import torch |
|
from fairseq.modules.sparse_multihead_attention import SparseMultiheadAttention |
|
|
|
|
|
class TestSparseMultiheadAttention(unittest.TestCase): |
|
def test_sparse_multihead_attention(self): |
|
attn_weights = torch.randn(1, 8, 8) |
|
bidirectional_sparse_mask = torch.tensor( |
|
[ |
|
[0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0], |
|
[0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0], |
|
[0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0], |
|
[0, 0, 0, 0, 0, float("-inf"), float("-inf"), 0], |
|
[float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0], |
|
[float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0], |
|
[float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0], |
|
[float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0], |
|
] |
|
) |
|
|
|
bidirectional_attention = SparseMultiheadAttention( |
|
16, 1, stride=4, expressivity=1, is_bidirectional=True |
|
) |
|
bidirectional_attention_sparse_mask = ( |
|
bidirectional_attention.buffered_sparse_mask(attn_weights, 8, 8) |
|
) |
|
torch.all( |
|
torch.eq(bidirectional_attention_sparse_mask, bidirectional_sparse_mask) |
|
) |
|
|
|
sparse_mask = torch.tensor( |
|
[ |
|
[ |
|
0, |
|
float("-inf"), |
|
float("-inf"), |
|
float("-inf"), |
|
float("-inf"), |
|
float("-inf"), |
|
float("-inf"), |
|
float("-inf"), |
|
], |
|
[ |
|
0, |
|
0, |
|
float("-inf"), |
|
float("-inf"), |
|
float("-inf"), |
|
float("-inf"), |
|
float("-inf"), |
|
float("-inf"), |
|
], |
|
[ |
|
0, |
|
0, |
|
0, |
|
float("-inf"), |
|
float("-inf"), |
|
float("-inf"), |
|
float("-inf"), |
|
float("-inf"), |
|
], |
|
[ |
|
0, |
|
0, |
|
0, |
|
0, |
|
float("-inf"), |
|
float("-inf"), |
|
float("-inf"), |
|
float("-inf"), |
|
], |
|
[0, 0, 0, 0, 0, float("-inf"), float("-inf"), float("-inf")], |
|
[ |
|
float("-inf"), |
|
float("-inf"), |
|
float("-inf"), |
|
0, |
|
0, |
|
0, |
|
float("-inf"), |
|
float("-inf"), |
|
], |
|
[ |
|
float("-inf"), |
|
float("-inf"), |
|
float("-inf"), |
|
0, |
|
0, |
|
0, |
|
0, |
|
float("-inf"), |
|
], |
|
[float("-inf"), float("-inf"), float("-inf"), 0, 0, 0, 0, 0], |
|
] |
|
) |
|
|
|
attention = SparseMultiheadAttention( |
|
16, 1, stride=4, expressivity=1, is_bidirectional=False |
|
) |
|
attention_sparse_mask = attention.buffered_sparse_mask(attn_weights, 8, 8) |
|
|
|
torch.all(torch.eq(attention_sparse_mask, sparse_mask)) |
|
|
|
|
|
if __name__ == "__main__": |
|
unittest.main() |
|
|