Gateston Johns commited on
Commit
9041389
1 Parent(s): fc4bc08

first real commit

Browse files
Files changed (36) hide show
  1. app.py +61 -0
  2. domain/chunk_d.py +89 -0
  3. domain/chunk_d_test.py +71 -0
  4. domain/domain_protocol.py +89 -0
  5. domain/domain_protocol_test.py +69 -0
  6. domain/entity_d.py +169 -0
  7. domain/entity_d_test.py +96 -0
  8. extraction_pipeline/base_stage.py +63 -0
  9. extraction_pipeline/document_metadata_extractor/document_metadata_extractor.py +9 -0
  10. extraction_pipeline/document_metadata_extractor/openai_document_metadata_extractor.py +73 -0
  11. extraction_pipeline/document_metadata_extractor/openai_document_metadata_extractor_test.py +74 -0
  12. extraction_pipeline/document_metadata_extractor/prompts.py +22 -0
  13. extraction_pipeline/pdf_process_stage.py +131 -0
  14. extraction_pipeline/pdf_process_stage_test.py +103 -0
  15. extraction_pipeline/pdf_to_knowledge_graph_transform.py +153 -0
  16. extraction_pipeline/relationship_extractor/entity_relationship_extractor.py +10 -0
  17. extraction_pipeline/relationship_extractor/openai_relationship_extractor.py +108 -0
  18. extraction_pipeline/relationship_extractor/openai_relationship_extractor_test.py +87 -0
  19. extraction_pipeline/relationship_extractor/prompts.py +481 -0
  20. llm_handler/llm_interface.py +21 -0
  21. llm_handler/mock_llm_handler.py +34 -0
  22. llm_handler/openai_handler.py +64 -0
  23. llm_handler/openai_handler_test.py +69 -0
  24. proto/chunk_pb2.py +30 -0
  25. proto/chunk_pb2.pyi +43 -0
  26. proto/entity_pb2.py +34 -0
  27. proto/entity_pb2.pyi +56 -0
  28. proto/pdf_document_pb2.py +26 -0
  29. proto/pdf_document_pb2.pyi +11 -0
  30. query_pipeline/evaluation_engine.py +219 -0
  31. query_pipeline/thesis_extractor.py +18 -0
  32. requirements.txt +39 -0
  33. storage/domain_dao.py +128 -0
  34. storage/domain_dao_test.py +71 -0
  35. storage/neo4j_dao.py +102 -0
  36. utils/dates.py +7 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from typing import Dict, Union, List
