import logging import unittest import proto.entity_pb2 as entity_pb2 import domain.entity_d as entity_d class EntityDTest(unittest.TestCase): @classmethod def setUpClass(cls): cls.entity_d = entity_d.EntityD(entity_id='demoid', entity_name='demo entity') def test_proto_roundtrip(self): proto = self.entity_d.to_proto() domain = entity_d.EntityD.from_proto(proto) self.assertEqual(self.entity_d.to_proto(), domain.to_proto()) class RelationshipDTest(unittest.TestCase): @classmethod def setUpClass(cls): cls.relationship_d = entity_d.RelationshipD( relationship_id="demoid", start_date="2024-06-01", end_date="2024-06-02", source_text="source text", predicted_movement=entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL) def test_proto_roundtrip(self): proto = self.relationship_d.to_proto() domain = entity_d.RelationshipD.from_proto(proto) self.assertEqual(self.relationship_d.to_proto(), domain.to_proto()) def test_end_date_after_start_date(self): with self.assertRaises(ValueError): _ = entity_d.RelationshipD( relationship_id="demoid", start_date="2024-06-01", end_date="2024-05-02", source_text="source text", predicted_movement=entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL) class EntityRelationshipDTest(unittest.TestCase): @classmethod def setUpClass(cls): cls.from_entity_d = entity_d.EntityD(entity_id='from_id', entity_name='from entity') cls.relationship_d = entity_d.RelationshipD( relationship_id='relationship_id', start_date='2024-06-01', end_date='2024-06-02', source_text='source text', predicted_movement=entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL) cls.to_entity_d = entity_d.EntityD(entity_id='to_id', entity_name='to entity') cls.entity_relationship_d = entity_d.EntityRelationshipD(from_entity=cls.from_entity_d, relationship=cls.relationship_d, to_entity=cls.to_entity_d) def test_proto_roundtrip(self): proto = self.entity_relationship_d.to_proto() domain = entity_d.EntityRelationshipD.from_proto(proto) self.assertEqual(self.entity_relationship_d.to_proto(), domain.to_proto()) class EntityKnowledgeGraphDTest(unittest.TestCase): @classmethod def setUpClass(cls): cls.from_entity_d = entity_d.EntityD(entity_id='from_id', entity_name='from entity') cls.relationship_d = entity_d.RelationshipD( relationship_id='relationship_id', start_date='2024-06-01', end_date='2024-06-02', source_text='source text', predicted_movement=entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL) cls.to_entity_d = entity_d.EntityD(entity_id='to_id', entity_name='to entity') cls.entity_relationship_d = entity_d.EntityRelationshipD(from_entity=cls.from_entity_d, relationship=cls.relationship_d, to_entity=cls.to_entity_d) cls.entity_knowledge_graph_d = entity_d.EntityKnowledgeGraphD( entity_relationships=[cls.entity_relationship_d for _ in range(2)]) def test_proto_roundtrip(self): proto = self.entity_knowledge_graph_d.to_proto() domain = entity_d.EntityKnowledgeGraphD.from_proto(proto) self.assertEqual(self.entity_knowledge_graph_d.to_proto(), domain.to_proto()) if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()