File size: 348 Bytes
c668e80
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
import unittest
from onmt.modules.structured_attention import MatrixTree

import torch


class TestStructuredAttention(unittest.TestCase):
    def test_matrix_tree_marg_pdfs_sum_to_1(self):
        dtree = MatrixTree()
        q = torch.rand(1, 5, 5)
        marg = dtree.forward(q)
        self.assertTrue(marg.sum(1).allclose(torch.tensor(1.0)))