3
+ import json
4
+ import uuid
5
+ from storage.neo4j_dao import Neo4jDomainDAO
6
+ from domain.chunk_d import DocumentD, ChunkD
7
+ from query_pipeline.thesis_extractor import ThesisExtractor
8
+ from query_pipeline.evaluation_engine import EvaluationEngine
9
+ from proto.chunk_pb2 import Chunk, ChunkType
10
+ from datetime import datetime
11
+ import os
12
+
13
+ os.environ['OPENAI_API_KEY'] = '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'
14
+
15
+ os.environ['NEO4J_USER'] = "neo4j"
16
+ os.environ['NEO4J_PASSWORD'] = "dOIZwzF_GImgwjF-smChe60QGQgicq8ER8RUlZvornU"
17
+ os.environ['NEO4J_URI'] = "neo4j+s://2317ae21.databases.neo4j.io"
18
+
19
+ def thesis_evaluation(thesis: str) -> str:
20
+ thesis_document_d = DocumentD(file_path="",
21
+ authors="user",
22
+ publish_date="2024-06-18")
23
+
24
+ thesis_chunk_d = ChunkD(chunk_text=thesis,
25
+ chunk_type=ChunkType.CHUNK_TYPE_PAGE,
26
+ chunk_index=0,
27
+ parent_reference=thesis_document_d)
28
+
29
+ thesis_entity_kg = ThesisExtractor().extract_relationships(thesis_chunk_d)
30
+
31
+ with Neo4jDomainDAO() as neo4j_dao:
32
+ full_db_kg = neo4j_dao.get_knowledge_graph()
33
+
34
+ return EvaluationEngine(full_db_kg).evaluate_and_display_thesis(thesis_entity_kg)
35
+
36
+
37
+ theme = gr.themes.Soft(
38
+ primary_hue="cyan",
39
+ secondary_hue="green",
40
+ neutral_hue="slate",
41
+ )
42
+
43
+ with gr.Blocks(theme=theme) as demo:
44
+ gr.HTML("""
45
+ <div style="background-color: rgb(171, 188, 251);">
46
+ <div style="background-color: rgb(151, 168, 231); padding-top: 2px; padding-bottom: 7px; text-align: center;">
47
+ <h1>AthenaAIC MetisLLM Thesis Evaluation Tool</h1>
48
+ </div>
49
+ </div>""")
50
+ with gr.Row():
51
+ inp = gr.Textbox(placeholder="What is your thesis?", scale=3, lines=3, show_label=False)
52
+ submit = gr.Button(scale=1, value="Evaluate")
53
+ with gr.Row():
54
+ out = gr.HTML()
55
+
56
+ submit.click(thesis_evaluation, inputs=inp, outputs=out)
57
+
58
+ demo.launch(
59
+ auth=('athena-admin', 'athena'),
60
+ share=True
61
+ )
domain/chunk_d.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import dataclasses
3
+ import uuid
4
+ from typing import Union
5
+ import hashlib
6
+
7
+ import proto.chunk_pb2 as chunk_pb2
8
+ from domain.domain_protocol import DomainProtocol
9
+
10
+
11
+ @dataclasses.dataclass(frozen=True)
12
+ class DocumentD(DomainProtocol[chunk_pb2.Document]):
13
+ file_path: str
14
+ authors: str
15
+ publish_date: str
16
+
17
+ @property
18
+ def id(self) -> str:
19
+ return hashlib.sha256(self.to_proto().SerializeToString()).hexdigest()
20
+
21
+ @classmethod
22
+ def _from_proto(cls, proto: chunk_pb2.Document) -> DocumentD:
23
+ return cls(file_path=proto.file_path,
24
+ authors=proto.authors,
25
+ publish_date=proto.publish_date)
26
+
27
+ def to_proto(self) -> chunk_pb2.Document:
28
+ return chunk_pb2.Document(file_path=self.file_path,
29
+ authors=self.authors,
30
+ publish_date=self.publish_date)
31
+
32
+
33
+ @dataclasses.dataclass(frozen=True)
34
+ class ChunkD(DomainProtocol[chunk_pb2.Chunk]):
35
+
36
+ @property
37
+ def id(self) -> str:
38
+ return str(self.chunk_id)
39
+
40
+ chunk_text: str
41
+ chunk_type: chunk_pb2.ChunkType
42
+ chunk_index: int
43
+ parent_reference: Union[uuid.UUID, DocumentD]
44
+ chunk_id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)
45
+
46
+ def __post_init__(self):
47
+ if self.chunk_type == chunk_pb2.ChunkType.CHUNK_TYPE_PAGE:
48
+ if not isinstance(self.parent_reference, DocumentD):
49
+ raise ValueError(
50
+ f"Chunk (id: {self.chunk_id}) with type {self.chunk_type} must have a DocumentD parent_reference."
51
+ )
52
+ elif not isinstance(self.parent_reference, uuid.UUID):
53
+ raise ValueError(
54
+ f"Chunk (id: {self.chunk_id}) with type {self.chunk_type} must have a uuid.UUID parent_reference."
55
+ )
56
+
57
+ @classmethod
58
+ def _from_proto(cls, proto: chunk_pb2.Chunk) -> ChunkD:
59
+ if proto.HasField('parent_chunk_id'):
60
+ return cls(chunk_id=uuid.UUID(proto.chunk_id),
61
+ parent_reference=uuid.UUID(proto.parent_chunk_id),
62
+ chunk_text=proto.chunk_text,
63
+ chunk_type=proto.chunk_type,
64
+ chunk_index=proto.chunk_index)
65
+ elif proto.HasField('document'):
66
+ return cls(chunk_id=uuid.UUID(proto.chunk_id),
67
+ parent_reference=DocumentD._from_proto(proto.document),
68
+ chunk_text=proto.chunk_text,
69
+ chunk_type=proto.chunk_type,
70
+ chunk_index=proto.chunk_index)
71
+ else:
72
+ raise ValueError(
73
+ f"Chunk proto (id: {proto.chunk_id}) has no 'parent' or 'document' field.")
74
+
75
+ def to_proto(self) -> chunk_pb2.Chunk:
76
+ chunk_proto = chunk_pb2.Chunk()
77
+ chunk_proto.chunk_id = str(self.chunk_id)
78
+ chunk_proto.chunk_text = self.chunk_text
79
+ chunk_proto.chunk_type = self.chunk_type
80
+ chunk_proto.chunk_index = self.chunk_index
81
+ if isinstance(self.parent_reference, uuid.UUID):
82
+ chunk_proto.parent_chunk_id = str(self.parent_reference)
83
+ elif isinstance(self.parent_reference, DocumentD):
84
+ chunk_proto.document.CopyFrom(self.parent_reference.to_proto())
85
+ else:
86
+ raise ValueError(
87
+ f"Chunk (id: {self.chunk_id}) parent_reference is of unknown type: {type(self.parent_reference)}"
88
+ )
89
+ return chunk_proto
domain/chunk_d_test.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import logging
3
+ import uuid
4
+
5
+ from domain.chunk_d import DocumentD, ChunkD
6
+ import proto.chunk_pb2 as chunk_pb2
7
+
8
+
9
+ class DocumentDTest(unittest.TestCase):
10
+
11
+ @classmethod
12
+ def setUpClass(cls):
13
+ cls.document_d = DocumentD(file_path='test_file',
14
+ authors='BofA John Doe Herb Johnson Taylor Mason',
15
+ publish_date='2021-01-01')
16
+
17
+ def test_proto_roundtrip(self):
18
+ document_proto = self.document_d.to_proto()
19
+ roundtrip_document_d = DocumentD.from_proto(document_proto)
20
+ self.assertEqual(self.document_d, roundtrip_document_d)
21
+
22
+
23
+ class ChunkDTest(unittest.TestCase):
24
+
25
+ @classmethod
26
+ def setUpClass(cls):
27
+ cls.parent_reference_document_d = DocumentD(
28
+ file_path='test_file',
29
+ authors='BofA John Doe Herb Johnson Taylor Mason',
30
+ publish_date='2021-01-01')
31
+
32
+ def test_validate(self):
33
+ with self.assertRaises(ValueError,
34
+ msg="CHUNK_TYPE_PAGE must have a DocumentD parent_reference"):
35
+ ChunkD(chunk_text='test chunk text',
36
+ chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PAGE,
37
+ chunk_index=9,
38
+ parent_reference=uuid.uuid4(),
39
+ chunk_id=uuid.uuid4())
40
+ with self.assertRaises(ValueError,
41
+ msg="CHUNK_TYPE_SENTENCE must have a uuid.UUID parent_reference"):
42
+ ChunkD(chunk_text='test chunk text',
43
+ chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_SENTENCE,
44
+ chunk_index=0,
45
+ parent_reference=self.parent_reference_document_d,
46
+ chunk_id=uuid.uuid4())
47
+ ChunkD(chunk_text='test chunk text',
48
+ chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PAGE,
49
+ chunk_index=9,
50
+ parent_reference=self.parent_reference_document_d,
51
+ chunk_id=uuid.uuid4())
52
+ ChunkD(chunk_text='test chunk text',
53
+ chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_SENTENCE,
54
+ chunk_index=0,
55
+ parent_reference=uuid.uuid4(),
56
+ chunk_id=uuid.uuid4())
57
+
58
+ def test_proto_roundtrip(self):
59
+ test_chunk_d = ChunkD(chunk_id=uuid.uuid4(),
60
+ parent_reference=self.parent_reference_document_d,
61
+ chunk_text='test chunk text',
62
+ chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PAGE,
63
+ chunk_index=9)
64
+ chunk_proto = test_chunk_d.to_proto()
65
+ roundtrip_chunk_d = ChunkD.from_proto(chunk_proto)
66
+ self.assertEqual(test_chunk_d, roundtrip_chunk_d)
67
+
68
+
69
+ if __name__ == '__main__':
70
+ logging.basicConfig(level=logging.INFO)
71
+ unittest.main()
domain/domain_protocol.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Protocol, Tuple, Type, TypeVar, get_args
4
+
5
+ from google.protobuf import json_format, message
6
+
7
+ MessageType = TypeVar("MessageType", bound=message.Message)
8
+ DomainProtocolType = TypeVar("DomainProtocolType", bound='DomainProtocol')
9
+
10
+
11
+ class ProtoDeserializationError(Exception):
12
+ ...
13
+
14
+
15
+ class DomainProtocol(Protocol[MessageType]):
16
+
17
+ @property
18
+ def id(self) -> str:
19
+ ...
20
+
21
+ @classmethod
22
+ def _from_proto(cls: Type[DomainProtocolType], proto: MessageType) -> DomainProtocolType:
23
+ ...
24
+
25
+ def to_proto(self) -> MessageType:
26
+ ...
27
+
28
+ @classmethod
29
+ def message_cls(cls: Type[DomainProtocolType]) -> Type[MessageType]:
30
+ orig_bases: Optional[Tuple[Type[MessageType], ...]] = getattr(cls, "__orig_bases__", None)
31
+ if not orig_bases:
32
+ raise ValueError(f"Class {cls} does not have __orig_bases__")
33
+ if len(orig_bases) != 1:
34
+ raise ValueError(f"Class {cls} has unexpected number of bases: {orig_bases}")
35
+ return get_args(orig_bases[0])[0]
36
+
37
+ @classmethod
38
+ def from_proto(cls: Type[DomainProtocolType],
39
+ proto: MessageType,
40
+ allow_empty: bool = False) -> DomainProtocolType:
41
+ try:
42
+ if not allow_empty:
43
+ cls.validate_proto_not_empty(proto)
44
+ return cls._from_proto(proto)
45
+ except Exception as e:
46
+ error_str = f"Failed to convert {cls} - {e}"
47
+ raise ProtoDeserializationError(error_str) from e
48
+
49
+ @classmethod
50
+ def from_json(cls: Type[DomainProtocolType], json_str: str) -> DomainProtocolType:
51
+ try:
52
+ proto_cls = cls.message_cls()
53
+ proto = proto_cls()
54
+ json_format.Parse(json_str, proto)
55
+ return cls.from_proto(proto)
56
+ except json_format.ParseError as e:
57
+ error_str = f"{cls} failed to parse json string: {json_str} - {e}"
58
+ raise ProtoDeserializationError(error_str) from e
59
+
60
+ def to_json(self) -> str:
61
+ return json_format.MessageToJson(self.to_proto()).replace("\n", " ")
62
+
63
+ @classmethod
64
+ def validate_proto_not_empty(cls, proto: message.Message):
65
+ if cls.is_empty(proto):
66
+ raise ValueError("Proto is empty")
67
+
68
+ @classmethod
69
+ def is_empty(cls, proto: message.Message) -> bool:
70
+ descriptor = getattr(proto, 'DESCRIPTOR', None)
71
+ fields = list(descriptor.fields) if descriptor else []
72
+ while fields:
73
+ field = fields.pop()
74
+
75
+ if field.label == field.LABEL_REPEATED:
76
+ eval_func = lambda x: x == field.default_value
77
+ if field.type == field.TYPE_MESSAGE:
78
+ eval_func = cls.is_empty
79
+ if not all([eval_func(item) for item in getattr(proto, field.name)]):
80
+ return False
81
+
82
+ elif field.type == field.TYPE_MESSAGE:
83
+ if not cls.is_empty(getattr(proto, field.name)):
84
+ return False
85
+ else:
86
+ field_value = getattr(proto, field.name)
87
+ if field_value != field.default_value:
88
+ return False
89
+ return True
domain/domain_protocol_test.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import logging
3
+ import unittest
4
+ from google.protobuf import timestamp_pb2
5
+ import dataclasses
6
+
7
+ from domain.domain_protocol import DomainProtocol, ProtoDeserializationError
8
+
9
+
10
+ @dataclasses.dataclass(frozen=True)
11
+ class TimestampTestD(DomainProtocol[timestamp_pb2.Timestamp]):
12
+
13
+ nanos: int
14
+
15
+ @property
16
+ def id(self) -> str:
17
+ return str(self.nanos)
18
+
19
+ @classmethod
20
+ def _from_proto(cls, proto: timestamp_pb2.Timestamp) -> TimestampTestD:
21
+ return cls(nanos=proto.nanos)
22
+
23
+ def to_proto(self) -> timestamp_pb2.Timestamp:
24
+ return timestamp_pb2.Timestamp(nanos=self.nanos)
25
+
26
+
27
+ class DomainProtocolTest(unittest.TestCase):
28
+
29
+ @classmethod
30
+ def setUpClass(cls) -> None:
31
+ cls.timestamp_d = TimestampTestD(nanos=1)
32
+ cls.timestamp_proto = timestamp_pb2.Timestamp(nanos=1)
33
+
34
+ def test_proto_roundtrip(self):
35
+ proto = self.timestamp_d.to_proto()
36
+ domain_from_proto = TimestampTestD.from_proto(proto)
37
+ self.assertEqual(self.timestamp_d, domain_from_proto)
38
+
39
+ def test_json_roundtrip(self):
40
+ json_str = self.timestamp_d.to_json()
41
+ domain_from_json = TimestampTestD.from_json(json_str)
42
+ self.assertEqual(self.timestamp_d, domain_from_json)
43
+
44
+ def test_from_proto_empty_fail(self):
45
+ empty_proto = timestamp_pb2.Timestamp()
46
+ with self.assertRaises(ProtoDeserializationError):
47
+ TimestampTestD.from_proto(empty_proto)
48
+
49
+ def test_from_proto_empty_allowed_flag(self):
50
+ empty_proto = timestamp_pb2.Timestamp()
51
+ domain_from_proto = TimestampTestD.from_proto(empty_proto, allow_empty=True)
52
+ self.assertEqual(TimestampTestD(nanos=0), domain_from_proto)
53
+
54
+ def test_validate_proto_not_empty(self):
55
+ empty_proto = timestamp_pb2.Timestamp()
56
+ with self.assertRaises(ValueError):
57
+ TimestampTestD.validate_proto_not_empty(empty_proto)
58
+
59
+ def test_is_empty(self):
60
+ empty_proto = timestamp_pb2.Timestamp()
61
+ self.assertTrue(TimestampTestD.is_empty(empty_proto))
62
+
63
+ def test_message_cls(self):
64
+ self.assertEqual(timestamp_pb2.Timestamp, TimestampTestD.message_cls())
65
+
66
+
67
+ if __name__ == "__main__":
68
+ logging.basicConfig(level=logging.INFO)
69
+ unittest.main()
domain/entity_d.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import hashlib
5
+ import logging
6
+ from typing import TypeAlias, Union
7
+
8
+ import proto.entity_pb2 as entity_pb2
9
+
10
+ from domain.domain_protocol import DomainProtocol
11
+ from utils.dates import parse_date
12
+
13
+ Neo4jDict: TypeAlias = dict[str, Union[str, int, bool, list[str], list[int], list[bool]]]
14
+
15
+
16
+ @dataclasses.dataclass(frozen=True)
17
+ class EntityD(DomainProtocol[entity_pb2.Entity]):
18
+ entity_id: str
19
+ entity_name: str
20
+
21
+ @property
22
+ def id(self) -> str:
23
+ return self.entity_id
24
+
25
+ @classmethod
26
+ def _from_proto(cls, proto: entity_pb2.Entity) -> EntityD:
27
+ return EntityD(entity_id=proto.entity_id, entity_name=proto.entity_name)
28
+
29
+ def to_proto(self) -> entity_pb2.Entity:
30
+ return entity_pb2.Entity(entity_id=self.entity_id, entity_name=self.entity_name)
31
+
32
+ @property
33
+ def neo4j_create_cmd(self):
34
+ # TODO store entity_id?
35
+ return "MERGE (e:Entity {name: $name}) ON CREATE SET e.pdf_file = $pdf_file"
36
+
37
+ @property
38
+ def neo4j_create_args(self) -> Neo4jDict:
39
+ return {
40
+ "name": self.entity_name,
41
+ }
42
+
43
+
44
+ @dataclasses.dataclass(frozen=True)
45
+ class RelationshipD(DomainProtocol[entity_pb2.Relationship]):
46
+ relationship_id: str
47
+ start_date: str
48
+ end_date: str
49
+ source_text: str
50
+ predicted_movement: entity_pb2.PredictedMovement
51
+
52
+ @property
53
+ def id(self) -> str:
54
+ return self.relationship_id
55
+
56
+ def __post_init__(self):
57
+ if self.start_date and self.end_date:
58
+ start = parse_date(self.start_date)
59
+ end = parse_date(self.end_date)
60
+
61
+ if end < start:
62
+ logging.warning("end_date %s is before start_date %s",
63
+ self.end_date,
64
+ self.start_date)
65
+ # raise ValueError(f"end_date {self.end_date} is before start_date {self.start_date}")
66
+
67
+ @classmethod
68
+ def _from_proto(cls, proto: entity_pb2.Relationship) -> RelationshipD:
69
+ return RelationshipD(relationship_id=proto.relationship_id,
70
+ start_date=proto.start_date,
71
+ end_date=proto.end_date,
72
+ source_text=proto.source_text,
73
+ predicted_movement=proto.predicted_movement)
74
+
75
+ def to_proto(self) -> entity_pb2.Relationship:
76
+ return entity_pb2.Relationship(relationship_id=self.relationship_id,
77
+ start_date=self.start_date,
78
+ end_date=self.end_date,
79
+ source_text=self.source_text,
80
+ predicted_movement=self.predicted_movement)
81
+
82
+ @classmethod
83
+ def from_string(cls, relationship: str) -> entity_pb2.PredictedMovement:
84
+ if relationship == "PREDICTED_MOVEMENT_NEUTRAL":
85
+ return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL
86
+ elif relationship == "PREDICTED_MOVEMENT_INCREASE":
87
+ return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_INCREASE
88
+ elif relationship == "PREDICTED_MOVEMENT_DECREASE":
89
+ return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_DECREASE
90
+ else:
91
+ return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_UNSPECIFIED
92
+
93
+ @property
94
+ def neo4j_create_cmd(self):
95
+ return """MATCH (from:Entity {name: $from_name})
96
+ MATCH (to:Entity {name: $to_name})
97
+ MERGE (from) -[r:Relationship {start_date: $start_date, end_date: $end_date, predicted_movement: $predicted_movement}]-> (to) ON CREATE SET r.source_text = $source_text, r.pdf_file = $pdf_file"""
98
+
99
+ @property
100
+ def neo4j_create_args(self) -> Neo4jDict:
101
+ return {
102
+ "start_date": self.start_date,
103
+ "end_date": self.end_date,
104
+ "predicted_movement": entity_pb2.PredictedMovement.Name(self.predicted_movement),
105
+ "source_text": self.source_text,
106
+ }
107
+
108
+
109
+ @dataclasses.dataclass(frozen=True)
110
+ class EntityRelationshipD(DomainProtocol[entity_pb2.EntityRelationship]):
111
+ from_entity: EntityD
112
+ relationship: RelationshipD
113
+ to_entity: EntityD
114
+
115
+ @property
116
+ def id(self) -> str:
117
+ return hashlib.sha256(self.to_proto().SerializeToString()).hexdigest()
118
+
119
+ @classmethod
120
+ def _from_proto(cls, proto: entity_pb2.EntityRelationship) -> EntityRelationshipD:
121
+ return EntityRelationshipD(from_entity=EntityD._from_proto(proto.from_entity),
122
+ relationship=RelationshipD._from_proto(proto.relationship),
123
+ to_entity=EntityD._from_proto(proto.to_entity))
124
+
125
+ def to_proto(self) -> entity_pb2.EntityRelationship:
126
+ return entity_pb2.EntityRelationship(from_entity=self.from_entity.to_proto(),
127
+ relationship=self.relationship.to_proto(),
128
+ to_entity=self.to_entity.to_proto())
129
+
130
+ @property
131
+ def neo4j_create_cmds(self):
132
+ return [
133
+ self.from_entity.neo4j_create_cmd,
134
+ self.to_entity.neo4j_create_cmd,
135
+ self.relationship.neo4j_create_cmd
136
+ ]
137
+
138
+ @property
139
+ def neo4j_create_args(self) -> list[Neo4jDict]:
140
+ relationship_args = {
141
+ **self.relationship.neo4j_create_args,
142
+ 'from_name': self.from_entity.entity_name,
143
+ 'to_name': self.to_entity.entity_name,
144
+ }
145
+
146
+ return [
147
+ self.from_entity.neo4j_create_args, self.to_entity.neo4j_create_args, relationship_args
148
+ ]
149
+
150
+
151
+ @dataclasses.dataclass(frozen=True)
152
+ class EntityKnowledgeGraphD(DomainProtocol[entity_pb2.EntityKnowledgeGraph]):
153
+ entity_relationships: list[EntityRelationshipD]
154
+
155
+ @property
156
+ def id(self) -> str:
157
+ return hashlib.sha256(self.to_proto().SerializeToString()).hexdigest()
158
+
159
+ @classmethod
160
+ def _from_proto(cls, proto: entity_pb2.EntityKnowledgeGraph) -> EntityKnowledgeGraphD:
161
+ return EntityKnowledgeGraphD(entity_relationships=[
162
+ EntityRelationshipD._from_proto(entity_relationship)
163
+ for entity_relationship in proto.entity_relationships
164
+ ])
165
+
166
+ def to_proto(self) -> entity_pb2.EntityKnowledgeGraph:
167
+ return entity_pb2.EntityKnowledgeGraph(entity_relationships=[
168
+ entity_relationship.to_proto() for entity_relationship in self.entity_relationships
169
+ ])
domain/entity_d_test.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import unittest
3
+
4
+ import proto.entity_pb2 as entity_pb2
5
+
6
+ import domain.entity_d as entity_d
7
+
8
+
9
+ class EntityDTest(unittest.TestCase):
10
+
11
+ @classmethod
12
+ def setUpClass(cls):
13
+ cls.entity_d = entity_d.EntityD(entity_id='demoid', entity_name='demo entity')
14
+
15
+ def test_proto_roundtrip(self):
16
+ proto = self.entity_d.to_proto()
17
+ domain = entity_d.EntityD.from_proto(proto)
18
+ self.assertEqual(self.entity_d.to_proto(), domain.to_proto())
19
+
20
+
21
+ class RelationshipDTest(unittest.TestCase):
22
+
23
+ @classmethod
24
+ def setUpClass(cls):
25
+ cls.relationship_d = entity_d.RelationshipD(
26
+ relationship_id="demoid",
27
+ start_date="2024-06-01",
28
+ end_date="2024-06-02",
29
+ source_text="source text",
30
+ predicted_movement=entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL)
31
+
32
+ def test_proto_roundtrip(self):
33
+ proto = self.relationship_d.to_proto()
34
+ domain = entity_d.RelationshipD.from_proto(proto)
35
+ self.assertEqual(self.relationship_d.to_proto(), domain.to_proto())
36
+
37
+ def test_end_date_after_start_date(self):
38
+ with self.assertRaises(ValueError):
39
+ _ = entity_d.RelationshipD(
40
+ relationship_id="demoid",
41
+ start_date="2024-06-01",
42
+ end_date="2024-05-02",
43
+ source_text="source text",
44
+ predicted_movement=entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL)
45
+
46
+
47
+ class EntityRelationshipDTest(unittest.TestCase):
48
+
49
+ @classmethod
50
+ def setUpClass(cls):
51
+ cls.from_entity_d = entity_d.EntityD(entity_id='from_id', entity_name='from entity')
52
+ cls.relationship_d = entity_d.RelationshipD(
53
+ relationship_id='relationship_id',
54
+ start_date='2024-06-01',
55
+ end_date='2024-06-02',
56
+ source_text='source text',
57
+ predicted_movement=entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL)
58
+ cls.to_entity_d = entity_d.EntityD(entity_id='to_id', entity_name='to entity')
59
+ cls.entity_relationship_d = entity_d.EntityRelationshipD(from_entity=cls.from_entity_d,
60
+ relationship=cls.relationship_d,
61
+ to_entity=cls.to_entity_d)
62
+
63
+ def test_proto_roundtrip(self):
64
+ proto = self.entity_relationship_d.to_proto()
65
+ domain = entity_d.EntityRelationshipD.from_proto(proto)
66
+ self.assertEqual(self.entity_relationship_d.to_proto(), domain.to_proto())
67
+
68
+
69
+ class EntityKnowledgeGraphDTest(unittest.TestCase):
70
+
71
+ @classmethod
72
+ def setUpClass(cls):
73
+ cls.from_entity_d = entity_d.EntityD(entity_id='from_id', entity_name='from entity')
74
+ cls.relationship_d = entity_d.RelationshipD(
75
+ relationship_id='relationship_id',
76
+ start_date='2024-06-01',
77
+ end_date='2024-06-02',
78
+ source_text='source text',
79
+ predicted_movement=entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL)
80
+ cls.to_entity_d = entity_d.EntityD(entity_id='to_id', entity_name='to entity')
81
+ cls.entity_relationship_d = entity_d.EntityRelationshipD(from_entity=cls.from_entity_d,
82
+ relationship=cls.relationship_d,
83
+ to_entity=cls.to_entity_d)
84
+
85
+ cls.entity_knowledge_graph_d = entity_d.EntityKnowledgeGraphD(
86
+ entity_relationships=[cls.entity_relationship_d for _ in range(2)])
87
+
88
+ def test_proto_roundtrip(self):
89
+ proto = self.entity_knowledge_graph_d.to_proto()
90
+ domain = entity_d.EntityKnowledgeGraphD.from_proto(proto)
91
+ self.assertEqual(self.entity_knowledge_graph_d.to_proto(), domain.to_proto())
92
+
93
+
94
+ if __name__ == '__main__':
95
+ logging.getLogger().setLevel(logging.INFO)
96
+ unittest.main()
extraction_pipeline/base_stage.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import dataclasses
4
+ import logging
5
+ from typing import Iterable, Protocol, Set, TypeVar
6
+
7
+ from domain.domain_protocol import DomainProtocol
8
+
9
+
10
+ class StageError(Exception):
11
+ pass
12
+
13
+
14
+ StageInputType = TypeVar("StageInputType", bound=DomainProtocol, contravariant=True)
15
+ StageOutputType = TypeVar("StageOutputType", bound=DomainProtocol, covariant=True)
16
+
17
+
18
+ @dataclasses.dataclass(frozen=True)
19
+ class BaseStage(Protocol[StageInputType, StageOutputType]):
20
+
21
+ @property
22
+ def name(self) -> str:
23
+ return self.__class__.__name__
24
+
25
+ def _process_element(self, element: StageInputType) -> Iterable[StageOutputType]:
26
+ ...
27
+
28
+ def process_element(self, element: StageInputType) -> Iterable[StageOutputType]:
29
+ try:
30
+ logging.info(f"{self.name} processing element {element}")
31
+ yield from self._process_element(element)
32
+ except Exception as e:
33
+ logging.error(e)
34
+ raise StageError(f"Error processing element {element}") from e
35
+
36
+
37
+ class TransformError(Exception):
38
+ pass
39
+
40
+
41
+ TransformInputType = TypeVar("TransformInputType", bound=DomainProtocol, contravariant=True)
42
+ TransformOutputType = TypeVar("TransformOutputType", bound=DomainProtocol, covariant=True)
43
+
44
+
45
+ @dataclasses.dataclass(frozen=True)
46
+ class BaseTransform(Protocol[TransformInputType, TransformOutputType]):
47
+
48
+ @property
49
+ def name(self) -> str:
50
+ return self.__class__.__name__
51
+
52
+ def _process_collection(
53
+ self, collection: Iterable[TransformInputType]) -> Iterable[TransformOutputType]:
54
+ ...
55
+
56
+ def process_collection(
57
+ self, collection: Iterable[TransformInputType]) -> Iterable[TransformOutputType]:
58
+ try:
59
+ logging.info(f"{self.name} processing collection")
60
+ yield from self._process_collection(collection)
61
+ except Exception as e:
62
+ logging.error(e)
63
+ raise TransformError(f"Error processing collection") from e
extraction_pipeline/document_metadata_extractor/document_metadata_extractor.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+
3
+ from extraction_pipeline.base_stage import BaseStage
4
+ from domain.chunk_d import DocumentD
5
+
6
+
7
+ @dataclasses.dataclass(frozen=True)
8
+ class DocumentMetadataExtractor(BaseStage[DocumentD, DocumentD]):
9
+ ...
extraction_pipeline/document_metadata_extractor/openai_document_metadata_extractor.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import json
3
+ from typing import ClassVar, Dict, Iterable, Iterator, List, Optional, Tuple, Union
4
+
5
+ import pymupdf
6
+ from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
7
+
8
+ from domain.chunk_d import DocumentD
9
+ from extraction_pipeline.document_metadata_extractor.document_metadata_extractor import (
10
+ DocumentMetadataExtractor,)
11
+ from extraction_pipeline.document_metadata_extractor.prompts import (
12
+ DOCUMENT_METADATA_PROMPT,)
13
+ from llm_handler.llm_interface import LLMInterface
14
+ from llm_handler.openai_handler import ChatModelVersion, OpenAIHandler
15
+ from utils.dates import parse_date
16
+
17
+
18
+ class CreationDateError(Exception):
19
+ pass
20
+
21
+
22
+ class AuthorsError(Exception):
23
+ pass
24
+
25
+
26
+ class OpenAIDocumentMetadataExtractor(DocumentMetadataExtractor):
27
+
28
+ _handler: LLMInterface
29
+ _MODEL_VERSION: ChatModelVersion = ChatModelVersion.GPT_4_O
30
+ _AUTHORS_KEY: ClassVar[str] = "authors"
31
+ _PUBLISH_DATE_KEY: ClassVar[str] = "publish_date"
32
+ _TEMPARATURE: ClassVar[float] = 0.2
33
+
34
+ def __init__(self,
35
+ openai_handler: Optional[LLMInterface] = None,
36
+ model_version: Optional[ChatModelVersion] = None):
37
+ self._handler = openai_handler or OpenAIHandler()
38
+ self._model_version = model_version or self._MODEL_VERSION
39
+
40
+ def _validate_text(self, completion_text: Dict[str, Union[str, List[str]]]):
41
+ if not completion_text.get(self._AUTHORS_KEY):
42
+ raise AuthorsError("No authors found.")
43
+ if not completion_text.get(self._PUBLISH_DATE_KEY):
44
+ raise CreationDateError("No creation date found.")
45
+
46
+ publish_date_str: str = str(completion_text.get(self._PUBLISH_DATE_KEY, ""))
47
+ try:
48
+ parse_date(publish_date_str)
49
+ except ValueError as e:
50
+ raise CreationDateError(
51
+ f"Failed to parse publish date '{publish_date_str}': {e}") from e
52
+
53
+ def _process_element(self, element: DocumentD) -> Iterable[DocumentD]:
54
+ pdf_document_pages: Iterator = pymupdf.open(element.file_path).pages()
55
+ first_page_text: str = next(pdf_document_pages).get_text()
56
+ messages: List[ChatCompletionMessageParam] = [{
57
+ "role": "system", "content": DOCUMENT_METADATA_PROMPT
58
+ },
59
+ {
60
+ "role": "user",
61
+ "content": f"Input:\n{first_page_text}"
62
+ }]
63
+ completion_text_raw = self._handler.get_chat_completion(
64
+ messages=messages,
65
+ model=self._model_version,
66
+ temperature=self._TEMPARATURE,
67
+ response_format={"type": "json_object"})
68
+ completion_text: Dict[str, Union[str, List[str]]] = dict(json.loads(completion_text_raw))
69
+ self._validate_text(completion_text)
70
+ authors: str = ", ".join(completion_text.get(self._AUTHORS_KEY, []))
71
+ publish_date: str = str(completion_text.get(self._PUBLISH_DATE_KEY, ""))
72
+
73
+ yield dataclasses.replace(element, authors=authors, publish_date=publish_date)
extraction_pipeline/document_metadata_extractor/openai_document_metadata_extractor_test.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import unittest
3
+ import uuid
4
+ from typing import List
5
+
6
+ from domain.chunk_d import DocumentD
7
+ from extraction_pipeline.document_metadata_extractor.openai_document_metadata_extractor import OpenAIDocumentMetadataExtractor, AuthorsError, CreationDateError
8
+ from llm_handler.mock_llm_handler import MockLLMHandler
9
+
10
+ DOCUMENT_METADATA_EXTRACTION_RESPONSE = '''
11
+ {
12
+ "authors": ["BofA Global Research", "Michael Hartnett", "Elyas Galou", "Anya Shelekhin", "Myung-Jee Jung"],
13
+ "publish_date": "2023-04-13"
14
+ }
15
+ '''
16
+
17
+
18
+ class TestOpenAIDocumentMetadataExtractor(unittest.TestCase):
19
+
20
+ @classmethod
21
+ def setUpClass(cls) -> None:
22
+ cls.test_pdf_path = "extraction_pipeline/test_data/test.pdf"
23
+ cls.start_document_d = DocumentD(file_path=cls.test_pdf_path, authors="", publish_date="")
24
+ cls.final_document_d = DocumentD(
25
+ file_path=cls.test_pdf_path,
26
+ authors=
27
+ "BofA Global Research, Michael Hartnett, Elyas Galou, Anya Shelekhin, Myung-Jee Jung",
28
+ publish_date="2023-04-13")
29
+ cls.openai_publish_details_extractor = OpenAIDocumentMetadataExtractor()
30
+
31
+ def test__validate_text_missing_publishers(self):
32
+ missing_publishers_text = {"publish_date": "2023-12-13"}
33
+ with self.assertRaises(AuthorsError):
34
+ self.openai_publish_details_extractor._validate_text(
35
+ missing_publishers_text) # type: ignore
36
+
37
+ def test__validate_text_invalid_date(self):
38
+ invalid_date_text = {
39
+ "authors": [
40
+ "BofA Global Research",
41
+ "Michael Hartnett",
42
+ "Elyas Galou",
43
+ "Anya Shelekhin",
44
+ "Myung-Jee Jung"
45
+ ],
46
+ "publish_date": "2-13"
47
+ }
48
+ with self.assertRaises(CreationDateError):
49
+ self.openai_publish_details_extractor._validate_text(invalid_date_text) # type: ignore
50
+
51
+ def test__validate_text_valid(self):
52
+ valid_text = {
53
+ "authors": [
54
+ "BofA Global Research",
55
+ "Michael Hartnett",
56
+ "Elyas Galou",
57
+ "Anya Shelekhin",
58
+ "Myung-Jee Jung"
59
+ ],
60
+ "publish_date": "2023-04-13"
61
+ }
62
+ self.openai_publish_details_extractor._validate_text(valid_text) # type: ignore
63
+
64
+ def test__process_element(self):
65
+ handler = MockLLMHandler(chat_completion=[DOCUMENT_METADATA_EXTRACTION_RESPONSE])
66
+ openai_publish_details_extractor = OpenAIDocumentMetadataExtractor(handler)
67
+ pdf_document_d = self.start_document_d
68
+ output = list(openai_publish_details_extractor._process_element(pdf_document_d))
69
+ self.assertEqual(output[0], self.final_document_d)
70
+
71
+
72
+ if __name__ == '__main__':
73
+ logging.basicConfig(level=logging.INFO)
74
+ unittest.main()
extraction_pipeline/document_metadata_extractor/prompts.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DOCUMENT_METADATA_PROMPT = '''
2
+ You are a financial reports expert, tasked with extracting the authors and publishing date
3
+ from the first page of a financial report document. Given the first page of text from a financial report,
4
+ you must extract and return the list of authors (which may be only one) and the date of
5
+ publication in a specific string format of 'YYYY-MM-DD' in the form of a JSON object. The JSON object
6
+ you return should have the list of authors as the value to the "authors" key.
7
+ The specifically formatted datetime string should be the value to the "publish_date" key.
8
+ Note, I will be parsing out the datetime string using the python function `datetime.strptime(your_output['publish_date'], '%Y-%m-%d')`.
9
+ Do not include anything besides the JSON object in your response.
10
+ I will be parsing your response using the python code `json.loads(your_output)`.
11
+
12
+ Example input:
13
+ "Trading ideas and investment strategies discussed herein may give rise to significant risk and are \nnot suitable for all investors. Investors should have experience in relevant markets and the financial \nresources to absorb any losses arising from applying these ideas or strategies. \n>> Employed by a non-US affiliate of BofAS and is not registered/qualified as a research analyst \nunder the FINRA rules. \nRefer to "Other Important Disclosures" for information on certain BofA Securities entities that take \nresponsibility for the information herein in particular jurisdictions. \nBofA Securities does and seeks to do business with issuers covered in its research \nreports. As a result, investors should be aware that the firm may have a conflict of \ninterest that could affect the objectivity of this report. Investors should consider this \nreport as only a single factor in making their investment decision. \nRefer to important disclosures on page 11 to 13. \n12544294 \n \nThe Flow Show \n \nSugar Coated Iceberg \n \n \nScores on the Doors: crypto 65.3%, gold 10.1%, stocks 7.7%, HY bonds 4.3%, IG bonds \n4.3%, govt bonds 3.4%, oil 3.7%, cash 1.2%, commodities -0.1%, US dollar -2.0% YTD. \nThe Biggest Picture: everyone's new favorite theme…US dollar debasement; US$ -11% \nsince Sept, gold >$2k, bitcoin >$30k; right secular theme (deficits, debt, geopolitics), \nUS$ in 4th bear market of past 50 years, bullish gold, oil, Euro, international stocks; but \npessimism so high right now, and 200bps Fed cuts priced in from June to election, look \nfor US$ trading bounce after Fed ends hiking cycle May 3rd (Chart 2). \nTale of the Tape: inflation slowing, Fed hiking cycle over, recession expectations \nuniversal, yet UST 30-year can't break below 3.6% (200-dma) because we've already \ntraded 8% to 4% CPI, labor market yet to crack, US govt deficit growing too quickly. \nThe Price is Right: next 6 months it's "recession vs Fed cuts"; best tells who's \nwinning…HY bonds, homebuilders, semiconductors…HYG <73, XHB <70, SOX <2900 \nrecessionary, and if levels hold…it's a no/soft landing. \nChart 2: US dollar has begun 4th bear market of past 50 years \nUS dollar index (DXY) \n \nSource: BofA Global Investment Strategy, Bloomberg. \nBofA GLOBAL RESEARCH \nMore on page 2… \n \n \n \n \n \n \nGold ↑\nYen & Deutsche \nMark ↑\nBRICS & Euro ↑\nGold, \nEuro ↑\n70\n80\n90\n100\n110\n120\n130\n140\n150\n160\n1967\n1973\n1979\n1985\n1991\n1997\n2003\n2009\n2015\n2021\nUS dollar index (DXY)\n13 April 2023 Corrected \n \n \n \n \nInvestment Strategy \nGlobal \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \nMichael Hartnett \nInvestment Strategist \nBofAS \n+1 646 855 1508 \nmichael.hartnett@bofa.com \n \nElyas Galou >> \nInvestment Strategist \nBofASE (France) \n+33 1 8770 0087 \nelyas.galou@bofa.com \n \nAnya Shelekhin \nInvestment Strategist \nBofAS \n+1 646 855 3753 \nanya.shelekhin@bofa.com \n \nMyung-Jee Jung \nInvestment Strategist \nBofAS \n+1 646 855 0389 \nmyung-jee.jung@bofa.com \n \n \n \n \n \n \n \n \n \nChart 1: BofA Bull & Bear Indicator \nStays at 2.3 \n \nSource: BofA Global Investment Strategy \nThe indicator identified above as the BofA Bull & Bear \nIndicator is intended to be an indicative metric only and \nmay not be used for reference purposes or as a measure \nof performance for any financial instrument or contract, \nor otherwise relied upon by third parties for any other \npurpose, without the prior written consent of BofA \nGlobal Research. This indicator was not created to act as \na benchmark. \nBofA GLOBAL RESEARCH \n \n \nExtreme \nBearish\nExtreme \nBullish\n4\n6\n10\n0\n8\n2\nBuy\nSell\n2.3\nAccessible version \n \n \nTimestamp: 13 April 2023 08:19PM EDT\nW \n"
14
+
15
+ Example output:
16
+ "
17
+ {
18
+ "authors": ["BofA Global Research", "Michael Hartnett", "Elyas Galou", "Anya Shelekhin", "Myung-Jee Jung"],
19
+ "publish_date": "2023-04-13"
20
+ }
21
+ "
22
+ '''
extraction_pipeline/pdf_process_stage.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+ from typing import Iterable
4
+ import pymupdf
5
+ import re
6
+
7
+ from domain.domain_protocol import DomainProtocol
8
+ from domain.chunk_d import DocumentD, ChunkD
9
+ import proto.chunk_pb2 as chunk_pb2
10
+ from extraction_pipeline.base_stage import BaseStage, BaseTransform
11
+ from storage.domain_dao import CacheDomainDAO
12
+
13
+
14
+ class IDTagger():
15
+
16
+ _uuid_to_tag: CacheDomainDAO[ChunkD]
17
+
18
+ def __init__(self, cache_save_path: str):
19
+ self._uuid_to_tag = CacheDomainDAO(f"{cache_save_path}.json", ChunkD)
20
+
21
+ def __call__(self, element: ChunkD) -> ChunkD:
22
+ self._uuid_to_tag.insert([element])
23
+ return element
24
+
25
+
26
+ class PdfPageChunkerStage(BaseStage[DocumentD, ChunkD]):
27
+
28
+ def __init__(self):
29
+ self._pdf_cache_dao = CacheDomainDAO("pdf_chunk_id_map.json", DocumentD)
30
+ self.id_tagger = CacheDomainDAO(f"{str.lower((self.__class__.__name__))}_chunk_id_map.json",
31
+ ChunkD)
32
+
33
+ def _process_element(self, element: DocumentD) -> Iterable[ChunkD]:
34
+ pdf_document: Iterable = pymupdf.open(element.file_path) # type: ignore
35
+ for page_index, pdf_page in enumerate(pdf_document):
36
+ page_text: str = pdf_page.get_textpage().extractText() # type: ignore
37
+ yield self.id_tagger(
38
+ ChunkD(parent_reference=element,
39
+ chunk_text=page_text,
40
+ chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PAGE,
41
+ chunk_index=page_index,
42
+ chunk_id=self._pdf_cache_dao.set(element)))
43
+
44
+
45
+ class ParagraphChunkerStage(BaseStage[ChunkD, ChunkD]):
46
+
47
+ def __init__(self):
48
+ self.id_tagger = CacheDomainDAO(f"{str.lower((self.__class__.__name__))}_chunk_id_map.json",
49
+ ChunkD)
50
+
51
+ def _process_element(self, element: ChunkD) -> Iterable[ChunkD]:
52
+ paragraphs = re.split(r'\n+', element.chunk_text)
53
+ paragraphs = [para.strip() for para in paragraphs if para.strip()]
54
+ for chunk_index, paragraph in enumerate(paragraphs):
55
+ yield self.id_tagger(
56
+ ChunkD(parent_reference=element.chunk_id,
57
+ chunk_text=paragraph,
58
+ chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PARAGRAPH,
59
+ chunk_index=chunk_index))
60
+
61
+
62
+ class SentenceChunkerStage(BaseStage[ChunkD, ChunkD]):
63
+
64
+ def __init__(self):
65
+ self.id_tagger = CacheDomainDAO(f"{str.lower((self.__class__.__name__))}_chunk_id_map.json",
66
+ ChunkD)
67
+
68
+ def _process_element(self, element: ChunkD) -> Iterable[ChunkD]:
69
+ sentence_endings = r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|!)(\s|$)'
70
+ sentences = re.split(sentence_endings, element.chunk_text)
71
+ sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
72
+ for chunk_index, sentence in enumerate(sentences):
73
+ yield self.id_tagger(
74
+ ChunkD(parent_reference=element.chunk_id,
75
+ chunk_text=sentence,
76
+ chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_SENTENCE,
77
+ chunk_index=chunk_index))
78
+
79
+
80
+ class PdfToSentencesTransform(BaseTransform[DocumentD, ChunkD]):
81
+
82
+ _pdf_page_chunker: PdfPageChunkerStage
83
+ _paragraph_chunker: ParagraphChunkerStage
84
+ _sentence_chunker: SentenceChunkerStage
85
+
86
+ def __init__(self):
87
+ self._pdf_page_chunker = PdfPageChunkerStage()
88
+ self._paragraph_chunker = ParagraphChunkerStage()
89
+ self._sentence_chunker = SentenceChunkerStage()
90
+
91
+ def _process_collection(self, collection: Iterable[DocumentD]) -> Iterable[ChunkD]:
92
+ for pdf_document in collection:
93
+ pdf_pages = self._pdf_page_chunker.process_element(pdf_document)
94
+ for pdf_page in pdf_pages:
95
+ paragraphs = self._paragraph_chunker.process_element(pdf_page)
96
+ for paragraph in paragraphs:
97
+ sentences = self._sentence_chunker.process_element(paragraph)
98
+ yield from sentences
99
+
100
+
101
+ class PdfToParagraphTransform(BaseTransform[DocumentD, ChunkD]):
102
+
103
+ _pdf_page_chunker: PdfPageChunkerStage
104
+ _paragraph_chunker: ParagraphChunkerStage
105
+
106
+ def __init__(self):
107
+ self._pdf_page_chunker = PdfPageChunkerStage()
108
+ self._paragraph_chunker = ParagraphChunkerStage()
109
+
110
+ def _process_collection(self, collection: Iterable[DocumentD]) -> Iterable[ChunkD]:
111
+ for pdf_document in collection:
112
+ pdf_pages = self._pdf_page_chunker.process_element(pdf_document)
113
+ for pdf_page in pdf_pages:
114
+ paragraphs = self._paragraph_chunker.process_element(pdf_page)
115
+ for paragraph in paragraphs:
116
+ yield paragraph
117
+
118
+
119
+ class PdfToPageTransform(BaseTransform[DocumentD, ChunkD]):
120
+
121
+ _pdf_page_chunker: PdfPageChunkerStage
122
+
123
+ def __init__(self):
124
+ self._pdf_page_chunker = PdfPageChunkerStage()
125
+ self._paragraph_chunker = ParagraphChunkerStage()
126
+
127
+ def _process_collection(self, collection: Iterable[DocumentD]) -> Iterable[ChunkD]:
128
+ for pdf_document in collection:
129
+ pdf_pages = self._pdf_page_chunker.process_element(pdf_document)
130
+ for pdf_page in pdf_pages:
131
+ yield pdf_page
extraction_pipeline/pdf_process_stage_test.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import logging
3
+ import uuid
4
+
5
+ from domain.chunk_d import DocumentD, ChunkD
6
+ import proto.chunk_pb2 as chunk_pb2
7
+ from extraction_pipeline.pdf_process_stage import PdfPageChunkerStage, ParagraphChunkerStage, SentenceChunkerStage, PdfToSentencesTransform
8
+
9
+
10
+ def compare_chunk_without_id(chunk1: ChunkD, chunk2: ChunkD):
11
+ assert chunk1.parent_reference == chunk2.parent_reference
12
+ assert chunk1.chunk_text == chunk2.chunk_text
13
+ assert chunk1.chunk_type == chunk2.chunk_type
14
+ assert chunk1.chunk_index == chunk2.chunk_index
15
+
16
+
17
+ class PdfPageChunkerStageTest(unittest.TestCase):
18
+
19
+ @classmethod
20
+ def setUpClass(cls):
21
+ cls.test_pdf_path = "extraction_pipeline/test_data/test.pdf"
22
+ cls.expected_pdf_text = """This research presents a thorough examination of Large Language Models (LLMs) in the
23
+ domain of code re-engineering.
24
+ Our focus on the methodical development and utilization of robust pipelines highlights both the
25
+ potential and the current challenges of employing LLMs for these complex tasks.
26
+ """
27
+ cls.document_d = DocumentD(file_path="extraction_pipeline/test_data/test.pdf",
28
+ authors="BofA John Doe Herb Johnson Taylor Mason",
29
+ publish_date="2021-01-01")
30
+
31
+ def test__process_element(self):
32
+ chunks = list(PdfPageChunkerStage()._process_element(self.document_d))
33
+ self.assertEqual(len(chunks), 1)
34
+ expected_chunk = ChunkD(parent_reference=self.document_d,
35
+ chunk_text=self.expected_pdf_text,
36
+ chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PAGE,
37
+ chunk_index=0)
38
+ compare_chunk_without_id(chunks[0], expected_chunk)
39
+
40
+
41
+ class ParagraphChunkerStageTest(unittest.TestCase):
42
+
43
+ @classmethod
44
+ def setUpClass(cls):
45
+ cls.document_d = DocumentD(file_path="extraction_pipeline/test_data/test.pdf",
46
+ authors="BofA John Doe Herb Johnson Taylor Mason",
47
+ publish_date="2021-01-01")
48
+
49
+ def test__process_element(self):
50
+ chunk = ChunkD(parent_reference=self.document_d,
51
+ chunk_text="paragraph1\nparagraph2\n\nparagraph3",
52
+ chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PAGE,
53
+ chunk_index=0)
54
+ chunks = list(ParagraphChunkerStage()._process_element(chunk))
55
+ expected_chunks = [
56
+ ChunkD(parent_reference=chunk.chunk_id,
57
+ chunk_text="paragraph1",
58
+ chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PARAGRAPH,
59
+ chunk_index=0),
60
+ ChunkD(parent_reference=chunk.chunk_id,
61
+ chunk_text="paragraph2",
62
+ chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PARAGRAPH,
63
+ chunk_index=1),
64
+ ChunkD(parent_reference=chunk.chunk_id,
65
+ chunk_text="paragraph3",
66
+ chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PARAGRAPH,
67
+ chunk_index=2)
68
+ ]
69
+ self.assertEqual(len(chunks), len(expected_chunks))
70
+ for chunk, expected_chunk in zip(chunks, expected_chunks):
71
+ compare_chunk_without_id(chunk, expected_chunk)
72
+
73
+
74
+ class SentenceChunkerStageTest(unittest.TestCase):
75
+
76
+ def test__process_element(self):
77
+ chunk = ChunkD(parent_reference=uuid.uuid4(),
78
+ chunk_text="sentence1. sentence2! sentence3?",
79
+ chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PARAGRAPH,
80
+ chunk_index=0)
81
+ chunks = list(SentenceChunkerStage()._process_element(chunk))
82
+ expected_chunks = [
83
+ ChunkD(parent_reference=chunk.chunk_id,
84
+ chunk_text="sentence1.",
85
+ chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_SENTENCE,
86
+ chunk_index=0),
87
+ ChunkD(parent_reference=chunk.chunk_id,
88
+ chunk_text="sentence2!",
89
+ chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_SENTENCE,
90
+ chunk_index=1),
91
+ ChunkD(parent_reference=chunk.chunk_id,
92
+ chunk_text="sentence3?",
93
+ chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_SENTENCE,
94
+ chunk_index=2)
95
+ ]
96
+ self.assertEqual(len(chunks), len(expected_chunks))
97
+ for chunk, expected_chunk in zip(chunks, expected_chunks):
98
+ compare_chunk_without_id(chunk, expected_chunk)
99
+
100
+
101
+ if __name__ == '__main__':
102
+ logging.basicConfig(level=logging.INFO)
103
+ unittest.main()
extraction_pipeline/pdf_to_knowledge_graph_transform.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import sys
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+ from pathlib import Path
7
+ from typing import Iterable, Optional
8
+
9
+ from domain.chunk_d import ChunkD, DocumentD
10
+ from domain.entity_d import EntityKnowledgeGraphD
11
+ from extraction_pipeline.base_stage import BaseTransform
12
+ from extraction_pipeline.document_metadata_extractor.openai_document_metadata_extractor import (
13
+ OpenAIDocumentMetadataExtractor,)
14
+ from extraction_pipeline.pdf_process_stage import (
15
+ PdfToPageTransform,
16
+ PdfToParagraphTransform,
17
+ PdfToSentencesTransform,
18
+ )
19
+ from extraction_pipeline.relationship_extractor.entity_relationship_extractor import (
20
+ RelationshipExtractor,)
21
+ from extraction_pipeline.relationship_extractor.openai_relationship_extractor import (
22
+ OpenAIRelationshipExtractor,)
23
+ from llm_handler.openai_handler import OpenAIHandler
24
+ from storage.domain_dao import InMemDomainDAO
25
+ from storage.neo4j_dao import Neo4jDomainDAO
26
+
27
+
28
+ class PdfToKnowledgeGraphTransform(BaseTransform[DocumentD, EntityKnowledgeGraphD]):
29
+ _metadata_extractor: OpenAIDocumentMetadataExtractor
30
+ _pdf_chunker: BaseTransform[DocumentD, ChunkD]
31
+ _relationship_extractor: RelationshipExtractor
32
+
33
+ def __init__(self, pdf_chunker: BaseTransform[DocumentD, ChunkD]):
34
+ openai_handler = OpenAIHandler()
35
+ self._metadata_extractor = OpenAIDocumentMetadataExtractor(openai_handler=openai_handler)
36
+ self._pdf_chunker = pdf_chunker
37
+ self._relationship_extractor = OpenAIRelationshipExtractor(openai_handler=openai_handler)
38
+
39
+ def _process_collection(self,
40
+ collection: Iterable[DocumentD]) -> Iterable[EntityKnowledgeGraphD]:
41
+ # produce 1 EntityKnowledgeGraphD per DocumentD
42
+ for pdf_document in collection:
43
+ # metadata extractor only yields 1 filled in DocumentD for each input DocumentD
44
+ document = next(iter(self._metadata_extractor.process_element(pdf_document)))
45
+ entity_relationships = []
46
+ for pdf_chunk in self._pdf_chunker.process_collection([document]):
47
+ for relationship in self._relationship_extractor.process_element(pdf_chunk):
48
+ entity_relationships.append(relationship)
49
+ yield EntityKnowledgeGraphD(entity_relationships=entity_relationships)
50
+
51
+
52
+ if __name__ == '__main__':
53
+ ## CLI Arguments for running transform as multi-threaded script
54
+ parser = argparse.ArgumentParser(description='Extract knowledge graphs from PDF files')
55
+ parser.add_argument('--pdf_folder',
56
+ type=str,
57
+ help='Path to folder of PDF files to process',
58
+ default='')
59
+ parser.add_argument('--pdf_file',
60
+ type=str,
61
+ help='Path to the one PDF file to process',
62
+ default='')
63
+ parser.add_argument('--output_json_file',
64
+ type=str,
65
+ help='Path for output json file of knowledge graphs',
66
+ default='./knowledge_graphs.json')
67
+ parser.add_argument('--log_folder', type=str, help='Path to log folder', default='log')
68
+ parser.add_argument('--chunk_to',
69
+ type=str,
70
+ help='What level to chunk PDF text',
71
+ default='page',
72
+ choices=['page', 'paragraph', 'sentence'])
73
+ parser.add_argument('--verbose', help='Enable DEBUG level logs', action='store_true')
74
+ parser.add_argument('--upload_to_neo4j',
75
+ help='Enable uploads to Neo4j database',
76
+ action='store_true')
77
+ args = parser.parse_args()
78
+
79
+ ## Setup logging
80
+ if args.verbose:
81
+ log_level = logging.DEBUG
82
+ else:
83
+ log_level = logging.INFO
84
+ os.makedirs(args.log_folder, exist_ok=True)
85
+ script_name = os.path.splitext(os.path.basename(__file__))[0]
86
+ logging.basicConfig(level=log_level,
87
+ format='%(asctime)s - %(levelname)s - %(message)s',
88
+ filename=f'{args.log_folder}/{script_name}.log',
89
+ filemode='w')
90
+ logger = logging.getLogger(__name__)
91
+
92
+ ## Setup PDF Chunking
93
+ if args.chunk_to == 'page':
94
+ pdf_transform = PdfToPageTransform()
95
+ elif args.chunk_to == 'paragraph':
96
+ pdf_transform = PdfToParagraphTransform()
97
+ elif args.chunk_to == 'sentence':
98
+ pdf_transform = PdfToSentencesTransform()
99
+ else:
100
+ logging.error('Invalid chunking level: %s', args.chunk_to)
101
+ sys.exit(1)
102
+
103
+ ## Process PDF Files
104
+ if not args.pdf_folder and not args.pdf_file:
105
+ logging.error('No PDF file or folder provided')
106
+ sys.exit(1)
107
+ elif args.pdf_folder:
108
+ pdf_folder = Path(args.pdf_folder)
109
+ if not pdf_folder.exists():
110
+ logging.error('PDF folder does not exist: %s', pdf_folder)
111
+ sys.exit(1)
112
+ pdf_files = list(pdf_folder.glob('*.pdf'))
113
+ if len(pdf_files) == 0:
114
+ logging.warning('No PDF files found in folder: %s', pdf_folder)
115
+ sys.exit(0)
116
+ pdfs = [
117
+ DocumentD(file_path=str(pdf_file), authors='', publish_date='')
118
+ for pdf_file in pdf_files
119
+ ]
120
+ else:
121
+ pdf_file = Path(args.pdf_file)
122
+ if not pdf_file.exists():
123
+ logging.error('PDF file does not exist: %s', pdf_file)
124
+ sys.exit(1)
125
+ pdfs = [DocumentD(file_path=str(pdf_file), authors='', publish_date='')]
126
+
127
+ pdf_to_kg = PdfToKnowledgeGraphTransform(pdf_transform)
128
+
129
+ def process_pdf(pdf: DocumentD) -> tuple[Optional[EntityKnowledgeGraphD], str]:
130
+ pdf_name = Path(pdf.file_path).name
131
+ try:
132
+ # process collection yields 1 KG per pdf but we are only
133
+ # inputing 1 PDF at a time so we just need the 1st element
134
+ return list(pdf_to_kg.process_collection([pdf]))[0], pdf_name
135
+ except Exception as e:
136
+ logging.error(f"Error processing pdf: {e}")
137
+ return None, pdf_name
138
+
139
+ results: list[EntityKnowledgeGraphD] = []
140
+ with ThreadPoolExecutor() as executor, Neo4jDomainDAO() as dao:
141
+ futures = [executor.submit(process_pdf, pdf) for pdf in pdfs]
142
+
143
+ for future in as_completed(futures):
144
+ kg, pdf_name = future.result()
145
+ if not kg:
146
+ continue
147
+ results.append(kg)
148
+ if args.upload_to_neo4j:
149
+ dao.insert(kg, pdf_name)
150
+
151
+ dao = InMemDomainDAO()
152
+ dao.insert(results)
153
+ dao.save_to_file(args.output_json_file)
extraction_pipeline/relationship_extractor/entity_relationship_extractor.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+
3
+ from extraction_pipeline.base_stage import BaseStage
4
+ from domain.chunk_d import ChunkD
5
+ from domain.entity_d import EntityRelationshipD
6
+
7
+
8
+ @dataclasses.dataclass(frozen=True)
9
+ class RelationshipExtractor(BaseStage[ChunkD, EntityRelationshipD]):
10
+ ...
extraction_pipeline/relationship_extractor/openai_relationship_extractor.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import uuid
4
+ from typing import ClassVar, Dict, Iterator, List, Optional
5
+
6
+ from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
7
+
8
+ from domain.chunk_d import (ChunkD, DocumentD)
9
+ from domain.entity_d import (
10
+ EntityD,
11
+ EntityRelationshipD,
12
+ RelationshipD,
13
+ )
14
+ from extraction_pipeline.relationship_extractor.entity_relationship_extractor import (
15
+ RelationshipExtractor,)
16
+ from extraction_pipeline.relationship_extractor.prompts import (
17
+ EXTRACT_RELATIONSHIPS_PROMPT,
18
+ NER_TAGGING_PROMPT,
19
+ )
20
+ from llm_handler.llm_interface import LLMInterface
21
+ from llm_handler.openai_handler import ChatModelVersion, OpenAIHandler
22
+
23
+
24
+ class OpenAIRelationshipExtractor(RelationshipExtractor):
25
+
26
+ _handler: LLMInterface
27
+ _MODEL_VERSION: ChatModelVersion = ChatModelVersion.GPT_4_O
28
+ _RELATIONSHIP_KEY: ClassVar[str] = "relationships"
29
+ _ENTITIES_KEY: ClassVar[str] = "entities"
30
+ _RELATIONSHIPS_TYPES: ClassVar[List[str]] = ["PREDICTION"]
31
+ _TEMPARATURE: ClassVar[float] = 0.2
32
+
33
+ def __init__(self,
34
+ openai_handler: Optional[LLMInterface] = None,
35
+ model_version: Optional[ChatModelVersion] = None):
36
+ self._handler = openai_handler or OpenAIHandler()
37
+ self._model_version = model_version or self._MODEL_VERSION
38
+
39
+ def _extract_entity_names(self, chunk_text: str) -> List[Dict[str, str]]:
40
+
41
+ messages: List[ChatCompletionMessageParam] = [{
42
+ "role": "system", "content": NER_TAGGING_PROMPT
43
+ },
44
+ {
45
+ "role": "user",
46
+ "content": f"Input:\n{chunk_text}"
47
+ }]
48
+ completion_text = self._handler.get_chat_completion(messages=messages,
49
+ model=self._model_version,
50
+ temperature=self._TEMPARATURE,
51
+ response_format={"type": "json_object"})
52
+ logging.info(f"entity extraction results: {completion_text}")
53
+ return dict(json.loads(completion_text)).get(self._ENTITIES_KEY, [])
54
+
55
+ def _extract_relationships(self, chunk: ChunkD,
56
+ entity_nodes: List[Dict[str, str]]) -> Iterator[EntityRelationshipD]:
57
+ if isinstance(chunk.parent_reference, DocumentD):
58
+ analyst_names: str = chunk.parent_reference.authors
59
+ document_datetime: str = chunk.parent_reference.publish_date
60
+ else:
61
+ raise NotImplementedError("Parent reference is not a DocumentD")
62
+ messages: List[ChatCompletionMessageParam] = [
63
+ {
64
+ "role": "system", "content": EXTRACT_RELATIONSHIPS_PROMPT
65
+ },
66
+ {
67
+ "role":
68
+ "user",
69
+ "content":
70
+ f"Analyst: {analyst_names} \n Date: {document_datetime} \n Text Chunk: {chunk.chunk_text} \n {str(entity_nodes)}"
71
+ }
72
+ ]
73
+
74
+ completion_text = self._handler.get_chat_completion(messages=messages,
75
+ model=self._model_version,
76
+ temperature=self._TEMPARATURE,
77
+ response_format={"type": "json_object"})
78
+ logging.info(f"relationship results: {completion_text}")
79
+ completion_text = dict(json.loads(completion_text))
80
+ relationships: List[Dict[str, Dict[str,
81
+ str]]] = completion_text.get(self._RELATIONSHIP_KEY, [])
82
+ for extracted_relationship in relationships:
83
+ key: str = list(extracted_relationship.keys())[0]
84
+ if key in self._RELATIONSHIPS_TYPES:
85
+ relationship_kw_attr: Dict[str, str] = extracted_relationship[key]
86
+ relationship_d = RelationshipD(
87
+ relationship_id=str(uuid.uuid4()),
88
+ start_date=relationship_kw_attr.get("start_date", ""),
89
+ end_date=relationship_kw_attr.get("end_date", ""),
90
+ source_text=chunk.chunk_text,
91
+ predicted_movement=RelationshipD.from_string(
92
+ relationship_kw_attr.get("predicted_movement", "")))
93
+ else:
94
+ raise ValueError(f"No valid relationships in {extracted_relationship}")
95
+
96
+ entity_l_d = EntityD(entity_id=str(uuid.uuid4()),
97
+ entity_name=relationship_kw_attr.get("from_entity", ""))
98
+
99
+ entity_r_d = EntityD(entity_id=str(uuid.uuid4()),
100
+ entity_name=relationship_kw_attr.get("to_entity", ""))
101
+
102
+ yield EntityRelationshipD(relationship=relationship_d,
103
+ from_entity=entity_l_d,
104
+ to_entity=entity_r_d)
105
+
106
+ def _process_element(self, element: ChunkD) -> Iterator[EntityRelationshipD]:
107
+ entities_text: List[Dict[str, str]] = self._extract_entity_names(element.chunk_text)
108
+ yield from self._extract_relationships(element, entities_text)
extraction_pipeline/relationship_extractor/openai_relationship_extractor_test.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import unittest
3
+ import uuid
4
+
5
+ import proto.chunk_pb2 as chunk_pb2
6
+
7
+ from domain.chunk_d import ChunkD, DocumentD
8
+ from domain.entity_d import RelationshipD
9
+ from extraction_pipeline.relationship_extractor.openai_relationship_extractor import (
10
+ OpenAIRelationshipExtractor,)
11
+ from llm_handler.mock_llm_handler import MockLLMHandler
12
+
13
+ ENTITY_EXTRACTION_RESPONSE = '''
14
+ {
15
+ "entities": [
16
+ {"name": "entity 1", "type": "Financial Securities"},
17
+ {"name": "entity 2", "type": "Financial Securities"}
18
+ ]
19
+ }
20
+ '''
21
+
22
+ RELATIONSHIP_EXTRACTION_RESPONSE = '''
23
+ {
24
+ "relationships": [
25
+ {
26
+ "PREDICTION": {
27
+ "from_entity": "John Doe",
28
+ "to_entity": "entity 1",
29
+ "start_date": "2024-06-12",
30
+ "end_date": "2024-11-05",
31
+ "predicted_movement": "PREDICTED_MOVEMENT_DECREASE",
32
+ "description": "200bps Fed cuts priced in from June to election, look for US$ trading bounce after Fed ends hiking cycle May 3rd."
33
+ }
34
+ },
35
+ {
36
+ "PREDICTION": {
37
+ "from_entity": "John Doe",
38
+ "to_entity": "entity 2",
39
+ "start_date": "2024-06-12",
40
+ "end_date": "2024-11-05",
41
+ "predicted_movement": "PREDICTED_MOVEMENT_DECREASE",
42
+ "description": "200bps Fed cuts priced in from June to election, look for US$ trading bounce after Fed ends hiking cycle May 3rd."
43
+ }
44
+ }
45
+ ]
46
+ }
47
+ '''
48
+
49
+
50
+ class TestOpenAIRelationshipExtractor(unittest.TestCase):
51
+
52
+ def setUp(self) -> None:
53
+ self.document_d = DocumentD(file_path="Morgan_Stanley_Research_2024.pdf",
54
+ authors='John Doe',
55
+ publish_date='2024-06-12')
56
+ self.chunkd = ChunkD(chunk_text='entity 1 relates to entity 2.',
57
+ chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PAGE,
58
+ chunk_index=1,
59
+ parent_reference=self.document_d)
60
+
61
+ def test__process_elements(self):
62
+ chat_responses = [ENTITY_EXTRACTION_RESPONSE, RELATIONSHIP_EXTRACTION_RESPONSE]
63
+ handler = MockLLMHandler(chat_completion=chat_responses)
64
+ openai_relationship_extractor = OpenAIRelationshipExtractor(handler)
65
+ output = list(openai_relationship_extractor._process_element(self.chunkd))
66
+ extracted_relationship1 = output[0]
67
+ extracted_relationship2 = output[1]
68
+ self.assertEqual(len(output), 2)
69
+ self.assertEqual(extracted_relationship1.from_entity.entity_name, 'John Doe')
70
+ self.assertEqual(extracted_relationship1.to_entity.entity_name, 'entity 1')
71
+ self.assertEqual(extracted_relationship1.relationship.start_date, '2024-06-12')
72
+ self.assertEqual(extracted_relationship1.relationship.end_date, '2024-11-05')
73
+ self.assertEqual(extracted_relationship1.relationship.predicted_movement, 3)
74
+ self.assertEqual(extracted_relationship1.relationship.source_text,
75
+ 'entity 1 relates to entity 2.')
76
+ self.assertEqual(extracted_relationship2.from_entity.entity_name, 'John Doe')
77
+ self.assertEqual(extracted_relationship2.to_entity.entity_name, 'entity 2')
78
+ self.assertEqual(extracted_relationship2.relationship.start_date, '2024-06-12')
79
+ self.assertEqual(extracted_relationship2.relationship.end_date, '2024-11-05')
80
+ self.assertEqual(extracted_relationship2.relationship.predicted_movement, 3)
81
+ self.assertEqual(extracted_relationship2.relationship.source_text,
82
+ 'entity 1 relates to entity 2.')
83
+
84
+
85
+ if __name__ == '__main__':
86
+ logging.basicConfig(level=logging.INFO)
87
+ unittest.main()
extraction_pipeline/relationship_extractor/prompts.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NER_TAGGING_PROMPT = '''
2
+ You are an expert in finance who will be performing Named Entity Recognition (NER)
3
+ with the high-level objective of analyzing financial documents. The output should
4
+ be formatted as a JSON object, where the JSON object has a key "entities" that
5
+ contains a list of the extracted entities. If no entities are found, return an
6
+ empty JSON object.
7
+
8
+ - Finacial Markets (Grouping of financial securities)
9
+ - Finacial Securities
10
+ - Finacial Indicators
11
+
12
+ Now, given a text chunk, extract the relevant entities and format the output as a JSON object.
13
+
14
+ Examples of entities:
15
+
16
+ - Financial Markets: "US Stock Market", "US Bond Market", "US Real Estate Market", "Tech Stocks", "Energy Stocks"
17
+ - Financial Securities: "Apple", "Tesla", "Microsoft", "USD", "Bitcoin", "Gold", "Oil"
18
+ - Economic Indicators: "Consumer Price Index (CPI)", "Unemployment Rate", "GDP Growth Rate", "Federal Deficit"
19
+
20
+ Input:
21
+ Scores on the Doors: crypto 65.3%, gold 10.1%, stocks 7.7%, HY bonds 4.3%, IG bonds 4.3%, govt bonds 3.4%,
22
+ oil 3.7%, cash 1.2%, commodities -0.1%, US dollar -2.0% YTD.
23
+ The Biggest Picture: everyone's new favorite theme...US dollar debasement; US$ -11%
24
+ since Sept, gold >$2k, bitcoin >$30k; right secular theme (deficits, debt, geopolitics),
25
+ US$ in 4th bear market of past 50 years, bullish gold, bearish Euro, bullish international
26
+ stocks; but pessimism so high right now, and 200bps Fed cuts priced in from June to election,
27
+ look for US$ trading bounce after Fed ends hiking cycle May 3rd (Chart 2).
28
+ 2024 Tale of the Tape: inflation slowing, Fed hiking cycle over, recession
29
+ expectations universal, yet UST 30-year can't break below 3.6% (200-dma) because we've already traded 8% to 4% CPI,
30
+ labor market yet to crack, US govt deficit growing too quickly. The Price is Right: next 6 months it's
31
+ “recession vs Fed cuts”; best tells who's winning...HY bonds, homebuilders, semiconductors...HYG <73, XHB <70,
32
+ SOX <2900 recessionary, and if levels hold...it's a no/soft landing.
33
+
34
+ Output:
35
+ {
36
+ "entities": [
37
+ {
38
+ "name": "crypto",
39
+ "type": "Financial Securities",
40
+ },
41
+ {
42
+ "name": "gold",
43
+ "type": "Financial Securities",
44
+ },
45
+ {
46
+ "name": "stocks",
47
+ "type": "Financial Securities",
48
+ },
49
+ {
50
+ "name": "HY bonds",
51
+ "type": "Financial Securities",
52
+ },
53
+ {
54
+ "name": "IG bonds",
55
+ "type": "Financial Securities",
56
+ },
57
+ {
58
+ "name": "govt bonds",
59
+ "type": "Financial Securities",
60
+ },
61
+ {
62
+ "name": "oil",
63
+ "type": "Financial Securities",
64
+ },
65
+ {
66
+ "name": "cash",
67
+ "type": "Financial Securities",
68
+ },
69
+ {
70
+ "name": "commodities",
71
+ "type": "Financial Securities",
72
+ },
73
+ {
74
+ "name": "US dollar",
75
+ "type": "Financial Securities",
76
+ },
77
+ {
78
+ "name": "US$",
79
+ "type": "Financial Securities",
80
+ },
81
+ {
82
+ "name": "gold",
83
+ "type": "Financial Securities",
84
+ },
85
+ {
86
+ "name": "bitcoin",
87
+ "type": "Financial Securities",
88
+ },
89
+ {
90
+ "name": "UST 30-year",
91
+ "type": "Financial Securities",
92
+ },
93
+ {
94
+ "name": "CPI",
95
+ "type": "Economic Indicators",
96
+ },
97
+ {
98
+ "name": "labor market",
99
+ "type": "Economic Indicators",
100
+ },
101
+ {
102
+ "name": "US govt deficit",
103
+ "type": "Economic Indicators",
104
+ },
105
+ {
106
+ "name": "HY bonds",
107
+ "type": "Financial Securities",
108
+ },
109
+ {
110
+ "name": "homebuilders",
111
+ "type": "Financial Markets",
112
+ },
113
+ {
114
+ "name": "semiconductors",
115
+ "type": "Financial Markets",
116
+ },
117
+ {
118
+ "name": "HYG",
119
+ "type": "Financial Securities",
120
+ },
121
+ {
122
+ "name": "XHB",
123
+ "type": "Financial Securities",
124
+ },
125
+ {
126
+ "name": "SOX",
127
+ "type": "Financial Securities",
128
+ }
129
+ ]
130
+ }
131
+
132
+ Input:
133
+ the latest tech stock resurgence of 2024; NASDAQ +15% since January, Apple >$180, Tesla >$250;
134
+ driven by strong earnings, innovation, and AI advancements, NASDAQ in its 3rd bull market of the past decade,
135
+ bullish on semiconductors, cloud computing, and cybersecurity stocks; however, optimism
136
+ is soaring, and 150bps rate hikes anticipated from July to year-end, expect a potential
137
+ pullback in tech stocks after the next earnings season concludes in August (Chart 3).
138
+ Anyway, the bipartisan agreement to suspend the US debt ceiling until 2025 with only limited
139
+ fiscal restraint (roughly 0.1-0.2% of GDP yoy in 2024 and 2025 compared with the baseline)
140
+ means that we have avoided another potential significant negative shock to demand.
141
+ On the brightside, the boost to funding that Congress approved late last year for FY23
142
+ was so large (nearly 10% yoy) that overall discretionary spending is likely to be slightly
143
+ higher in real terms next year despite the new caps.
144
+
145
+ Output:
146
+ {
147
+ "entities": [
148
+ {
149
+ "name": "NASDAQ",
150
+ "type": "Financial Securities",
151
+ },
152
+ {
153
+ "name": "Apple",
154
+ "type": "Financial Securities",
155
+ },
156
+ {
157
+ "name": "Tesla",
158
+ "type": "Financial Securities",
159
+ },
160
+ {
161
+ "name": "semiconductors",
162
+ "type": "Financial Market",
163
+ },
164
+ {
165
+ "name": "cloud computing",
166
+ "type": "Financial Market",
167
+ },
168
+ {
169
+ "name": "cybersecurity stocks",
170
+ "type": "Financial Market",
171
+ },
172
+ {
173
+ "name": "tech stocks",
174
+ "type": "Financial Marker",
175
+ },
176
+ {
177
+ "name": "GDP",
178
+ "type": "Economic Indicators",
179
+ },
180
+ {
181
+ "name": "FY23",,
182
+ "type": "Economic Indiciators",
183
+ }
184
+ ]
185
+ }
186
+
187
+ '''
188
+
189
+ EXTRACT_RELATIONSHIPS_PROMPT = '''
190
+ As a finance expert, your role is to extract relationships between a set of entities.
191
+ You will receive an excerpt from a financial report, the name of the analyst who authored
192
+ it, the date the report was published, and a list of entities that have already been extracted.
193
+ Your task is to identify and output relationships between these entities to construct a
194
+ knowledge graph for financial document analysis.
195
+
196
+ Your output should be a JSON object containing a single key, "relationships,"
197
+ which maps to a list of the extracted relationships. These relationships should be
198
+ formatted as structured JSON objects. Each relationship type must include "from_entity"
199
+ and "to_entity" attributes, using previously extracted entities as guidance,
200
+ to denote the source and target of the relationship, respectively. Additionally, all
201
+ relationship types should have "start_date" and "end_date" attributes, formatted as
202
+ "YYYY-MM-DD". The "source_text" attribute within a relationship should directly quote
203
+ the description of the relationship from the financial report excerpt.
204
+
205
+ The relationship types you need to extract include:
206
+
207
+ [Relationship Type 1: "PREDICTION"]: This relationship captures any forecasts or predictions the analyst makes
208
+ regarding trends in financial securities, indicators, or market entities. Search for phrases such as "we forecast,"
209
+ "is likely that," or "expect the market to." You must also assess the sentiment of the prediction—whether it is neutral,
210
+ positive (increase), or negative (decrease)—and include this in the relationship. Map the sentiment to one of the following
211
+
212
+ - PREDICTED_MOVEMENT_NEUTRAL (neutral sentiment)
213
+ - PREDICTED_MOVEMENT_INCREASE (positive sentiment)
214
+ - PREDICTED_MOVEMENT_DECREASE (negative sentiment)
215
+
216
+ The "from_entity" should be the analyst's name(s), while the "to_entity" should be the entity the prediction concerns (utilize the previously extracted entities as guidance).
217
+ Each prediction must have a defined "start_date" and "end_date" indicating the expected timeframe of the prediction. If
218
+ the prediction is not explicitly time-bound, omit this relationship. Use the report's publishing date as the "start_date"
219
+ unless the author specifies a different date. The "end_date" should correspond to when the prediction is anticipated to conclude.
220
+
221
+ Be sure to include the source excerpt you used to derive each relationship in the "source_text" attribute of the corresponding relationship.
222
+ Realize that predictions are about the future. Lastly, note that sometimes you will need to infer what entity is being referred to,
223
+ even if it is not explicitly named in the source sentence.
224
+
225
+ Example Input:
226
+
227
+ Analyst: Helen Brooks
228
+ Publish Date: 2024-06-12
229
+
230
+ Text Chunk: Labor markets are showing signs of normalization to end 2024; unemployment could drift higher in 2024 while remaining low in historical context.
231
+ Momentum in the job market is starting to wane with slowing payroll growth and modestly rising unemployment, as well as declining quit rates and temporary help.
232
+ Increased labor force participation and elevated immigration patterns over the past year have added labor supply, while a shortening work week indicates moderating demand for labor.
233
+ Considering the challenges to add and retain workers coming out of the pandemic, businesses could be more reluctant than normal to shed workers in a slowing economic environment.
234
+ Even so, less hiring activity could be enough to cause the unemployment rate to tick up to the mid-4% area by the end of next year due to worker churn.
235
+ Already slowing wage gains should slow further in the context of a softer labor market.
236
+ Anyway, the bipartisan agreement to suspend the US debt ceiling until 2025 with only limited
237
+ fiscal restraint (roughly 0.1-0.2% of GDP yoy in 2024 and 2025 compared with the baseline)
238
+ means that we have avoided another potential significant negative shock to demand.
239
+ On the brightside, the boost to funding that Congress approved late last year for FY23
240
+ was so large (nearly 10% yoy) that overall discretionary spending is likely to be slightly
241
+ higher in real terms next year despite the new caps.
242
+
243
+ {
244
+ "entities": [
245
+ {
246
+ "name": "labor markets",
247
+ "type": "Economic Market",
248
+ },
249
+ {
250
+ "name": "unemployment",
251
+ "type": "Economic Indicators",
252
+ },
253
+ {
254
+ "name": "job market",
255
+ "type": "Economic Market",
256
+ },
257
+ {
258
+ "name": "payroll growth",
259
+ "type": "Economic Indicators",
260
+ },
261
+ {
262
+ "name": "labor force participation",
263
+ "type": "Economic Indicators",
264
+ },
265
+ {
266
+ "name": "rising unemployment",
267
+ "type": "Economic Indicators",
268
+ }
269
+ {
270
+ "name": "immigration patterns",
271
+ "type": "Economic Indicators",
272
+ },
273
+ {
274
+ "name": "hiring activity",
275
+ "type": "Economic Indicators",
276
+ },
277
+ {
278
+ "name": "worker churn",
279
+ "type": "Economic Indicators",
280
+ },
281
+ {
282
+ "name": "unemployment rate",
283
+ "type": "Economic Indicators",
284
+ },
285
+ {
286
+ "name": "discretionary spending",
287
+ "type": "Economic Indicators",
288
+ },
289
+ {
290
+ "name": "GDP",
291
+ "type": "Economic Indicators",
292
+ }
293
+ ]
294
+ }
295
+
296
+
297
+ Example Output:
298
+ {
299
+ "relationships": [
300
+ {
301
+ "PREDICTION": {
302
+ "from_entity": "Helen Brooks",
303
+ "to_entity": "labor market",
304
+ "start_date": "2024-06-12",
305
+ "end_date": "2024-12-31",
306
+ "predicted_movement": "PREDICTED_MOVEMENT_NEUTRAL",
307
+ "source_text": "Labor markets are showing signs of normalization to end 2024;"
308
+ }
309
+ },
310
+ {
311
+ "PREDICTION": {
312
+ "from_entity": "Helen Brooks",
313
+ "to_entity": "unemployment",
314
+ "startDate": "2024-06-12",
315
+ "endDate": "2024-12-31",
316
+ "predicted_movement": "PREDICTED_MOVEMENT_INCREASE",
317
+ "source_text": "unemployment could drift higher in 2024 while remaining low in historical context."
318
+ }
319
+ },
320
+ {
321
+ "PREDICTION": {
322
+ "from_entity": "Helen Brooks",
323
+ "to_entity": "unemployment rate",
324
+ "start_date": "2024-06-12",
325
+ "end_date": "2025-12-31",
326
+ "predicted_movement": "PREDICTED_MOVEMENT_INCREASE",
327
+ "source_text": "Even so, less hiring activity could be enough to cause the unemployment rate to tick up to the mid-4% area by the end of next year due to worker churn."
328
+ }
329
+ },
330
+ {
331
+ "PREDICTION": {
332
+ "from_entity": "Helen Brooks",
333
+ "to_entity": "discretionary spending",
334
+ "start_date": "2024-06-12",
335
+ "end_date": "2025-12-31",
336
+ "predicted_movement": "PREDICTED_MOVEMENT_INCREASE",
337
+ "source_text": "On the brightside, the boost to funding that Congress approved late last year for FY23 was so large (nearly 10% yoy) that overall discretionary spending is likely to be slightly higher in real terms next year despite the new caps."
338
+ }
339
+ },
340
+ ]
341
+ }
342
+
343
+ Example Input:
344
+
345
+ Analyst: John Doe, Herb Johnson, Taylor Mason
346
+
347
+ Publish Date: 2024-01-02
348
+
349
+ Text Chunk: Scores on the Doors: crypto 65.3%, gold 10.1%, stocks 7.7%, HY bonds 4.3%, IG bonds 4.3%, govt bonds 3.4%,
350
+ oil 3.7%, cash 1.2%, commodities -0.1%, US dollar -2.0% YTD. The Biggest Picture: everyone’s new favorite
351
+ theme...US dollar debasement; US$ -11% since Sept, gold >$2k, bitcoin >$30k; right secular theme (deficits,
352
+ debt, geopolitics), US$ in 4th bear market of past 50 years, bullish gold, oil, Euro, international stocks; but pessimism
353
+ so high right now, and 200bps Fed cuts priced in from June to election, look for US$ trading bounce after Fed
354
+ ends hiking cycle May 3rd (Chart 2). Tale of the Tape: inflation slowing, Fed hiking cycle over, recession
355
+ expectations universal, yet UST 30-year can’t break below 3.6% (200-dma) because we’ve already traded 8% to 4% CPI,
356
+ labor market yet to crack, US govt deficit growing too quickly. The Price is Right: next 6 months it’s
357
+ “recession vs Fed cuts”; best tells who’s winning...HY bonds, homebuilders, semiconductors...HYG <73, XHB <70,
358
+ SOX <2900 recessionary, and if levels hold...it’s a no/soft landing.
359
+
360
+ {
361
+ "entities": [
362
+ {
363
+ "name": "crypto",
364
+ "type": "Financial Securities",
365
+ },
366
+ {
367
+ "name": "gold",
368
+ "type": "Financial Securities",
369
+ },
370
+ {
371
+ "name": "stocks",
372
+ "type": "Financial Securities",
373
+ },
374
+ {
375
+ "name": "HY bonds",
376
+ "type": "Financial Securities",
377
+ },
378
+ {
379
+ "name": "IG bonds",
380
+ "type": "Financial Securities",
381
+ },
382
+ {
383
+ "name": "govt bonds",
384
+ "type": "Financial Securities",
385
+ },
386
+ {
387
+ "name": "oil",
388
+ "type": "Financial Securities",
389
+ },
390
+ {
391
+ "name": "cash",
392
+ "type": "Financial Securities",
393
+ },
394
+ {
395
+ "name": "commodities",
396
+ "type": "Financial Securities",
397
+ },
398
+ {
399
+ "name": "US dollar",
400
+ "type": "Financial Securities",
401
+ },
402
+ {
403
+ "name": "US$",
404
+ "type": "Financial Securities",
405
+ },
406
+ {
407
+ "name": "gold",
408
+ "type": "Financial Securities",
409
+ },
410
+ {
411
+ "name": "bitcoin",
412
+ "type": "Financial Securities",
413
+ },
414
+ {
415
+ "name": "UST 30-year",
416
+ "type": "Financial Securities",
417
+ },
418
+ {
419
+ "name": "CPI",
420
+ "type": "Economic Indicators",
421
+ },
422
+ {
423
+ "name": "labor market",
424
+ "type": "Economic Indicators",
425
+ },
426
+ {
427
+ "name": "US govt deficit",
428
+ "type": "Economic Indicators",
429
+ },
430
+ {
431
+ "name": "HY bonds",
432
+ "type": "Financial Securities",
433
+ },
434
+ {
435
+ "name": "homebuilders",
436
+ "type": "Financial Markets",
437
+ },
438
+ {
439
+ "name": "semiconductors",
440
+ "type": "Financial Markets",
441
+ },
442
+ {
443
+ "name": "HYG",
444
+ "type": "Financial Securities",
445
+ },
446
+ {
447
+ "name": "XHB",
448
+ "type": "Financial Securities",
449
+ },
450
+ {
451
+ "name": "SOX",
452
+ "type": "Financial Securities",
453
+ }
454
+ ]
455
+ }
456
+
457
+ Example Output:
458
+
459
+ {
460
+ "relationships": [
461
+ {
462
+ "PREDICTION": {
463
+ "from_entity": "John Doe, Herb Johnson, Taylor Mason",
464
+ "to_entity": "US$",
465
+ "start_date": "2024-06-12",
466
+ "end_date": "2024-11-05",
467
+ "predicted_movement": "PREDICTED_MOVEMENT_INCREASE",
468
+ "source_text": "200bps Fed cuts priced in from June to election, look for US$ trading bounce after Fed ends hiking cycle May 3rd."
469
+ }
470
+ "PREDICTION": {
471
+ "from_entity": "John Doe, Herb Johnson, Taylor Mason",
472
+ "to_entity": "Fed cuts",
473
+ "start_date": "2024-06-12",
474
+ "end_date": "2024-11-05",
475
+ "predicted_movement": "PREDICTED_MOVEMENT_DECREASE",
476
+ "source_text": "200bps Fed cuts priced in from June to election, look for US$ trading bounce after Fed ends hiking cycle May 3rd."
477
+ }
478
+ },
479
+ ]
480
+ }
481
+ '''
llm_handler/llm_interface.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import List, Protocol, Type, Dict
3
+ import enum
4
+
5
+
6
+ class DefaultEnumMeta(enum.EnumMeta):
7
+
8
+ def __call__(cls, value=None, *args, **kwargs) -> DefaultEnumMeta:
9
+ if value is None:
10
+ return next(iter(cls))
11
+ return super().__call__(value, *args, **kwargs) # type: ignore
12
+
13
+
14
+ class LLMInterface(Protocol):
15
+
16
+ def get_chat_completion(self, messages: List, model: enum.Enum, temperature: float,
17
+ **kwargs) -> str:
18
+ ...
19
+
20
+ def get_text_embedding(self, input: str, model: enum.Enum) -> List[float]:
21
+ ...
llm_handler/mock_llm_handler.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Dict
2
+ import enum
3
+
4
+ import llm_handler.llm_interface as llm_interface
5
+
6
+
7
+ class MockLLMHandler(llm_interface.LLMInterface):
8
+
9
+ _chat_completion: Optional[List[str]]
10
+ _text_embedding: Optional[List[float]]
11
+
12
+ def __init__(self,
13
+ chat_completion: Optional[List[str]] = None,
14
+ text_embedding: Optional[List[float]] = None):
15
+ self._chat_completion = chat_completion
16
+ self._text_embedding = text_embedding
17
+
18
+ def get_chat_completion(self,
19
+ messages: List[Dict],
20
+ model: Optional[enum.Enum] = None,
21
+ temperature: float = 0.2,
22
+ **kwargs) -> str:
23
+ if not self._chat_completion:
24
+ raise ValueError(f'_chat_completion not set')
25
+ return self._chat_completion.pop(0)
26
+
27
+ def get_text_embedding(
28
+ self,
29
+ input: str,
30
+ model: Optional[enum.Enum] = None,
31
+ ) -> List[float]:
32
+ if not self._text_embedding:
33
+ raise ValueError(f'_text_embedding not set')
34
+ return self._text_embedding
llm_handler/openai_handler.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import openai
3
+ from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
4
+ import os
5
+ from typing import List, Optional, ClassVar
6
+ import enum
7
+
8
+ from llm_handler.llm_interface import LLMInterface, DefaultEnumMeta
9
+
10
+
11
+ class ChatModelVersion(enum.Enum, metaclass=DefaultEnumMeta):
12
+ GPT_3_5 = 'gpt-3.5-turbo-1106'
13
+ GPT_4 = 'gpt-4'
14
+ GPT_4_TURBO = 'gpt-4-1106-preview'
15
+ GPT_4_O = 'gpt-4o'
16
+
17
+
18
+ class EmbeddingModelVersion(enum.Enum, metaclass=DefaultEnumMeta):
19
+ SMALL_3 = 'text-embedding-3-small'
20
+ ADA_002 = 'text-embedding-ada-002'
21
+ LARGE = 'text-embedding-3-large'
22
+
23
+
24
+ class OpenAIHandler(LLMInterface):
25
+
26
+ _ENV_KEY_NAME: ClassVar[str] = 'OPENAI_API_KEY'
27
+ _client: openai.Client
28
+
29
+ def __init__(self, openai_api_key: Optional[str] = None):
30
+ _openai_api_key = openai_api_key or os.environ.get(self._ENV_KEY_NAME)
31
+ if not _openai_api_key:
32
+ raise ValueError(f'{self._ENV_KEY_NAME} not set')
33
+ openai.api_key = _openai_api_key
34
+ self._client = openai.Client()
35
+
36
+ def get_chat_completion( # type: ignore
37
+ self,
38
+ messages: List[ChatCompletionMessageParam],
39
+ model: ChatModelVersion = ChatModelVersion.GPT_4_O,
40
+ temperature: float = 0.2,
41
+ **kwargs) -> str:
42
+ response = self._client.chat.completions.create(model=model.value,
43
+ messages=messages,
44
+ temperature=temperature,
45
+ **kwargs)
46
+ responses: List[str] = []
47
+ for choice in response.choices:
48
+ if choice.finish_reason != 'stop' or not choice.message.content:
49
+ raise ValueError(f'Choice did not complete correctly: {choice}')
50
+ responses.append(choice.message.content)
51
+ if len(responses) != 1:
52
+ raise ValueError(f'Expected one response, got {len(responses)}: {responses}')
53
+ return responses[0]
54
+
55
+ def get_text_embedding( # type: ignore
56
+ self, input: str, model: EmbeddingModelVersion) -> List[float]:
57
+ response = self._client.embeddings.create(model=model.value,
58
+ encoding_format='float',
59
+ input=input)
60
+ if not response.data:
61
+ raise ValueError(f'No embedding in response: {response}')
62
+ elif len(response.data) != 1:
63
+ raise ValueError(f'More than one embedding in response: {response}')
64
+ return response.data[0].embedding
llm_handler/openai_handler_test.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ from unittest import mock
4
+ import pytest
5
+ import openai
6
+ from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
7
+
8
+ import llm_handler.openai_handler as openai_handler
9
+
10
+
11
+ # Clear openai.api_key before each test
12
+ @pytest.fixture(autouse=True)
13
+ def test_env_setup():
14
+ openai.api_key = None
15
+
16
+
17
+ # This allows us to clear the OPENAI_API_KEY before any test we want
18
+ @pytest.fixture()
19
+ def mock_settings_env_vars():
20
+ with mock.patch.dict(os.environ, clear=True):
21
+ yield
22
+
23
+
24
+ # Define some constants for our tests
25
+ _TEST_KEY: str = "TEST_KEY"
26
+ _TEST_MESSAGE: str = "Hello how are you?"
27
+ _TEST_MESSAGE_FOR_CHAT_COMPLETION: List[ChatCompletionMessageParam] = [
28
+ {
29
+ "role": "system", "content": "You are serving as a en endpoint to verify a test."
30
+ }, {
31
+ "role": "user", "content": "Respond with something to help us verify our code is working."
32
+ }
33
+ ]
34
+ _TEXT_EMBEDDING_SMALL_3_LENGTH: int = 1536
35
+
36
+
37
+ def test_init_without_key(mock_settings_env_vars):
38
+ # Ensure key not set as env var and openai.api_key not set
39
+ with pytest.raises(KeyError):
40
+ os.environ[openai_handler.OpenAIHandler._ENV_KEY_NAME]
41
+ assert openai.api_key == None
42
+ # Ensure proper exception raised when instantiating handler without key as param or env var
43
+ with pytest.raises(ValueError) as excinfo:
44
+ openai_handler.OpenAIHandler()
45
+ assert f'{openai_handler.OpenAIHandler._ENV_KEY_NAME} not set' in str(excinfo.value)
46
+ assert openai.api_key == None
47
+
48
+
49
+ def test_init_with_key_as_param():
50
+ # Ensure key is set as env var, key value is unique from _TEST_KEY, and openai.api_key not set
51
+ assert not os.environ[openai_handler.OpenAIHandler._ENV_KEY_NAME] == _TEST_KEY
52
+ assert openai.api_key == None
53
+ # Ensure successful instantiation and openai.api_key properly set
54
+ handler = openai_handler.OpenAIHandler(openai_api_key=_TEST_KEY)
55
+ assert isinstance(handler, openai_handler.OpenAIHandler)
56
+ assert openai.api_key == _TEST_KEY
57
+
58
+
59
+ def test_init_with_key_as_env_var(mock_settings_env_vars):
60
+ # Ensure key not set as env var and openai.api_key not set
61
+ with pytest.raises(KeyError):
62
+ os.environ[openai_handler.OpenAIHandler._ENV_KEY_NAME]
63
+ assert openai.api_key == None
64
+ # Set key as env var
65
+ os.environ.setdefault(openai_handler.OpenAIHandler._ENV_KEY_NAME, _TEST_KEY)
66
+ # Ensure successful instantiation and openai.api_key properly set
67
+ handler = openai_handler.OpenAIHandler()
68
+ assert isinstance(handler, openai_handler.OpenAIHandler)
69
+ assert openai.api_key == _TEST_KEY
proto/chunk_pb2.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: proto/chunk.proto
4
+ # Protobuf Python Version: 5.26.1
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+ # @@protoc_insertion_point(imports)
11
+
12
+ _sym_db = _symbol_database.Default()
13
+
14
+
15
+
16
+
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11proto/chunk.proto\x12\x06\x66inllm\"D\n\x08\x44ocument\x12\x11\n\tfile_path\x18\x01 \x01(\t\x12\x0f\n\x07\x61uthors\x18\x02 \x01(\t\x12\x14\n\x0cpublish_date\x18\x03 \x01(\t\"\xbe\x01\n\x05\x43hunk\x12\x10\n\x08\x63hunk_id\x18\x01 \x01(\t\x12\x12\n\nchunk_text\x18\x02 \x01(\t\x12%\n\nchunk_type\x18\x03 \x01(\x0e\x32\x11.finllm.ChunkType\x12\x13\n\x0b\x63hunk_index\x18\x04 \x01(\x03\x12\x19\n\x0fparent_chunk_id\x18\x05 \x01(\tH\x00\x12$\n\x08\x64ocument\x18\x06 \x01(\x0b\x32\x10.finllm.DocumentH\x00\x42\x12\n\x10parent_reference*k\n\tChunkType\x12\x16\n\x12\x43HUNK_TYPE_UNKNOWN\x10\x00\x12\x17\n\x13\x43HUNK_TYPE_SENTENCE\x10\x01\x12\x18\n\x14\x43HUNK_TYPE_PARAGRAPH\x10\x02\x12\x13\n\x0f\x43HUNK_TYPE_PAGE\x10\x03\x62\x06proto3')
18
+
19
+ _globals = globals()
20
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
21
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'proto.chunk_pb2', _globals)
22
+ if not _descriptor._USE_C_DESCRIPTORS:
23
+ DESCRIPTOR._loaded_options = None
24
+ _globals['_CHUNKTYPE']._serialized_start=292
25
+ _globals['_CHUNKTYPE']._serialized_end=399
26
+ _globals['_DOCUMENT']._serialized_start=29
27
+ _globals['_DOCUMENT']._serialized_end=97
28
+ _globals['_CHUNK']._serialized_start=100
29
+ _globals['_CHUNK']._serialized_end=290
30
+ # @@protoc_insertion_point(module_scope)
proto/chunk_pb2.pyi ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
2
+ from google.protobuf import descriptor as _descriptor
3
+ from google.protobuf import message as _message
4
+ from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional, Union as _Union
5
+
6
+ DESCRIPTOR: _descriptor.FileDescriptor
7
+
8
+ class ChunkType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
9
+ __slots__ = ()
10
+ CHUNK_TYPE_UNKNOWN: _ClassVar[ChunkType]
11
+ CHUNK_TYPE_SENTENCE: _ClassVar[ChunkType]
12
+ CHUNK_TYPE_PARAGRAPH: _ClassVar[ChunkType]
13
+ CHUNK_TYPE_PAGE: _ClassVar[ChunkType]
14
+ CHUNK_TYPE_UNKNOWN: ChunkType
15
+ CHUNK_TYPE_SENTENCE: ChunkType
16
+ CHUNK_TYPE_PARAGRAPH: ChunkType
17
+ CHUNK_TYPE_PAGE: ChunkType
18
+
19
+ class Document(_message.Message):
20
+ __slots__ = ("file_path", "authors", "publish_date")
21
+ FILE_PATH_FIELD_NUMBER: _ClassVar[int]
22
+ AUTHORS_FIELD_NUMBER: _ClassVar[int]
23
+ PUBLISH_DATE_FIELD_NUMBER: _ClassVar[int]
24
+ file_path: str
25
+ authors: str
26
+ publish_date: str
27
+ def __init__(self, file_path: _Optional[str] = ..., authors: _Optional[str] = ..., publish_date: _Optional[str] = ...) -> None: ...
28
+
29
+ class Chunk(_message.Message):
30
+ __slots__ = ("chunk_id", "chunk_text", "chunk_type", "chunk_index", "parent_chunk_id", "document")
31
+ CHUNK_ID_FIELD_NUMBER: _ClassVar[int]
32
+ CHUNK_TEXT_FIELD_NUMBER: _ClassVar[int]
33
+ CHUNK_TYPE_FIELD_NUMBER: _ClassVar[int]
34
+ CHUNK_INDEX_FIELD_NUMBER: _ClassVar[int]
35
+ PARENT_CHUNK_ID_FIELD_NUMBER: _ClassVar[int]
36
+ DOCUMENT_FIELD_NUMBER: _ClassVar[int]
37
+ chunk_id: str
38
+ chunk_text: str
39
+ chunk_type: ChunkType
40
+ chunk_index: int
41
+ parent_chunk_id: str
42
+ document: Document
43
+ def __init__(self, chunk_id: _Optional[str] = ..., chunk_text: _Optional[str] = ..., chunk_type: _Optional[_Union[ChunkType, str]] = ..., chunk_index: _Optional[int] = ..., parent_chunk_id: _Optional[str] = ..., document: _Optional[_Union[Document, _Mapping]] = ...) -> None: ...
proto/entity_pb2.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: proto/entity.proto
4
+ # Protobuf Python Version: 5.26.1
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+ # @@protoc_insertion_point(imports)
11
+
12
+ _sym_db = _symbol_database.Default()
13
+
14
+
15
+
16
+
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12proto/entity.proto\x12\x06\x66inllm\"0\n\x06\x45ntity\x12\x11\n\tentity_id\x18\x01 \x01(\t\x12\x13\n\x0b\x65ntity_name\x18\x02 \x01(\t\"\x99\x01\n\x0cRelationship\x12\x17\n\x0frelationship_id\x18\x01 \x01(\t\x12\x12\n\nstart_date\x18\x02 \x01(\t\x12\x10\n\x08\x65nd_date\x18\x03 \x01(\t\x12\x13\n\x0bsource_text\x18\x04 \x01(\t\x12\x35\n\x12predicted_movement\x18\x05 \x01(\x0e\x32\x19.finllm.PredictedMovement\"\x88\x01\n\x12\x45ntityRelationship\x12#\n\x0b\x66rom_entity\x18\x01 \x01(\x0b\x32\x0e.finllm.Entity\x12*\n\x0crelationship\x18\x02 \x01(\x0b\x32\x14.finllm.Relationship\x12!\n\tto_entity\x18\x03 \x01(\x0b\x32\x0e.finllm.Entity\"P\n\x14\x45ntityKnowledgeGraph\x12\x38\n\x14\x65ntity_relationships\x18\x01 \x03(\x0b\x32\x1a.finllm.EntityRelationship*\x99\x01\n\x11PredictedMovement\x12\"\n\x1ePREDICTED_MOVEMENT_UNSPECIFIED\x10\x00\x12\x1e\n\x1aPREDICTED_MOVEMENT_NEUTRAL\x10\x01\x12\x1f\n\x1bPREDICTED_MOVEMENT_INCREASE\x10\x02\x12\x1f\n\x1bPREDICTED_MOVEMENT_DECREASE\x10\x03\x62\x06proto3')
18
+
19
+ _globals = globals()
20
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
21
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'proto.entity_pb2', _globals)
22
+ if not _descriptor._USE_C_DESCRIPTORS:
23
+ DESCRIPTOR._loaded_options = None
24
+ _globals['_PREDICTEDMOVEMENT']._serialized_start=458
25
+ _globals['_PREDICTEDMOVEMENT']._serialized_end=611
26
+ _globals['_ENTITY']._serialized_start=30
27
+ _globals['_ENTITY']._serialized_end=78
28
+ _globals['_RELATIONSHIP']._serialized_start=81
29
+ _globals['_RELATIONSHIP']._serialized_end=234
30
+ _globals['_ENTITYRELATIONSHIP']._serialized_start=237
31
+ _globals['_ENTITYRELATIONSHIP']._serialized_end=373
32
+ _globals['_ENTITYKNOWLEDGEGRAPH']._serialized_start=375
33
+ _globals['_ENTITYKNOWLEDGEGRAPH']._serialized_end=455
34
+ # @@protoc_insertion_point(module_scope)
proto/entity_pb2.pyi ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from google.protobuf.internal import containers as _containers
2
+ from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
3
+ from google.protobuf import descriptor as _descriptor
4
+ from google.protobuf import message as _message
5
+ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
6
+
7
+ DESCRIPTOR: _descriptor.FileDescriptor
8
+
9
+ class PredictedMovement(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
10
+ __slots__ = ()
11
+ PREDICTED_MOVEMENT_UNSPECIFIED: _ClassVar[PredictedMovement]
12
+ PREDICTED_MOVEMENT_NEUTRAL: _ClassVar[PredictedMovement]
13
+ PREDICTED_MOVEMENT_INCREASE: _ClassVar[PredictedMovement]
14
+ PREDICTED_MOVEMENT_DECREASE: _ClassVar[PredictedMovement]
15
+ PREDICTED_MOVEMENT_UNSPECIFIED: PredictedMovement
16
+ PREDICTED_MOVEMENT_NEUTRAL: PredictedMovement
17
+ PREDICTED_MOVEMENT_INCREASE: PredictedMovement
18
+ PREDICTED_MOVEMENT_DECREASE: PredictedMovement
19
+
20
+ class Entity(_message.Message):
21
+ __slots__ = ("entity_id", "entity_name")
22
+ ENTITY_ID_FIELD_NUMBER: _ClassVar[int]
23
+ ENTITY_NAME_FIELD_NUMBER: _ClassVar[int]
24
+ entity_id: str
25
+ entity_name: str
26
+ def __init__(self, entity_id: _Optional[str] = ..., entity_name: _Optional[str] = ...) -> None: ...
27
+
28
+ class Relationship(_message.Message):
29
+ __slots__ = ("relationship_id", "start_date", "end_date", "source_text", "predicted_movement")
30
+ RELATIONSHIP_ID_FIELD_NUMBER: _ClassVar[int]
31
+ START_DATE_FIELD_NUMBER: _ClassVar[int]
32
+ END_DATE_FIELD_NUMBER: _ClassVar[int]
33
+ SOURCE_TEXT_FIELD_NUMBER: _ClassVar[int]
34
+ PREDICTED_MOVEMENT_FIELD_NUMBER: _ClassVar[int]
35
+ relationship_id: str
36
+ start_date: str
37
+ end_date: str
38
+ source_text: str
39
+ predicted_movement: PredictedMovement
40
+ def __init__(self, relationship_id: _Optional[str] = ..., start_date: _Optional[str] = ..., end_date: _Optional[str] = ..., source_text: _Optional[str] = ..., predicted_movement: _Optional[_Union[PredictedMovement, str]] = ...) -> None: ...
41
+
42
+ class EntityRelationship(_message.Message):
43
+ __slots__ = ("from_entity", "relationship", "to_entity")
44
+ FROM_ENTITY_FIELD_NUMBER: _ClassVar[int]
45
+ RELATIONSHIP_FIELD_NUMBER: _ClassVar[int]
46
+ TO_ENTITY_FIELD_NUMBER: _ClassVar[int]
47
+ from_entity: Entity
48
+ relationship: Relationship
49
+ to_entity: Entity
50
+ def __init__(self, from_entity: _Optional[_Union[Entity, _Mapping]] = ..., relationship: _Optional[_Union[Relationship, _Mapping]] = ..., to_entity: _Optional[_Union[Entity, _Mapping]] = ...) -> None: ...
51
+
52
+ class EntityKnowledgeGraph(_message.Message):
53
+ __slots__ = ("entity_relationships",)
54
+ ENTITY_RELATIONSHIPS_FIELD_NUMBER: _ClassVar[int]
55
+ entity_relationships: _containers.RepeatedCompositeFieldContainer[EntityRelationship]
56
+ def __init__(self, entity_relationships: _Optional[_Iterable[_Union[EntityRelationship, _Mapping]]] = ...) -> None: ...
proto/pdf_document_pb2.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # source: proto/pdf_document.proto
4
+ # Protobuf Python Version: 5.26.1
5
+ """Generated protocol buffer code."""
6
+ from google.protobuf import descriptor as _descriptor
7
+ from google.protobuf import descriptor_pool as _descriptor_pool
8
+ from google.protobuf import symbol_database as _symbol_database
9
+ from google.protobuf.internal import builder as _builder
10
+ # @@protoc_insertion_point(imports)
11
+
12
+ _sym_db = _symbol_database.Default()
13
+
14
+
15
+
16
+
17
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18proto/pdf_document.proto\x12\x06\x66inllm\"!\n\x0cPdfDocumentD\x12\x11\n\tfile_path\x18\x01 \x01(\tb\x06proto3')
18
+
19
+ _globals = globals()
20
+ _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
21
+ _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'proto.pdf_document_pb2', _globals)
22
+ if not _descriptor._USE_C_DESCRIPTORS:
23
+ DESCRIPTOR._loaded_options = None
24
+ _globals['_PDFDOCUMENTD']._serialized_start=36
25
+ _globals['_PDFDOCUMENTD']._serialized_end=69
26
+ # @@protoc_insertion_point(module_scope)
proto/pdf_document_pb2.pyi ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from google.protobuf import descriptor as _descriptor
2
+ from google.protobuf import message as _message
3
+ from typing import ClassVar as _ClassVar, Optional as _Optional
4
+
5
+ DESCRIPTOR: _descriptor.FileDescriptor
6
+
7
+ class PdfDocumentD(_message.Message):
8
+ __slots__ = ("file_path",)
9
+ FILE_PATH_FIELD_NUMBER: _ClassVar[int]
10
+ file_path: str
11
+ def __init__(self, file_path: _Optional[str] = ...) -> None: ...
query_pipeline/evaluation_engine.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Optional
3
+
4
+ from proto.entity_pb2 import PredictedMovement
5
+ from tabulate import tabulate
6
+
7
+ from domain.entity_d import (
8
+ EntityD,
9
+ EntityKnowledgeGraphD,
10
+ EntityRelationshipD,
11
+ RelationshipD,
12
+ )
13
+ from llm_handler.openai_handler import (
14
+ ChatCompletionMessageParam,
15
+ ChatModelVersion,
16
+ OpenAIHandler,
17
+ )
18
+ from utils.dates import parse_date
19
+
20
+ FUZZY_MATCH_ENTITIES_PROMPT = '''
21
+ You are an expert in financial analysis. You will be given a two lists of entities. Your task is to output a semantic mapping from the entities in list A to the entities in list B. This means an entity in list A that is semantically similar to an entity in list B should be mapped together. If there is no reasonable semantic match for an entity in list A, output an empty string. Output should be in the format of a JSON object. Ensure the entity keys are in the same order as the input list A.
22
+
23
+ Input:
24
+ List A: ["BofA", "Bank of Amerca Corp" "GDP", "Inflation", "Yen"]
25
+ List B: ["Bank of America", "inflation", "Gross Domestic Product", "oil"]
26
+
27
+ Output:
28
+ {
29
+ "BofA": "Bank of America",
30
+ "Bank of America Corp": "Bank of America",
31
+ "GDP": "Gross Domestic Product",
32
+ "Inflation": "inflation",
33
+ "Yen": ""
34
+ }
35
+ '''
36
+
37
+
38
+ class EvaluationEngine:
39
+
40
+ _handler: OpenAIHandler
41
+ _MODEL_VERSION: ChatModelVersion = ChatModelVersion.GPT_4_O
42
+ _TEMPERATURE: float = 0.2
43
+
44
+ def __init__(self,
45
+ ground_truth_kg: EntityKnowledgeGraphD,
46
+ openai_handler: Optional[OpenAIHandler] = None,
47
+ model_version: Optional[ChatModelVersion] = None):
48
+ self._handler = openai_handler or OpenAIHandler()
49
+ self._model_version = model_version or self._MODEL_VERSION
50
+
51
+ # setup adjacency list representation of ground truth knowledge graph
52
+ self.kg: dict[str, list[EntityRelationshipD]] = {}
53
+ for entity_relationship in ground_truth_kg.entity_relationships:
54
+ to_entity_name = entity_relationship.to_entity.entity_name
55
+ relationships = self.kg.get(to_entity_name, [])
56
+ relationships.append(entity_relationship)
57
+ self.kg[to_entity_name] = relationships
58
+
59
+ def _get_thesis_to_gt_entity_map(self, thesis_kg: EntityKnowledgeGraphD) -> dict[str, str]:
60
+ thesis_entities = []
61
+ for entity_relationship in thesis_kg.entity_relationships:
62
+ thesis_entities.append(entity_relationship.to_entity.entity_name)
63
+
64
+ # LLM call to return out the matched entities
65
+ messages: list[ChatCompletionMessageParam] = [
66
+ {
67
+ "role": "system", "content": FUZZY_MATCH_ENTITIES_PROMPT
68
+ }, {
69
+ "role": "user",
70
+ "content": f"List A: {thesis_entities}\nList B: {list(self.kg.keys())}"
71
+ }
72
+ ]
73
+ completion_text = self._handler.get_chat_completion(messages=messages,
74
+ model=self._model_version,
75
+ temperature=self._TEMPERATURE,
76
+ response_format={"type": "json_object"})
77
+
78
+ thesis_to_gt_entity_mapping: dict[str, str] = json.loads(completion_text)
79
+
80
+ return thesis_to_gt_entity_mapping
81
+
82
+ def _get_relationships_matching_timeperiod(
83
+ self, gt_kg_to_node: str, relationship: RelationshipD) -> list[EntityRelationshipD]:
84
+ matching_relationships = []
85
+ thesis_relationship_start = parse_date(relationship.start_date)
86
+ thesis_relationship_end = parse_date(relationship.end_date)
87
+ for gt_relationship in self.kg[gt_kg_to_node]:
88
+ gt_relationship_start = parse_date(gt_relationship.relationship.start_date)
89
+ gt_relationship_end = parse_date(gt_relationship.relationship.end_date)
90
+
91
+ if (gt_relationship_start <= thesis_relationship_start <= gt_relationship_end and \
92
+ gt_relationship_start <= thesis_relationship_end <= gt_relationship_end):
93
+ # thesis relationship timeframe and gt relationship timeframe overlap
94
+ matching_relationships.append(gt_relationship)
95
+
96
+ return matching_relationships
97
+
98
+ def evaluate_thesis(
99
+ self, thesis_kg: EntityKnowledgeGraphD
100
+ ) -> list[tuple[EntityRelationshipD, bool, Optional[EntityRelationshipD]]]:
101
+ thesis_to_kg_map = self._get_thesis_to_gt_entity_map(thesis_kg)
102
+ results = []
103
+ for thesis_relationship in thesis_kg.entity_relationships:
104
+ thesis_to_node = thesis_relationship.to_entity.entity_name
105
+ kg_node = thesis_to_kg_map[thesis_to_node]
106
+ if not kg_node: # no matching entity in KG
107
+ results.append((thesis_relationship, False, None))
108
+ continue
109
+
110
+ matching_relationships = self._get_relationships_matching_timeperiod(
111
+ kg_node, thesis_relationship.relationship)
112
+
113
+ for entity_relationship in matching_relationships:
114
+ if entity_relationship.relationship.predicted_movement == thesis_relationship.relationship.predicted_movement:
115
+ results.append((thesis_relationship, True, entity_relationship))
116
+ else:
117
+ results.append((thesis_relationship, False, entity_relationship))
118
+ if len(matching_relationships) == 0:
119
+ results.append((thesis_relationship, False, None))
120
+
121
+ return results
122
+
123
+ def evaluate_and_display_thesis(self, thesis_kg: EntityKnowledgeGraphD):
124
+ results = self.evaluate_thesis(thesis_kg)
125
+
126
+ int_to_str = {1: "Neutral", 2: 'Increase', 3: 'Decrease'}
127
+
128
+ headers = ["Thesis Claim", "Supported by KG", "Related KG Relationship"]
129
+ table_data = []
130
+ for triplet in results:
131
+ claim_entity = triplet[0].to_entity.entity_name
132
+ claim_movement = int_to_str[triplet[0].relationship.predicted_movement]
133
+ claim = f'{claim_entity} {claim_movement}'
134
+ if triplet[2]:
135
+ evidence = int_to_str[triplet[2].relationship.predicted_movement]
136
+ evidence += f' ({triplet[2].from_entity.entity_name}) '
137
+ else:
138
+ evidence = "No evidence in KG"
139
+ table_data.append([claim, triplet[1], evidence])
140
+ return tabulate(table_data, tablefmt="html", headers=headers)
141
+
142
+
143
+ if __name__ == '__main__':
144
+ # TODO: extract the cases into pytest tests
145
+ kg = EntityKnowledgeGraphD(entity_relationships=[
146
+ EntityRelationshipD(from_entity=EntityD(entity_id='3', entity_name="analyst A"),
147
+ relationship=RelationshipD(
148
+ relationship_id='2',
149
+ start_date='2021-01-01',
150
+ end_date='2024-12-31',
151
+ source_text='',
152
+ predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE),
153
+ to_entity=EntityD(entity_id='1', entity_name="GDP")),
154
+ EntityRelationshipD(from_entity=EntityD(entity_id='5', entity_name="analyst B"),
155
+ relationship=RelationshipD(
156
+ relationship_id='3',
157
+ start_date='2021-01-01',
158
+ end_date='2021-12-31',
159
+ source_text='',
160
+ predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_DECREASE),
161
+ to_entity=EntityD(entity_id='1', entity_name="GDP")),
162
+ EntityRelationshipD(from_entity=EntityD(entity_id='7', entity_name="analyst C"),
163
+ relationship=RelationshipD(
164
+ relationship_id='4',
165
+ start_date='2021-01-01',
166
+ end_date='2021-12-31',
167
+ source_text='',
168
+ predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL),
169
+ to_entity=EntityD(entity_id='1', entity_name="GDP")),
170
+ EntityRelationshipD(from_entity=EntityD(entity_id='9', entity_name="analyst D"),
171
+ relationship=RelationshipD(
172
+ relationship_id='5',
173
+ start_date='2021-01-01',
174
+ end_date='2021-12-31',
175
+ source_text='',
176
+ predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL),
177
+ to_entity=EntityD(entity_id='10', entity_name="USD")),
178
+ EntityRelationshipD( # out of time range for thesis
179
+ from_entity=EntityD(entity_id='9', entity_name="analyst E"),
180
+ relationship=RelationshipD(
181
+ relationship_id='5',
182
+ start_date='2024-01-01',
183
+ end_date='2024-12-31',
184
+ source_text='',
185
+ predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL),
186
+ to_entity=EntityD(entity_id='10', entity_name="USD")),
187
+ ])
188
+
189
+ thesis_claims = [
190
+ EntityRelationshipD(from_entity=EntityD(entity_id='2', entity_name="user"),
191
+ relationship=RelationshipD(
192
+ relationship_id='1',
193
+ start_date='2021-01-01',
194
+ end_date='2021-12-31',
195
+ source_text='',
196
+ predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE),
197
+ to_entity=EntityD(entity_id='1', entity_name="Gross Domestic Product")),
198
+ EntityRelationshipD(from_entity=EntityD(entity_id='2', entity_name="user"),
199
+ relationship=RelationshipD(
200
+ relationship_id='1',
201
+ start_date='2021-01-01',
202
+ end_date='2021-12-31',
203
+ source_text='',
204
+ predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE),
205
+ to_entity=EntityD(entity_id='1', entity_name="US$")),
206
+ EntityRelationshipD(from_entity=EntityD(entity_id='2', entity_name="user"),
207
+ relationship=RelationshipD(
208
+ relationship_id='1',
209
+ start_date='2021-01-01',
210
+ end_date='2021-12-31',
211
+ source_text='',
212
+ predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE),
213
+ to_entity=EntityD(entity_id='1', entity_name="Yen")),
214
+ ]
215
+ thesis = EntityKnowledgeGraphD(entity_relationships=thesis_claims)
216
+
217
+ eval_engine = EvaluationEngine(kg)
218
+
219
+ eval_engine.evaluate_and_display_thesis(thesis)
query_pipeline/thesis_extractor.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from domain.chunk_d import ChunkD
2
+ from domain.entity_d import EntityKnowledgeGraphD
3
+ from extraction_pipeline.relationship_extractor.openai_relationship_extractor import (
4
+ OpenAIRelationshipExtractor,)
5
+
6
+
7
+ class ThesisExtractor:
8
+ _relationship_extractor: OpenAIRelationshipExtractor
9
+
10
+ def __init__(self):
11
+ self._relationship_extractor = OpenAIRelationshipExtractor()
12
+
13
+ def extract_relationships(self, thesis: ChunkD) -> EntityKnowledgeGraphD:
14
+ # Thesis chunk will have DocumentD for its parent_reference field
15
+ # which will contain "user" as the author and more importantly
16
+ # the time of the thesis query for use by the relationship extractor
17
+ entity_relationships = list(self._relationship_extractor.process_element(thesis))
18
+ return EntityKnowledgeGraphD(entity_relationships=entity_relationships)
requirements.txt ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ annotated-types==0.7.0
2
+ anyio==4.4.0
3
+ certifi==2024.6.2
4
+ distro==1.9.0
5
+ exceptiongroup==1.2.1
6
+ grpcio==1.64.1
7
+ grpcio-tools==1.64.1
8
+ h11==0.14.0
9
+ httpcore==1.0.5
10
+ httpx==0.27.0
11
+ idna==3.7
12
+ importlib_metadata==7.1.0
13
+ iniconfig==2.0.0
14
+ neo4j==5.21.0
15
+ nodeenv==1.9.1
16
+ numpy==1.26.4
17
+ openai==1.33.0
18
+ packaging==24.0
19
+ pandas==2.2.2
20
+ platformdirs==4.2.2
21
+ pluggy==1.5.0
22
+ protobuf==5.27.1
23
+ pydantic==2.7.3
24
+ pydantic_core==2.18.4
25
+ PyMuPDF==1.24.5
26
+ PyMuPDFb==1.24.3
27
+ pyright==1.1.366
28
+ pytest==8.2.2
29
+ python-dateutil==2.9.0.post0
30
+ pytz==2024.1
31
+ six==1.16.0
32
+ sniffio==1.3.1
33
+ tabulate==0.9.0
34
+ tomli==2.0.1
35
+ tqdm==4.66.4
36
+ typing_extensions==4.12.2
37
+ tzdata==2024.1
38
+ yapf==0.40.2
39
+ zipp==3.19.2
storage/domain_dao.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import List, Type, Protocol, TypeVar, Dict, Set
3
+ import os
4
+ import json
5
+ import uuid
6
+
7
+ from domain.domain_protocol import DomainProtocol
8
+
9
+ DomainT = TypeVar('DomainT', bound=DomainProtocol)
10
+ MAP_BIN = os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))), ".bin", "maps")
11
+
12
+
13
+ class DomainDAO(Protocol[DomainT]):
14
+
15
+ def insert(self, domain_objs: List[DomainT]):
16
+ ...
17
+
18
+ def read_by_id(self, domain_id: str) -> DomainT:
19
+ ...
20
+
21
+ def read_all(self) -> Set[DomainT]:
22
+ ...
23
+
24
+
25
+ class InMemDomainDAO(DomainDAO[DomainT]):
26
+
27
+ _id_to_domain_obj: Dict[str, DomainT]
28
+
29
+ def __init__(self):
30
+ self._id_to_domain_obj = {}
31
+
32
+ def insert(self, domain_objs: List[DomainT]):
33
+ new_id_to_domain_obj = {domain_obj.id: domain_obj for domain_obj in domain_objs}
34
+ if len(new_id_to_domain_obj) != len(domain_objs):
35
+ raise ValueError("Duplicate IDs exist within incoming domain_objs")
36
+ if duplicate_ids := set(new_id_to_domain_obj.keys()) & set(self._id_to_domain_obj.keys()):
37
+ raise ValueError(f"Duplicate ids exist in DB: {duplicate_ids}")
38
+ self._id_to_domain_obj.update(new_id_to_domain_obj)
39
+
40
+ def read_by_id(self, domain_id: str) -> DomainT:
41
+ if domain_obj := self._id_to_domain_obj.get(domain_id):
42
+ return domain_obj
43
+ raise ValueError(f"Domain obj with id {domain_id} not found")
44
+
45
+ def read_all(self) -> Set[DomainT]:
46
+ return set(self._id_to_domain_obj.values())
47
+
48
+ @classmethod
49
+ def load_from_file(cls, file_path: str, domain_cls: Type[DomainT]) -> InMemDomainDAO[DomainT]:
50
+ if not os.path.isfile(file_path):
51
+ raise ValueError(f"File not found: {file_path}")
52
+ with open(file_path, 'r') as f:
53
+ domain_objs = [domain_cls.from_json(line) for line in f]
54
+ dao = cls()
55
+ dao.insert(domain_objs)
56
+ return dao
57
+
58
+ def save_to_file(self, file_path: str):
59
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
60
+ domain_jsons = [domain_obj.to_json() for domain_obj in self._id_to_domain_obj.values()]
61
+ with open(file_path, 'w') as f:
62
+ f.write('\n'.join(domain_jsons) + '\n')
63
+
64
+
65
+ class CacheDomainDAO(DomainDAO[DomainT]):
66
+
67
+ _id_to_domain_obj: Dict[str, DomainT]
68
+ _save_path: str
69
+
70
+ def __init__(self, save_path: str, domain_cls: Type[DomainT]):
71
+ self._id_to_domain_obj = {}
72
+ self._save_path = os.path.join(MAP_BIN, save_path)
73
+ self._load_cache(domain_cls)
74
+
75
+ def __enter__(self):
76
+ return self
77
+
78
+ def __call__(self, element: DomainT) -> DomainT:
79
+ self.insert([element])
80
+ return element
81
+
82
+ def __exit__(self, exc_type, exc_val, exc_tb):
83
+ self._save_cache()
84
+
85
+ def set(self, element: DomainT) -> uuid.UUID:
86
+ id = uuid.uuid4()
87
+ self._id_to_domain_obj[str(id)] = element
88
+ self._save_cache()
89
+ return id
90
+
91
+ def _save_cache(self):
92
+ os.makedirs(MAP_BIN, exist_ok=True)
93
+ cache = {}
94
+ if os.path.isfile(self._save_path):
95
+ with open(self._save_path, 'r') as f:
96
+ cache = json.load(f)
97
+ domain_json_map = {
98
+ id: domain_obj.to_json()
99
+ for id, domain_obj in self._id_to_domain_obj.items()
100
+ }
101
+ cache.update(domain_json_map)
102
+ with open(self._save_path, 'w') as f:
103
+ json.dump(cache, f, indent=4)
104
+
105
+ def _load_cache(self, domain_cls: Type[DomainT]):
106
+ if not os.path.isfile(self._save_path):
107
+ return
108
+ with open(self._save_path, 'r') as f:
109
+ domain_json_map = json.load(f)
110
+ for id, domain_json in domain_json_map.items():
111
+ self._id_to_domain_obj[id] = domain_cls.from_json(domain_json)
112
+
113
+ def read_by_id(self, domain_id: str) -> DomainT:
114
+ if domain_obj := self._id_to_domain_obj.get(domain_id):
115
+ return domain_obj
116
+ raise ValueError(f"Domain obj with id {domain_id} not found")
117
+
118
+ def read_all(self) -> Set[DomainT]:
119
+ return set(self._id_to_domain_obj.values())
120
+
121
+ def insert(self, domain_objs: List[DomainT]):
122
+ new_id_to_domain_obj = {domain_obj.id: domain_obj for domain_obj in domain_objs}
123
+ if len(new_id_to_domain_obj) != len(domain_objs):
124
+ raise ValueError("Duplicate IDs exist within incoming domain_objs")
125
+ if duplicate_ids := set(new_id_to_domain_obj.keys()) & set(self._id_to_domain_obj.keys()):
126
+ raise ValueError(f"Duplicate ids exist in DB: {duplicate_ids}")
127
+ self._id_to_domain_obj.update(new_id_to_domain_obj)
128
+ self._save_cache()
storage/domain_dao_test.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import unittest
3
+ import uuid
4
+ import os
5
+ from storage.domain_dao import InMemDomainDAO
6
+ from domain.domain_protocol_test import TimestampTestD
7
+
8
+
9
+ class InMemDomainDAOTest(unittest.TestCase):
10
+
11
+ def setUp(self):
12
+ self.in_mem_dao: InMemDomainDAO = InMemDomainDAO[TimestampTestD]()
13
+ self.timestamp_d: TimestampTestD = TimestampTestD(nanos=1)
14
+
15
+ def test_insert_domain_obj(self):
16
+ self.in_mem_dao.insert([self.timestamp_d])
17
+ expected_map = {self.timestamp_d.id: self.timestamp_d}
18
+ self.assertEqual(self.in_mem_dao._id_to_domain_obj, expected_map)
19
+
20
+ def test_insert_domain_obj_raise_on_duplicate_id_in_db(self):
21
+ self.in_mem_dao.insert([self.timestamp_d])
22
+ with self.assertRaises(ValueError) as context:
23
+ self.in_mem_dao.insert([self.timestamp_d])
24
+ self.assertIn("Duplicate ids exist in DB", str(context.exception))
25
+
26
+ def test_insert_domain_obj_raise_on_duplicate_id_arguements(self):
27
+ with self.assertRaises(ValueError) as context:
28
+ self.in_mem_dao.insert([self.timestamp_d, self.timestamp_d])
29
+ self.assertIn("Duplicate IDs exist within incoming domain_objs", str(context.exception))
30
+
31
+ def test_read_by_id(self):
32
+ self.in_mem_dao.insert([self.timestamp_d])
33
+ timestamp_d = self.in_mem_dao.read_by_id(self.timestamp_d.id)
34
+ self.assertEqual(timestamp_d, self.timestamp_d)
35
+
36
+ def test_read_by_id_raise_not_found(self):
37
+ with self.assertRaises(ValueError):
38
+ self.in_mem_dao.read_by_id(self.timestamp_d.id)
39
+
40
+ def test_read_all(self):
41
+ timestamp_d_b = TimestampTestD(2)
42
+ self.in_mem_dao.insert([self.timestamp_d, timestamp_d_b])
43
+ expected_timestamps = {self.timestamp_d, timestamp_d_b}
44
+ self.assertEqual(expected_timestamps, self.in_mem_dao.read_all())
45
+
46
+ def test_load_from_file(self):
47
+ file_path = f".bin/{uuid.uuid4()}.jsonl"
48
+ with open(file_path, 'w') as f:
49
+ f.write(self.timestamp_d.to_json() + '\n')
50
+ dao = InMemDomainDAO[TimestampTestD].load_from_file(file_path, TimestampTestD)
51
+ os.remove(file_path)
52
+ self.assertEqual({self.timestamp_d}, dao.read_all())
53
+
54
+ def test_load_from_file_fail_not_found(self):
55
+ with self.assertRaises(ValueError):
56
+ _ = InMemDomainDAO[TimestampTestD].load_from_file("file_path", TimestampTestD)
57
+
58
+ def test_save_to_file(self):
59
+ file_path = f".bin/{uuid.uuid4()}.jsonl"
60
+ self.in_mem_dao.insert([self.timestamp_d])
61
+ self.in_mem_dao.save_to_file(file_path)
62
+ created_dao = InMemDomainDAO[TimestampTestD].load_from_file(file_path, TimestampTestD)
63
+ os.remove(file_path)
64
+ self.assertEqual(self.in_mem_dao.read_all(), created_dao.read_all())
65
+
66
+
67
+ #TODO: Add test for CacheDomainDAO
68
+
69
+ if __name__ == '__main__':
70
+ logging.basicConfig(level=logging.DEBUG)
71
+ unittest.main()
storage/neo4j_dao.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import Optional
4
+
5
+ import neo4j
6
+
7
+ from domain.entity_d import (
8
+ EntityD,
9
+ EntityKnowledgeGraphD,
10
+ EntityRelationshipD,
11
+ RelationshipD,
12
+ )
13
+
14
+
15
+ class Neo4jError(Exception):
16
+ ...
17
+
18
+
19
+ class Neo4jDomainDAO:
20
+ """ To be used with a context manager to ensure the connection is closed after use. """
21
+
22
+ def __enter__(self):
23
+ uri = os.environ.get("NEO4J_URI", "")
24
+ user = os.environ.get("NEO4J_USER", "")
25
+ password = os.environ.get("NEO4J_PASSWORD", "")
26
+
27
+ if not uri:
28
+ raise ValueError("NEO4J_URI environment variable not set")
29
+ if not user:
30
+ raise ValueError("NEO4J_USER environment variable not set")
31
+ if not password:
32
+ raise ValueError("NEO4J_PASSWORD environment variable not set")
33
+
34
+ try:
35
+ self.driver = neo4j.GraphDatabase.driver(uri, auth=(user, password))
36
+ self.driver.verify_connectivity()
37
+ except Exception as e:
38
+ logging.error(f"Failed to connect to Neo4j: {e}")
39
+ raise Neo4jError("Failed to connect to Neo4j")
40
+
41
+ return self
42
+
43
+ def insert(self, knowledge_graph: EntityKnowledgeGraphD, pdf_file: str = ""):
44
+ for entity_relationship in knowledge_graph.entity_relationships:
45
+ create_cmds = entity_relationship.neo4j_create_cmds
46
+ create_cmds_args = entity_relationship.neo4j_create_args
47
+
48
+ for create_cmd, args in zip(create_cmds, create_cmds_args):
49
+ args['pdf_file'] = pdf_file
50
+ try:
51
+ self.driver.execute_query(
52
+ create_cmd, # type: ignore
53
+ parameters_=args, # type: ignore
54
+ database_='neo4j') # type: ignore
55
+ except Exception as e:
56
+ logging.warning(
57
+ f"Failed to insert entity relationship: {entity_relationship} due to {e}")
58
+
59
+ def query(self, query, query_args):
60
+ return self.driver.execute_query(query, parameters_=query_args,
61
+ database_='neo4j') # type: ignore
62
+
63
+ def get_knowledge_graph(self) -> Optional[EntityKnowledgeGraphD]:
64
+ records = [] #list[dict[str, Neo4jDict]]
65
+ try:
66
+ records, _, _ = self.driver.execute_query("MATCH (from:Entity) -[r:Relationship]-> (to:Entity) RETURN from, properties(r), to", database_='neo4j') # type: ignore
67
+ except Exception as e:
68
+ logging.exception(e)
69
+ return None
70
+
71
+ entity_relationships = []
72
+ for record in records:
73
+ er_dict = record.data()
74
+
75
+ from_args = er_dict['from']
76
+ from_entity = EntityD(entity_id='', entity_name=from_args['name'])
77
+
78
+ to_args = er_dict['to']
79
+ to_entity = EntityD(entity_id='', entity_name=to_args['name'])
80
+
81
+ relationship_args = er_dict['properties(r)']
82
+ relationship = RelationshipD(relationship_id='',
83
+ start_date=relationship_args['start_date'],
84
+ end_date=relationship_args['end_date'],
85
+ source_text=relationship_args['source_text'],
86
+ predicted_movement=RelationshipD.from_string(
87
+ relationship_args['predicted_movement']))
88
+
89
+ entity_relationships.append(
90
+ EntityRelationshipD(from_entity=from_entity,
91
+ relationship=relationship,
92
+ to_entity=to_entity))
93
+
94
+ return EntityKnowledgeGraphD(entity_relationships=entity_relationships)
95
+
96
+ def __exit__(self, exception_type, exception_value, traceback):
97
+ if traceback:
98
+ logging.error("Neo4jDomainDAO error: %s | %s | %s",
99
+ exception_type,
100
+ exception_value,
101
+ traceback)
102
+ self.driver.close()
utils/dates.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import datetime
2
+
3
+ DATE_FMT = "%Y-%m-%d"
4
+
5
+
6
+ def parse_date(date_str: str) -> datetime.datetime:
7
+ return datetime.datetime.strptime(date_str, DATE_FMT)