File size: 3,915 Bytes
9041389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import logging
import unittest

import proto.entity_pb2 as entity_pb2

import domain.entity_d as entity_d


class EntityDTest(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        cls.entity_d = entity_d.EntityD(entity_id='demoid', entity_name='demo entity')

    def test_proto_roundtrip(self):
        proto = self.entity_d.to_proto()
        domain = entity_d.EntityD.from_proto(proto)
        self.assertEqual(self.entity_d.to_proto(), domain.to_proto())


class RelationshipDTest(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        cls.relationship_d = entity_d.RelationshipD(
            relationship_id="demoid",
            start_date="2024-06-01",
            end_date="2024-06-02",
            source_text="source text",
            predicted_movement=entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL)

    def test_proto_roundtrip(self):
        proto = self.relationship_d.to_proto()
        domain = entity_d.RelationshipD.from_proto(proto)
        self.assertEqual(self.relationship_d.to_proto(), domain.to_proto())

    def test_end_date_after_start_date(self):
        with self.assertRaises(ValueError):
            _ = entity_d.RelationshipD(
                relationship_id="demoid",
                start_date="2024-06-01",
                end_date="2024-05-02",
                source_text="source text",
                predicted_movement=entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL)


class EntityRelationshipDTest(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        cls.from_entity_d = entity_d.EntityD(entity_id='from_id', entity_name='from entity')
        cls.relationship_d = entity_d.RelationshipD(
            relationship_id='relationship_id',
            start_date='2024-06-01',
            end_date='2024-06-02',
            source_text='source text',
            predicted_movement=entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL)
        cls.to_entity_d = entity_d.EntityD(entity_id='to_id', entity_name='to entity')
        cls.entity_relationship_d = entity_d.EntityRelationshipD(from_entity=cls.from_entity_d,
                                                                 relationship=cls.relationship_d,
                                                                 to_entity=cls.to_entity_d)

    def test_proto_roundtrip(self):
        proto = self.entity_relationship_d.to_proto()
        domain = entity_d.EntityRelationshipD.from_proto(proto)
        self.assertEqual(self.entity_relationship_d.to_proto(), domain.to_proto())


class EntityKnowledgeGraphDTest(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        cls.from_entity_d = entity_d.EntityD(entity_id='from_id', entity_name='from entity')
        cls.relationship_d = entity_d.RelationshipD(
            relationship_id='relationship_id',
            start_date='2024-06-01',
            end_date='2024-06-02',
            source_text='source text',
            predicted_movement=entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL)
        cls.to_entity_d = entity_d.EntityD(entity_id='to_id', entity_name='to entity')
        cls.entity_relationship_d = entity_d.EntityRelationshipD(from_entity=cls.from_entity_d,
                                                                 relationship=cls.relationship_d,
                                                                 to_entity=cls.to_entity_d)

        cls.entity_knowledge_graph_d = entity_d.EntityKnowledgeGraphD(
            entity_relationships=[cls.entity_relationship_d for _ in range(2)])

    def test_proto_roundtrip(self):
        proto = self.entity_knowledge_graph_d.to_proto()
        domain = entity_d.EntityKnowledgeGraphD.from_proto(proto)
        self.assertEqual(self.entity_knowledge_graph_d.to_proto(), domain.to_proto())


if __name__ == '__main__':
    logging.getLogger().setLevel(logging.INFO)
    unittest.main()