Spaces:
Sleeping
Sleeping
Gateston Johns
commited on
Commit
·
9041389
1
Parent(s):
fc4bc08
first real commit
Browse files- app.py +61 -0
- domain/chunk_d.py +89 -0
- domain/chunk_d_test.py +71 -0
- domain/domain_protocol.py +89 -0
- domain/domain_protocol_test.py +69 -0
- domain/entity_d.py +169 -0
- domain/entity_d_test.py +96 -0
- extraction_pipeline/base_stage.py +63 -0
- extraction_pipeline/document_metadata_extractor/document_metadata_extractor.py +9 -0
- extraction_pipeline/document_metadata_extractor/openai_document_metadata_extractor.py +73 -0
- extraction_pipeline/document_metadata_extractor/openai_document_metadata_extractor_test.py +74 -0
- extraction_pipeline/document_metadata_extractor/prompts.py +22 -0
- extraction_pipeline/pdf_process_stage.py +131 -0
- extraction_pipeline/pdf_process_stage_test.py +103 -0
- extraction_pipeline/pdf_to_knowledge_graph_transform.py +153 -0
- extraction_pipeline/relationship_extractor/entity_relationship_extractor.py +10 -0
- extraction_pipeline/relationship_extractor/openai_relationship_extractor.py +108 -0
- extraction_pipeline/relationship_extractor/openai_relationship_extractor_test.py +87 -0
- extraction_pipeline/relationship_extractor/prompts.py +481 -0
- llm_handler/llm_interface.py +21 -0
- llm_handler/mock_llm_handler.py +34 -0
- llm_handler/openai_handler.py +64 -0
- llm_handler/openai_handler_test.py +69 -0
- proto/chunk_pb2.py +30 -0
- proto/chunk_pb2.pyi +43 -0
- proto/entity_pb2.py +34 -0
- proto/entity_pb2.pyi +56 -0
- proto/pdf_document_pb2.py +26 -0
- proto/pdf_document_pb2.pyi +11 -0
- query_pipeline/evaluation_engine.py +219 -0
- query_pipeline/thesis_extractor.py +18 -0
- requirements.txt +39 -0
- storage/domain_dao.py +128 -0
- storage/domain_dao_test.py +71 -0
- storage/neo4j_dao.py +102 -0
- utils/dates.py +7 -0
app.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from typing import Dict, Union, List
|
3 |
+
import json
|
4 |
+
import uuid
|
5 |
+
from storage.neo4j_dao import Neo4jDomainDAO
|
6 |
+
from domain.chunk_d import DocumentD, ChunkD
|
7 |
+
from query_pipeline.thesis_extractor import ThesisExtractor
|
8 |
+
from query_pipeline.evaluation_engine import EvaluationEngine
|
9 |
+
from proto.chunk_pb2 import Chunk, ChunkType
|
10 |
+
from datetime import datetime
|
11 |
+
import os
|
12 |
+
|
13 |
+
os.environ['OPENAI_API_KEY'] = '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'
|
14 |
+
|
15 |
+
os.environ['NEO4J_USER'] = "neo4j"
|
16 |
+
os.environ['NEO4J_PASSWORD'] = "dOIZwzF_GImgwjF-smChe60QGQgicq8ER8RUlZvornU"
|
17 |
+
os.environ['NEO4J_URI'] = "neo4j+s://2317ae21.databases.neo4j.io"
|
18 |
+
|
19 |
+
def thesis_evaluation(thesis: str) -> str:
|
20 |
+
thesis_document_d = DocumentD(file_path="",
|
21 |
+
authors="user",
|
22 |
+
publish_date="2024-06-18")
|
23 |
+
|
24 |
+
thesis_chunk_d = ChunkD(chunk_text=thesis,
|
25 |
+
chunk_type=ChunkType.CHUNK_TYPE_PAGE,
|
26 |
+
chunk_index=0,
|
27 |
+
parent_reference=thesis_document_d)
|
28 |
+
|
29 |
+
thesis_entity_kg = ThesisExtractor().extract_relationships(thesis_chunk_d)
|
30 |
+
|
31 |
+
with Neo4jDomainDAO() as neo4j_dao:
|
32 |
+
full_db_kg = neo4j_dao.get_knowledge_graph()
|
33 |
+
|
34 |
+
return EvaluationEngine(full_db_kg).evaluate_and_display_thesis(thesis_entity_kg)
|
35 |
+
|
36 |
+
|
37 |
+
theme = gr.themes.Soft(
|
38 |
+
primary_hue="cyan",
|
39 |
+
secondary_hue="green",
|
40 |
+
neutral_hue="slate",
|
41 |
+
)
|
42 |
+
|
43 |
+
with gr.Blocks(theme=theme) as demo:
|
44 |
+
gr.HTML("""
|
45 |
+
<div style="background-color: rgb(171, 188, 251);">
|
46 |
+
<div style="background-color: rgb(151, 168, 231); padding-top: 2px; padding-bottom: 7px; text-align: center;">
|
47 |
+
<h1>AthenaAIC MetisLLM Thesis Evaluation Tool</h1>
|
48 |
+
</div>
|
49 |
+
</div>""")
|
50 |
+
with gr.Row():
|
51 |
+
inp = gr.Textbox(placeholder="What is your thesis?", scale=3, lines=3, show_label=False)
|
52 |
+
submit = gr.Button(scale=1, value="Evaluate")
|
53 |
+
with gr.Row():
|
54 |
+
out = gr.HTML()
|
55 |
+
|
56 |
+
submit.click(thesis_evaluation, inputs=inp, outputs=out)
|
57 |
+
|
58 |
+
demo.launch(
|
59 |
+
auth=('athena-admin', 'athena'),
|
60 |
+
share=True
|
61 |
+
)
|
domain/chunk_d.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
import dataclasses
|
3 |
+
import uuid
|
4 |
+
from typing import Union
|
5 |
+
import hashlib
|
6 |
+
|
7 |
+
import proto.chunk_pb2 as chunk_pb2
|
8 |
+
from domain.domain_protocol import DomainProtocol
|
9 |
+
|
10 |
+
|
11 |
+
@dataclasses.dataclass(frozen=True)
|
12 |
+
class DocumentD(DomainProtocol[chunk_pb2.Document]):
|
13 |
+
file_path: str
|
14 |
+
authors: str
|
15 |
+
publish_date: str
|
16 |
+
|
17 |
+
@property
|
18 |
+
def id(self) -> str:
|
19 |
+
return hashlib.sha256(self.to_proto().SerializeToString()).hexdigest()
|
20 |
+
|
21 |
+
@classmethod
|
22 |
+
def _from_proto(cls, proto: chunk_pb2.Document) -> DocumentD:
|
23 |
+
return cls(file_path=proto.file_path,
|
24 |
+
authors=proto.authors,
|
25 |
+
publish_date=proto.publish_date)
|
26 |
+
|
27 |
+
def to_proto(self) -> chunk_pb2.Document:
|
28 |
+
return chunk_pb2.Document(file_path=self.file_path,
|
29 |
+
authors=self.authors,
|
30 |
+
publish_date=self.publish_date)
|
31 |
+
|
32 |
+
|
33 |
+
@dataclasses.dataclass(frozen=True)
|
34 |
+
class ChunkD(DomainProtocol[chunk_pb2.Chunk]):
|
35 |
+
|
36 |
+
@property
|
37 |
+
def id(self) -> str:
|
38 |
+
return str(self.chunk_id)
|
39 |
+
|
40 |
+
chunk_text: str
|
41 |
+
chunk_type: chunk_pb2.ChunkType
|
42 |
+
chunk_index: int
|
43 |
+
parent_reference: Union[uuid.UUID, DocumentD]
|
44 |
+
chunk_id: uuid.UUID = dataclasses.field(default_factory=uuid.uuid4)
|
45 |
+
|
46 |
+
def __post_init__(self):
|
47 |
+
if self.chunk_type == chunk_pb2.ChunkType.CHUNK_TYPE_PAGE:
|
48 |
+
if not isinstance(self.parent_reference, DocumentD):
|
49 |
+
raise ValueError(
|
50 |
+
f"Chunk (id: {self.chunk_id}) with type {self.chunk_type} must have a DocumentD parent_reference."
|
51 |
+
)
|
52 |
+
elif not isinstance(self.parent_reference, uuid.UUID):
|
53 |
+
raise ValueError(
|
54 |
+
f"Chunk (id: {self.chunk_id}) with type {self.chunk_type} must have a uuid.UUID parent_reference."
|
55 |
+
)
|
56 |
+
|
57 |
+
@classmethod
|
58 |
+
def _from_proto(cls, proto: chunk_pb2.Chunk) -> ChunkD:
|
59 |
+
if proto.HasField('parent_chunk_id'):
|
60 |
+
return cls(chunk_id=uuid.UUID(proto.chunk_id),
|
61 |
+
parent_reference=uuid.UUID(proto.parent_chunk_id),
|
62 |
+
chunk_text=proto.chunk_text,
|
63 |
+
chunk_type=proto.chunk_type,
|
64 |
+
chunk_index=proto.chunk_index)
|
65 |
+
elif proto.HasField('document'):
|
66 |
+
return cls(chunk_id=uuid.UUID(proto.chunk_id),
|
67 |
+
parent_reference=DocumentD._from_proto(proto.document),
|
68 |
+
chunk_text=proto.chunk_text,
|
69 |
+
chunk_type=proto.chunk_type,
|
70 |
+
chunk_index=proto.chunk_index)
|
71 |
+
else:
|
72 |
+
raise ValueError(
|
73 |
+
f"Chunk proto (id: {proto.chunk_id}) has no 'parent' or 'document' field.")
|
74 |
+
|
75 |
+
def to_proto(self) -> chunk_pb2.Chunk:
|
76 |
+
chunk_proto = chunk_pb2.Chunk()
|
77 |
+
chunk_proto.chunk_id = str(self.chunk_id)
|
78 |
+
chunk_proto.chunk_text = self.chunk_text
|
79 |
+
chunk_proto.chunk_type = self.chunk_type
|
80 |
+
chunk_proto.chunk_index = self.chunk_index
|
81 |
+
if isinstance(self.parent_reference, uuid.UUID):
|
82 |
+
chunk_proto.parent_chunk_id = str(self.parent_reference)
|
83 |
+
elif isinstance(self.parent_reference, DocumentD):
|
84 |
+
chunk_proto.document.CopyFrom(self.parent_reference.to_proto())
|
85 |
+
else:
|
86 |
+
raise ValueError(
|
87 |
+
f"Chunk (id: {self.chunk_id}) parent_reference is of unknown type: {type(self.parent_reference)}"
|
88 |
+
)
|
89 |
+
return chunk_proto
|
domain/chunk_d_test.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import logging
|
3 |
+
import uuid
|
4 |
+
|
5 |
+
from domain.chunk_d import DocumentD, ChunkD
|
6 |
+
import proto.chunk_pb2 as chunk_pb2
|
7 |
+
|
8 |
+
|
9 |
+
class DocumentDTest(unittest.TestCase):
|
10 |
+
|
11 |
+
@classmethod
|
12 |
+
def setUpClass(cls):
|
13 |
+
cls.document_d = DocumentD(file_path='test_file',
|
14 |
+
authors='BofA John Doe Herb Johnson Taylor Mason',
|
15 |
+
publish_date='2021-01-01')
|
16 |
+
|
17 |
+
def test_proto_roundtrip(self):
|
18 |
+
document_proto = self.document_d.to_proto()
|
19 |
+
roundtrip_document_d = DocumentD.from_proto(document_proto)
|
20 |
+
self.assertEqual(self.document_d, roundtrip_document_d)
|
21 |
+
|
22 |
+
|
23 |
+
class ChunkDTest(unittest.TestCase):
|
24 |
+
|
25 |
+
@classmethod
|
26 |
+
def setUpClass(cls):
|
27 |
+
cls.parent_reference_document_d = DocumentD(
|
28 |
+
file_path='test_file',
|
29 |
+
authors='BofA John Doe Herb Johnson Taylor Mason',
|
30 |
+
publish_date='2021-01-01')
|
31 |
+
|
32 |
+
def test_validate(self):
|
33 |
+
with self.assertRaises(ValueError,
|
34 |
+
msg="CHUNK_TYPE_PAGE must have a DocumentD parent_reference"):
|
35 |
+
ChunkD(chunk_text='test chunk text',
|
36 |
+
chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PAGE,
|
37 |
+
chunk_index=9,
|
38 |
+
parent_reference=uuid.uuid4(),
|
39 |
+
chunk_id=uuid.uuid4())
|
40 |
+
with self.assertRaises(ValueError,
|
41 |
+
msg="CHUNK_TYPE_SENTENCE must have a uuid.UUID parent_reference"):
|
42 |
+
ChunkD(chunk_text='test chunk text',
|
43 |
+
chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_SENTENCE,
|
44 |
+
chunk_index=0,
|
45 |
+
parent_reference=self.parent_reference_document_d,
|
46 |
+
chunk_id=uuid.uuid4())
|
47 |
+
ChunkD(chunk_text='test chunk text',
|
48 |
+
chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PAGE,
|
49 |
+
chunk_index=9,
|
50 |
+
parent_reference=self.parent_reference_document_d,
|
51 |
+
chunk_id=uuid.uuid4())
|
52 |
+
ChunkD(chunk_text='test chunk text',
|
53 |
+
chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_SENTENCE,
|
54 |
+
chunk_index=0,
|
55 |
+
parent_reference=uuid.uuid4(),
|
56 |
+
chunk_id=uuid.uuid4())
|
57 |
+
|
58 |
+
def test_proto_roundtrip(self):
|
59 |
+
test_chunk_d = ChunkD(chunk_id=uuid.uuid4(),
|
60 |
+
parent_reference=self.parent_reference_document_d,
|
61 |
+
chunk_text='test chunk text',
|
62 |
+
chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PAGE,
|
63 |
+
chunk_index=9)
|
64 |
+
chunk_proto = test_chunk_d.to_proto()
|
65 |
+
roundtrip_chunk_d = ChunkD.from_proto(chunk_proto)
|
66 |
+
self.assertEqual(test_chunk_d, roundtrip_chunk_d)
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == '__main__':
|
70 |
+
logging.basicConfig(level=logging.INFO)
|
71 |
+
unittest.main()
|
domain/domain_protocol.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Optional, Protocol, Tuple, Type, TypeVar, get_args
|
4 |
+
|
5 |
+
from google.protobuf import json_format, message
|
6 |
+
|
7 |
+
MessageType = TypeVar("MessageType", bound=message.Message)
|
8 |
+
DomainProtocolType = TypeVar("DomainProtocolType", bound='DomainProtocol')
|
9 |
+
|
10 |
+
|
11 |
+
class ProtoDeserializationError(Exception):
|
12 |
+
...
|
13 |
+
|
14 |
+
|
15 |
+
class DomainProtocol(Protocol[MessageType]):
|
16 |
+
|
17 |
+
@property
|
18 |
+
def id(self) -> str:
|
19 |
+
...
|
20 |
+
|
21 |
+
@classmethod
|
22 |
+
def _from_proto(cls: Type[DomainProtocolType], proto: MessageType) -> DomainProtocolType:
|
23 |
+
...
|
24 |
+
|
25 |
+
def to_proto(self) -> MessageType:
|
26 |
+
...
|
27 |
+
|
28 |
+
@classmethod
|
29 |
+
def message_cls(cls: Type[DomainProtocolType]) -> Type[MessageType]:
|
30 |
+
orig_bases: Optional[Tuple[Type[MessageType], ...]] = getattr(cls, "__orig_bases__", None)
|
31 |
+
if not orig_bases:
|
32 |
+
raise ValueError(f"Class {cls} does not have __orig_bases__")
|
33 |
+
if len(orig_bases) != 1:
|
34 |
+
raise ValueError(f"Class {cls} has unexpected number of bases: {orig_bases}")
|
35 |
+
return get_args(orig_bases[0])[0]
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def from_proto(cls: Type[DomainProtocolType],
|
39 |
+
proto: MessageType,
|
40 |
+
allow_empty: bool = False) -> DomainProtocolType:
|
41 |
+
try:
|
42 |
+
if not allow_empty:
|
43 |
+
cls.validate_proto_not_empty(proto)
|
44 |
+
return cls._from_proto(proto)
|
45 |
+
except Exception as e:
|
46 |
+
error_str = f"Failed to convert {cls} - {e}"
|
47 |
+
raise ProtoDeserializationError(error_str) from e
|
48 |
+
|
49 |
+
@classmethod
|
50 |
+
def from_json(cls: Type[DomainProtocolType], json_str: str) -> DomainProtocolType:
|
51 |
+
try:
|
52 |
+
proto_cls = cls.message_cls()
|
53 |
+
proto = proto_cls()
|
54 |
+
json_format.Parse(json_str, proto)
|
55 |
+
return cls.from_proto(proto)
|
56 |
+
except json_format.ParseError as e:
|
57 |
+
error_str = f"{cls} failed to parse json string: {json_str} - {e}"
|
58 |
+
raise ProtoDeserializationError(error_str) from e
|
59 |
+
|
60 |
+
def to_json(self) -> str:
|
61 |
+
return json_format.MessageToJson(self.to_proto()).replace("\n", " ")
|
62 |
+
|
63 |
+
@classmethod
|
64 |
+
def validate_proto_not_empty(cls, proto: message.Message):
|
65 |
+
if cls.is_empty(proto):
|
66 |
+
raise ValueError("Proto is empty")
|
67 |
+
|
68 |
+
@classmethod
|
69 |
+
def is_empty(cls, proto: message.Message) -> bool:
|
70 |
+
descriptor = getattr(proto, 'DESCRIPTOR', None)
|
71 |
+
fields = list(descriptor.fields) if descriptor else []
|
72 |
+
while fields:
|
73 |
+
field = fields.pop()
|
74 |
+
|
75 |
+
if field.label == field.LABEL_REPEATED:
|
76 |
+
eval_func = lambda x: x == field.default_value
|
77 |
+
if field.type == field.TYPE_MESSAGE:
|
78 |
+
eval_func = cls.is_empty
|
79 |
+
if not all([eval_func(item) for item in getattr(proto, field.name)]):
|
80 |
+
return False
|
81 |
+
|
82 |
+
elif field.type == field.TYPE_MESSAGE:
|
83 |
+
if not cls.is_empty(getattr(proto, field.name)):
|
84 |
+
return False
|
85 |
+
else:
|
86 |
+
field_value = getattr(proto, field.name)
|
87 |
+
if field_value != field.default_value:
|
88 |
+
return False
|
89 |
+
return True
|
domain/domain_protocol_test.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
import logging
|
3 |
+
import unittest
|
4 |
+
from google.protobuf import timestamp_pb2
|
5 |
+
import dataclasses
|
6 |
+
|
7 |
+
from domain.domain_protocol import DomainProtocol, ProtoDeserializationError
|
8 |
+
|
9 |
+
|
10 |
+
@dataclasses.dataclass(frozen=True)
|
11 |
+
class TimestampTestD(DomainProtocol[timestamp_pb2.Timestamp]):
|
12 |
+
|
13 |
+
nanos: int
|
14 |
+
|
15 |
+
@property
|
16 |
+
def id(self) -> str:
|
17 |
+
return str(self.nanos)
|
18 |
+
|
19 |
+
@classmethod
|
20 |
+
def _from_proto(cls, proto: timestamp_pb2.Timestamp) -> TimestampTestD:
|
21 |
+
return cls(nanos=proto.nanos)
|
22 |
+
|
23 |
+
def to_proto(self) -> timestamp_pb2.Timestamp:
|
24 |
+
return timestamp_pb2.Timestamp(nanos=self.nanos)
|
25 |
+
|
26 |
+
|
27 |
+
class DomainProtocolTest(unittest.TestCase):
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def setUpClass(cls) -> None:
|
31 |
+
cls.timestamp_d = TimestampTestD(nanos=1)
|
32 |
+
cls.timestamp_proto = timestamp_pb2.Timestamp(nanos=1)
|
33 |
+
|
34 |
+
def test_proto_roundtrip(self):
|
35 |
+
proto = self.timestamp_d.to_proto()
|
36 |
+
domain_from_proto = TimestampTestD.from_proto(proto)
|
37 |
+
self.assertEqual(self.timestamp_d, domain_from_proto)
|
38 |
+
|
39 |
+
def test_json_roundtrip(self):
|
40 |
+
json_str = self.timestamp_d.to_json()
|
41 |
+
domain_from_json = TimestampTestD.from_json(json_str)
|
42 |
+
self.assertEqual(self.timestamp_d, domain_from_json)
|
43 |
+
|
44 |
+
def test_from_proto_empty_fail(self):
|
45 |
+
empty_proto = timestamp_pb2.Timestamp()
|
46 |
+
with self.assertRaises(ProtoDeserializationError):
|
47 |
+
TimestampTestD.from_proto(empty_proto)
|
48 |
+
|
49 |
+
def test_from_proto_empty_allowed_flag(self):
|
50 |
+
empty_proto = timestamp_pb2.Timestamp()
|
51 |
+
domain_from_proto = TimestampTestD.from_proto(empty_proto, allow_empty=True)
|
52 |
+
self.assertEqual(TimestampTestD(nanos=0), domain_from_proto)
|
53 |
+
|
54 |
+
def test_validate_proto_not_empty(self):
|
55 |
+
empty_proto = timestamp_pb2.Timestamp()
|
56 |
+
with self.assertRaises(ValueError):
|
57 |
+
TimestampTestD.validate_proto_not_empty(empty_proto)
|
58 |
+
|
59 |
+
def test_is_empty(self):
|
60 |
+
empty_proto = timestamp_pb2.Timestamp()
|
61 |
+
self.assertTrue(TimestampTestD.is_empty(empty_proto))
|
62 |
+
|
63 |
+
def test_message_cls(self):
|
64 |
+
self.assertEqual(timestamp_pb2.Timestamp, TimestampTestD.message_cls())
|
65 |
+
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
logging.basicConfig(level=logging.INFO)
|
69 |
+
unittest.main()
|
domain/entity_d.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import dataclasses
|
4 |
+
import hashlib
|
5 |
+
import logging
|
6 |
+
from typing import TypeAlias, Union
|
7 |
+
|
8 |
+
import proto.entity_pb2 as entity_pb2
|
9 |
+
|
10 |
+
from domain.domain_protocol import DomainProtocol
|
11 |
+
from utils.dates import parse_date
|
12 |
+
|
13 |
+
Neo4jDict: TypeAlias = dict[str, Union[str, int, bool, list[str], list[int], list[bool]]]
|
14 |
+
|
15 |
+
|
16 |
+
@dataclasses.dataclass(frozen=True)
|
17 |
+
class EntityD(DomainProtocol[entity_pb2.Entity]):
|
18 |
+
entity_id: str
|
19 |
+
entity_name: str
|
20 |
+
|
21 |
+
@property
|
22 |
+
def id(self) -> str:
|
23 |
+
return self.entity_id
|
24 |
+
|
25 |
+
@classmethod
|
26 |
+
def _from_proto(cls, proto: entity_pb2.Entity) -> EntityD:
|
27 |
+
return EntityD(entity_id=proto.entity_id, entity_name=proto.entity_name)
|
28 |
+
|
29 |
+
def to_proto(self) -> entity_pb2.Entity:
|
30 |
+
return entity_pb2.Entity(entity_id=self.entity_id, entity_name=self.entity_name)
|
31 |
+
|
32 |
+
@property
|
33 |
+
def neo4j_create_cmd(self):
|
34 |
+
# TODO store entity_id?
|
35 |
+
return "MERGE (e:Entity {name: $name}) ON CREATE SET e.pdf_file = $pdf_file"
|
36 |
+
|
37 |
+
@property
|
38 |
+
def neo4j_create_args(self) -> Neo4jDict:
|
39 |
+
return {
|
40 |
+
"name": self.entity_name,
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
@dataclasses.dataclass(frozen=True)
|
45 |
+
class RelationshipD(DomainProtocol[entity_pb2.Relationship]):
|
46 |
+
relationship_id: str
|
47 |
+
start_date: str
|
48 |
+
end_date: str
|
49 |
+
source_text: str
|
50 |
+
predicted_movement: entity_pb2.PredictedMovement
|
51 |
+
|
52 |
+
@property
|
53 |
+
def id(self) -> str:
|
54 |
+
return self.relationship_id
|
55 |
+
|
56 |
+
def __post_init__(self):
|
57 |
+
if self.start_date and self.end_date:
|
58 |
+
start = parse_date(self.start_date)
|
59 |
+
end = parse_date(self.end_date)
|
60 |
+
|
61 |
+
if end < start:
|
62 |
+
logging.warning("end_date %s is before start_date %s",
|
63 |
+
self.end_date,
|
64 |
+
self.start_date)
|
65 |
+
# raise ValueError(f"end_date {self.end_date} is before start_date {self.start_date}")
|
66 |
+
|
67 |
+
@classmethod
|
68 |
+
def _from_proto(cls, proto: entity_pb2.Relationship) -> RelationshipD:
|
69 |
+
return RelationshipD(relationship_id=proto.relationship_id,
|
70 |
+
start_date=proto.start_date,
|
71 |
+
end_date=proto.end_date,
|
72 |
+
source_text=proto.source_text,
|
73 |
+
predicted_movement=proto.predicted_movement)
|
74 |
+
|
75 |
+
def to_proto(self) -> entity_pb2.Relationship:
|
76 |
+
return entity_pb2.Relationship(relationship_id=self.relationship_id,
|
77 |
+
start_date=self.start_date,
|
78 |
+
end_date=self.end_date,
|
79 |
+
source_text=self.source_text,
|
80 |
+
predicted_movement=self.predicted_movement)
|
81 |
+
|
82 |
+
@classmethod
|
83 |
+
def from_string(cls, relationship: str) -> entity_pb2.PredictedMovement:
|
84 |
+
if relationship == "PREDICTED_MOVEMENT_NEUTRAL":
|
85 |
+
return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL
|
86 |
+
elif relationship == "PREDICTED_MOVEMENT_INCREASE":
|
87 |
+
return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_INCREASE
|
88 |
+
elif relationship == "PREDICTED_MOVEMENT_DECREASE":
|
89 |
+
return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_DECREASE
|
90 |
+
else:
|
91 |
+
return entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_UNSPECIFIED
|
92 |
+
|
93 |
+
@property
|
94 |
+
def neo4j_create_cmd(self):
|
95 |
+
return """MATCH (from:Entity {name: $from_name})
|
96 |
+
MATCH (to:Entity {name: $to_name})
|
97 |
+
MERGE (from) -[r:Relationship {start_date: $start_date, end_date: $end_date, predicted_movement: $predicted_movement}]-> (to) ON CREATE SET r.source_text = $source_text, r.pdf_file = $pdf_file"""
|
98 |
+
|
99 |
+
@property
|
100 |
+
def neo4j_create_args(self) -> Neo4jDict:
|
101 |
+
return {
|
102 |
+
"start_date": self.start_date,
|
103 |
+
"end_date": self.end_date,
|
104 |
+
"predicted_movement": entity_pb2.PredictedMovement.Name(self.predicted_movement),
|
105 |
+
"source_text": self.source_text,
|
106 |
+
}
|
107 |
+
|
108 |
+
|
109 |
+
@dataclasses.dataclass(frozen=True)
|
110 |
+
class EntityRelationshipD(DomainProtocol[entity_pb2.EntityRelationship]):
|
111 |
+
from_entity: EntityD
|
112 |
+
relationship: RelationshipD
|
113 |
+
to_entity: EntityD
|
114 |
+
|
115 |
+
@property
|
116 |
+
def id(self) -> str:
|
117 |
+
return hashlib.sha256(self.to_proto().SerializeToString()).hexdigest()
|
118 |
+
|
119 |
+
@classmethod
|
120 |
+
def _from_proto(cls, proto: entity_pb2.EntityRelationship) -> EntityRelationshipD:
|
121 |
+
return EntityRelationshipD(from_entity=EntityD._from_proto(proto.from_entity),
|
122 |
+
relationship=RelationshipD._from_proto(proto.relationship),
|
123 |
+
to_entity=EntityD._from_proto(proto.to_entity))
|
124 |
+
|
125 |
+
def to_proto(self) -> entity_pb2.EntityRelationship:
|
126 |
+
return entity_pb2.EntityRelationship(from_entity=self.from_entity.to_proto(),
|
127 |
+
relationship=self.relationship.to_proto(),
|
128 |
+
to_entity=self.to_entity.to_proto())
|
129 |
+
|
130 |
+
@property
|
131 |
+
def neo4j_create_cmds(self):
|
132 |
+
return [
|
133 |
+
self.from_entity.neo4j_create_cmd,
|
134 |
+
self.to_entity.neo4j_create_cmd,
|
135 |
+
self.relationship.neo4j_create_cmd
|
136 |
+
]
|
137 |
+
|
138 |
+
@property
|
139 |
+
def neo4j_create_args(self) -> list[Neo4jDict]:
|
140 |
+
relationship_args = {
|
141 |
+
**self.relationship.neo4j_create_args,
|
142 |
+
'from_name': self.from_entity.entity_name,
|
143 |
+
'to_name': self.to_entity.entity_name,
|
144 |
+
}
|
145 |
+
|
146 |
+
return [
|
147 |
+
self.from_entity.neo4j_create_args, self.to_entity.neo4j_create_args, relationship_args
|
148 |
+
]
|
149 |
+
|
150 |
+
|
151 |
+
@dataclasses.dataclass(frozen=True)
|
152 |
+
class EntityKnowledgeGraphD(DomainProtocol[entity_pb2.EntityKnowledgeGraph]):
|
153 |
+
entity_relationships: list[EntityRelationshipD]
|
154 |
+
|
155 |
+
@property
|
156 |
+
def id(self) -> str:
|
157 |
+
return hashlib.sha256(self.to_proto().SerializeToString()).hexdigest()
|
158 |
+
|
159 |
+
@classmethod
|
160 |
+
def _from_proto(cls, proto: entity_pb2.EntityKnowledgeGraph) -> EntityKnowledgeGraphD:
|
161 |
+
return EntityKnowledgeGraphD(entity_relationships=[
|
162 |
+
EntityRelationshipD._from_proto(entity_relationship)
|
163 |
+
for entity_relationship in proto.entity_relationships
|
164 |
+
])
|
165 |
+
|
166 |
+
def to_proto(self) -> entity_pb2.EntityKnowledgeGraph:
|
167 |
+
return entity_pb2.EntityKnowledgeGraph(entity_relationships=[
|
168 |
+
entity_relationship.to_proto() for entity_relationship in self.entity_relationships
|
169 |
+
])
|
domain/entity_d_test.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import unittest
|
3 |
+
|
4 |
+
import proto.entity_pb2 as entity_pb2
|
5 |
+
|
6 |
+
import domain.entity_d as entity_d
|
7 |
+
|
8 |
+
|
9 |
+
class EntityDTest(unittest.TestCase):
|
10 |
+
|
11 |
+
@classmethod
|
12 |
+
def setUpClass(cls):
|
13 |
+
cls.entity_d = entity_d.EntityD(entity_id='demoid', entity_name='demo entity')
|
14 |
+
|
15 |
+
def test_proto_roundtrip(self):
|
16 |
+
proto = self.entity_d.to_proto()
|
17 |
+
domain = entity_d.EntityD.from_proto(proto)
|
18 |
+
self.assertEqual(self.entity_d.to_proto(), domain.to_proto())
|
19 |
+
|
20 |
+
|
21 |
+
class RelationshipDTest(unittest.TestCase):
|
22 |
+
|
23 |
+
@classmethod
|
24 |
+
def setUpClass(cls):
|
25 |
+
cls.relationship_d = entity_d.RelationshipD(
|
26 |
+
relationship_id="demoid",
|
27 |
+
start_date="2024-06-01",
|
28 |
+
end_date="2024-06-02",
|
29 |
+
source_text="source text",
|
30 |
+
predicted_movement=entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL)
|
31 |
+
|
32 |
+
def test_proto_roundtrip(self):
|
33 |
+
proto = self.relationship_d.to_proto()
|
34 |
+
domain = entity_d.RelationshipD.from_proto(proto)
|
35 |
+
self.assertEqual(self.relationship_d.to_proto(), domain.to_proto())
|
36 |
+
|
37 |
+
def test_end_date_after_start_date(self):
|
38 |
+
with self.assertRaises(ValueError):
|
39 |
+
_ = entity_d.RelationshipD(
|
40 |
+
relationship_id="demoid",
|
41 |
+
start_date="2024-06-01",
|
42 |
+
end_date="2024-05-02",
|
43 |
+
source_text="source text",
|
44 |
+
predicted_movement=entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL)
|
45 |
+
|
46 |
+
|
47 |
+
class EntityRelationshipDTest(unittest.TestCase):
|
48 |
+
|
49 |
+
@classmethod
|
50 |
+
def setUpClass(cls):
|
51 |
+
cls.from_entity_d = entity_d.EntityD(entity_id='from_id', entity_name='from entity')
|
52 |
+
cls.relationship_d = entity_d.RelationshipD(
|
53 |
+
relationship_id='relationship_id',
|
54 |
+
start_date='2024-06-01',
|
55 |
+
end_date='2024-06-02',
|
56 |
+
source_text='source text',
|
57 |
+
predicted_movement=entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL)
|
58 |
+
cls.to_entity_d = entity_d.EntityD(entity_id='to_id', entity_name='to entity')
|
59 |
+
cls.entity_relationship_d = entity_d.EntityRelationshipD(from_entity=cls.from_entity_d,
|
60 |
+
relationship=cls.relationship_d,
|
61 |
+
to_entity=cls.to_entity_d)
|
62 |
+
|
63 |
+
def test_proto_roundtrip(self):
|
64 |
+
proto = self.entity_relationship_d.to_proto()
|
65 |
+
domain = entity_d.EntityRelationshipD.from_proto(proto)
|
66 |
+
self.assertEqual(self.entity_relationship_d.to_proto(), domain.to_proto())
|
67 |
+
|
68 |
+
|
69 |
+
class EntityKnowledgeGraphDTest(unittest.TestCase):
|
70 |
+
|
71 |
+
@classmethod
|
72 |
+
def setUpClass(cls):
|
73 |
+
cls.from_entity_d = entity_d.EntityD(entity_id='from_id', entity_name='from entity')
|
74 |
+
cls.relationship_d = entity_d.RelationshipD(
|
75 |
+
relationship_id='relationship_id',
|
76 |
+
start_date='2024-06-01',
|
77 |
+
end_date='2024-06-02',
|
78 |
+
source_text='source text',
|
79 |
+
predicted_movement=entity_pb2.PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL)
|
80 |
+
cls.to_entity_d = entity_d.EntityD(entity_id='to_id', entity_name='to entity')
|
81 |
+
cls.entity_relationship_d = entity_d.EntityRelationshipD(from_entity=cls.from_entity_d,
|
82 |
+
relationship=cls.relationship_d,
|
83 |
+
to_entity=cls.to_entity_d)
|
84 |
+
|
85 |
+
cls.entity_knowledge_graph_d = entity_d.EntityKnowledgeGraphD(
|
86 |
+
entity_relationships=[cls.entity_relationship_d for _ in range(2)])
|
87 |
+
|
88 |
+
def test_proto_roundtrip(self):
|
89 |
+
proto = self.entity_knowledge_graph_d.to_proto()
|
90 |
+
domain = entity_d.EntityKnowledgeGraphD.from_proto(proto)
|
91 |
+
self.assertEqual(self.entity_knowledge_graph_d.to_proto(), domain.to_proto())
|
92 |
+
|
93 |
+
|
94 |
+
if __name__ == '__main__':
|
95 |
+
logging.getLogger().setLevel(logging.INFO)
|
96 |
+
unittest.main()
|
extraction_pipeline/base_stage.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import dataclasses
|
4 |
+
import logging
|
5 |
+
from typing import Iterable, Protocol, Set, TypeVar
|
6 |
+
|
7 |
+
from domain.domain_protocol import DomainProtocol
|
8 |
+
|
9 |
+
|
10 |
+
class StageError(Exception):
|
11 |
+
pass
|
12 |
+
|
13 |
+
|
14 |
+
StageInputType = TypeVar("StageInputType", bound=DomainProtocol, contravariant=True)
|
15 |
+
StageOutputType = TypeVar("StageOutputType", bound=DomainProtocol, covariant=True)
|
16 |
+
|
17 |
+
|
18 |
+
@dataclasses.dataclass(frozen=True)
|
19 |
+
class BaseStage(Protocol[StageInputType, StageOutputType]):
|
20 |
+
|
21 |
+
@property
|
22 |
+
def name(self) -> str:
|
23 |
+
return self.__class__.__name__
|
24 |
+
|
25 |
+
def _process_element(self, element: StageInputType) -> Iterable[StageOutputType]:
|
26 |
+
...
|
27 |
+
|
28 |
+
def process_element(self, element: StageInputType) -> Iterable[StageOutputType]:
|
29 |
+
try:
|
30 |
+
logging.info(f"{self.name} processing element {element}")
|
31 |
+
yield from self._process_element(element)
|
32 |
+
except Exception as e:
|
33 |
+
logging.error(e)
|
34 |
+
raise StageError(f"Error processing element {element}") from e
|
35 |
+
|
36 |
+
|
37 |
+
class TransformError(Exception):
|
38 |
+
pass
|
39 |
+
|
40 |
+
|
41 |
+
TransformInputType = TypeVar("TransformInputType", bound=DomainProtocol, contravariant=True)
|
42 |
+
TransformOutputType = TypeVar("TransformOutputType", bound=DomainProtocol, covariant=True)
|
43 |
+
|
44 |
+
|
45 |
+
@dataclasses.dataclass(frozen=True)
|
46 |
+
class BaseTransform(Protocol[TransformInputType, TransformOutputType]):
|
47 |
+
|
48 |
+
@property
|
49 |
+
def name(self) -> str:
|
50 |
+
return self.__class__.__name__
|
51 |
+
|
52 |
+
def _process_collection(
|
53 |
+
self, collection: Iterable[TransformInputType]) -> Iterable[TransformOutputType]:
|
54 |
+
...
|
55 |
+
|
56 |
+
def process_collection(
|
57 |
+
self, collection: Iterable[TransformInputType]) -> Iterable[TransformOutputType]:
|
58 |
+
try:
|
59 |
+
logging.info(f"{self.name} processing collection")
|
60 |
+
yield from self._process_collection(collection)
|
61 |
+
except Exception as e:
|
62 |
+
logging.error(e)
|
63 |
+
raise TransformError(f"Error processing collection") from e
|
extraction_pipeline/document_metadata_extractor/document_metadata_extractor.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
|
3 |
+
from extraction_pipeline.base_stage import BaseStage
|
4 |
+
from domain.chunk_d import DocumentD
|
5 |
+
|
6 |
+
|
7 |
+
@dataclasses.dataclass(frozen=True)
|
8 |
+
class DocumentMetadataExtractor(BaseStage[DocumentD, DocumentD]):
|
9 |
+
...
|
extraction_pipeline/document_metadata_extractor/openai_document_metadata_extractor.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import json
|
3 |
+
from typing import ClassVar, Dict, Iterable, Iterator, List, Optional, Tuple, Union
|
4 |
+
|
5 |
+
import pymupdf
|
6 |
+
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
7 |
+
|
8 |
+
from domain.chunk_d import DocumentD
|
9 |
+
from extraction_pipeline.document_metadata_extractor.document_metadata_extractor import (
|
10 |
+
DocumentMetadataExtractor,)
|
11 |
+
from extraction_pipeline.document_metadata_extractor.prompts import (
|
12 |
+
DOCUMENT_METADATA_PROMPT,)
|
13 |
+
from llm_handler.llm_interface import LLMInterface
|
14 |
+
from llm_handler.openai_handler import ChatModelVersion, OpenAIHandler
|
15 |
+
from utils.dates import parse_date
|
16 |
+
|
17 |
+
|
18 |
+
class CreationDateError(Exception):
|
19 |
+
pass
|
20 |
+
|
21 |
+
|
22 |
+
class AuthorsError(Exception):
|
23 |
+
pass
|
24 |
+
|
25 |
+
|
26 |
+
class OpenAIDocumentMetadataExtractor(DocumentMetadataExtractor):
|
27 |
+
|
28 |
+
_handler: LLMInterface
|
29 |
+
_MODEL_VERSION: ChatModelVersion = ChatModelVersion.GPT_4_O
|
30 |
+
_AUTHORS_KEY: ClassVar[str] = "authors"
|
31 |
+
_PUBLISH_DATE_KEY: ClassVar[str] = "publish_date"
|
32 |
+
_TEMPARATURE: ClassVar[float] = 0.2
|
33 |
+
|
34 |
+
def __init__(self,
|
35 |
+
openai_handler: Optional[LLMInterface] = None,
|
36 |
+
model_version: Optional[ChatModelVersion] = None):
|
37 |
+
self._handler = openai_handler or OpenAIHandler()
|
38 |
+
self._model_version = model_version or self._MODEL_VERSION
|
39 |
+
|
40 |
+
def _validate_text(self, completion_text: Dict[str, Union[str, List[str]]]):
|
41 |
+
if not completion_text.get(self._AUTHORS_KEY):
|
42 |
+
raise AuthorsError("No authors found.")
|
43 |
+
if not completion_text.get(self._PUBLISH_DATE_KEY):
|
44 |
+
raise CreationDateError("No creation date found.")
|
45 |
+
|
46 |
+
publish_date_str: str = str(completion_text.get(self._PUBLISH_DATE_KEY, ""))
|
47 |
+
try:
|
48 |
+
parse_date(publish_date_str)
|
49 |
+
except ValueError as e:
|
50 |
+
raise CreationDateError(
|
51 |
+
f"Failed to parse publish date '{publish_date_str}': {e}") from e
|
52 |
+
|
53 |
+
def _process_element(self, element: DocumentD) -> Iterable[DocumentD]:
|
54 |
+
pdf_document_pages: Iterator = pymupdf.open(element.file_path).pages()
|
55 |
+
first_page_text: str = next(pdf_document_pages).get_text()
|
56 |
+
messages: List[ChatCompletionMessageParam] = [{
|
57 |
+
"role": "system", "content": DOCUMENT_METADATA_PROMPT
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"role": "user",
|
61 |
+
"content": f"Input:\n{first_page_text}"
|
62 |
+
}]
|
63 |
+
completion_text_raw = self._handler.get_chat_completion(
|
64 |
+
messages=messages,
|
65 |
+
model=self._model_version,
|
66 |
+
temperature=self._TEMPARATURE,
|
67 |
+
response_format={"type": "json_object"})
|
68 |
+
completion_text: Dict[str, Union[str, List[str]]] = dict(json.loads(completion_text_raw))
|
69 |
+
self._validate_text(completion_text)
|
70 |
+
authors: str = ", ".join(completion_text.get(self._AUTHORS_KEY, []))
|
71 |
+
publish_date: str = str(completion_text.get(self._PUBLISH_DATE_KEY, ""))
|
72 |
+
|
73 |
+
yield dataclasses.replace(element, authors=authors, publish_date=publish_date)
|
extraction_pipeline/document_metadata_extractor/openai_document_metadata_extractor_test.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import unittest
|
3 |
+
import uuid
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from domain.chunk_d import DocumentD
|
7 |
+
from extraction_pipeline.document_metadata_extractor.openai_document_metadata_extractor import OpenAIDocumentMetadataExtractor, AuthorsError, CreationDateError
|
8 |
+
from llm_handler.mock_llm_handler import MockLLMHandler
|
9 |
+
|
10 |
+
DOCUMENT_METADATA_EXTRACTION_RESPONSE = '''
|
11 |
+
{
|
12 |
+
"authors": ["BofA Global Research", "Michael Hartnett", "Elyas Galou", "Anya Shelekhin", "Myung-Jee Jung"],
|
13 |
+
"publish_date": "2023-04-13"
|
14 |
+
}
|
15 |
+
'''
|
16 |
+
|
17 |
+
|
18 |
+
class TestOpenAIDocumentMetadataExtractor(unittest.TestCase):
|
19 |
+
|
20 |
+
@classmethod
|
21 |
+
def setUpClass(cls) -> None:
|
22 |
+
cls.test_pdf_path = "extraction_pipeline/test_data/test.pdf"
|
23 |
+
cls.start_document_d = DocumentD(file_path=cls.test_pdf_path, authors="", publish_date="")
|
24 |
+
cls.final_document_d = DocumentD(
|
25 |
+
file_path=cls.test_pdf_path,
|
26 |
+
authors=
|
27 |
+
"BofA Global Research, Michael Hartnett, Elyas Galou, Anya Shelekhin, Myung-Jee Jung",
|
28 |
+
publish_date="2023-04-13")
|
29 |
+
cls.openai_publish_details_extractor = OpenAIDocumentMetadataExtractor()
|
30 |
+
|
31 |
+
def test__validate_text_missing_publishers(self):
|
32 |
+
missing_publishers_text = {"publish_date": "2023-12-13"}
|
33 |
+
with self.assertRaises(AuthorsError):
|
34 |
+
self.openai_publish_details_extractor._validate_text(
|
35 |
+
missing_publishers_text) # type: ignore
|
36 |
+
|
37 |
+
def test__validate_text_invalid_date(self):
|
38 |
+
invalid_date_text = {
|
39 |
+
"authors": [
|
40 |
+
"BofA Global Research",
|
41 |
+
"Michael Hartnett",
|
42 |
+
"Elyas Galou",
|
43 |
+
"Anya Shelekhin",
|
44 |
+
"Myung-Jee Jung"
|
45 |
+
],
|
46 |
+
"publish_date": "2-13"
|
47 |
+
}
|
48 |
+
with self.assertRaises(CreationDateError):
|
49 |
+
self.openai_publish_details_extractor._validate_text(invalid_date_text) # type: ignore
|
50 |
+
|
51 |
+
def test__validate_text_valid(self):
|
52 |
+
valid_text = {
|
53 |
+
"authors": [
|
54 |
+
"BofA Global Research",
|
55 |
+
"Michael Hartnett",
|
56 |
+
"Elyas Galou",
|
57 |
+
"Anya Shelekhin",
|
58 |
+
"Myung-Jee Jung"
|
59 |
+
],
|
60 |
+
"publish_date": "2023-04-13"
|
61 |
+
}
|
62 |
+
self.openai_publish_details_extractor._validate_text(valid_text) # type: ignore
|
63 |
+
|
64 |
+
def test__process_element(self):
|
65 |
+
handler = MockLLMHandler(chat_completion=[DOCUMENT_METADATA_EXTRACTION_RESPONSE])
|
66 |
+
openai_publish_details_extractor = OpenAIDocumentMetadataExtractor(handler)
|
67 |
+
pdf_document_d = self.start_document_d
|
68 |
+
output = list(openai_publish_details_extractor._process_element(pdf_document_d))
|
69 |
+
self.assertEqual(output[0], self.final_document_d)
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == '__main__':
|
73 |
+
logging.basicConfig(level=logging.INFO)
|
74 |
+
unittest.main()
|
extraction_pipeline/document_metadata_extractor/prompts.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DOCUMENT_METADATA_PROMPT = '''
|
2 |
+
You are a financial reports expert, tasked with extracting the authors and publishing date
|
3 |
+
from the first page of a financial report document. Given the first page of text from a financial report,
|
4 |
+
you must extract and return the list of authors (which may be only one) and the date of
|
5 |
+
publication in a specific string format of 'YYYY-MM-DD' in the form of a JSON object. The JSON object
|
6 |
+
you return should have the list of authors as the value to the "authors" key.
|
7 |
+
The specifically formatted datetime string should be the value to the "publish_date" key.
|
8 |
+
Note, I will be parsing out the datetime string using the python function `datetime.strptime(your_output['publish_date'], '%Y-%m-%d')`.
|
9 |
+
Do not include anything besides the JSON object in your response.
|
10 |
+
I will be parsing your response using the python code `json.loads(your_output)`.
|
11 |
+
|
12 |
+
Example input:
|
13 |
+
"Trading ideas and investment strategies discussed herein may give rise to significant risk and are \nnot suitable for all investors. Investors should have experience in relevant markets and the financial \nresources to absorb any losses arising from applying these ideas or strategies. \n>> Employed by a non-US affiliate of BofAS and is not registered/qualified as a research analyst \nunder the FINRA rules. \nRefer to "Other Important Disclosures" for information on certain BofA Securities entities that take \nresponsibility for the information herein in particular jurisdictions. \nBofA Securities does and seeks to do business with issuers covered in its research \nreports. As a result, investors should be aware that the firm may have a conflict of \ninterest that could affect the objectivity of this report. Investors should consider this \nreport as only a single factor in making their investment decision. \nRefer to important disclosures on page 11 to 13. \n12544294 \n \nThe Flow Show \n \nSugar Coated Iceberg \n \n \nScores on the Doors: crypto 65.3%, gold 10.1%, stocks 7.7%, HY bonds 4.3%, IG bonds \n4.3%, govt bonds 3.4%, oil 3.7%, cash 1.2%, commodities -0.1%, US dollar -2.0% YTD. \nThe Biggest Picture: everyone's new favorite theme…US dollar debasement; US$ -11% \nsince Sept, gold >$2k, bitcoin >$30k; right secular theme (deficits, debt, geopolitics), \nUS$ in 4th bear market of past 50 years, bullish gold, oil, Euro, international stocks; but \npessimism so high right now, and 200bps Fed cuts priced in from June to election, look \nfor US$ trading bounce after Fed ends hiking cycle May 3rd (Chart 2). \nTale of the Tape: inflation slowing, Fed hiking cycle over, recession expectations \nuniversal, yet UST 30-year can't break below 3.6% (200-dma) because we've already \ntraded 8% to 4% CPI, labor market yet to crack, US govt deficit growing too quickly. \nThe Price is Right: next 6 months it's "recession vs Fed cuts"; best tells who's \nwinning…HY bonds, homebuilders, semiconductors…HYG <73, XHB <70, SOX <2900 \nrecessionary, and if levels hold…it's a no/soft landing. \nChart 2: US dollar has begun 4th bear market of past 50 years \nUS dollar index (DXY) \n \nSource: BofA Global Investment Strategy, Bloomberg. \nBofA GLOBAL RESEARCH \nMore on page 2… \n \n \n \n \n \n \nGold ↑\nYen & Deutsche \nMark ↑\nBRICS & Euro ↑\nGold, \nEuro ↑\n70\n80\n90\n100\n110\n120\n130\n140\n150\n160\n1967\n1973\n1979\n1985\n1991\n1997\n2003\n2009\n2015\n2021\nUS dollar index (DXY)\n13 April 2023 Corrected \n \n \n \n \nInvestment Strategy \nGlobal \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \nMichael Hartnett \nInvestment Strategist \nBofAS \n+1 646 855 1508 \nmichael.hartnett@bofa.com \n \nElyas Galou >> \nInvestment Strategist \nBofASE (France) \n+33 1 8770 0087 \nelyas.galou@bofa.com \n \nAnya Shelekhin \nInvestment Strategist \nBofAS \n+1 646 855 3753 \nanya.shelekhin@bofa.com \n \nMyung-Jee Jung \nInvestment Strategist \nBofAS \n+1 646 855 0389 \nmyung-jee.jung@bofa.com \n \n \n \n \n \n \n \n \n \nChart 1: BofA Bull & Bear Indicator \nStays at 2.3 \n \nSource: BofA Global Investment Strategy \nThe indicator identified above as the BofA Bull & Bear \nIndicator is intended to be an indicative metric only and \nmay not be used for reference purposes or as a measure \nof performance for any financial instrument or contract, \nor otherwise relied upon by third parties for any other \npurpose, without the prior written consent of BofA \nGlobal Research. This indicator was not created to act as \na benchmark. \nBofA GLOBAL RESEARCH \n \n \nExtreme \nBearish\nExtreme \nBullish\n4\n6\n10\n0\n8\n2\nBuy\nSell\n2.3\nAccessible version \n \n \nTimestamp: 13 April 2023 08:19PM EDT\nW \n"
|
14 |
+
|
15 |
+
Example output:
|
16 |
+
"
|
17 |
+
{
|
18 |
+
"authors": ["BofA Global Research", "Michael Hartnett", "Elyas Galou", "Anya Shelekhin", "Myung-Jee Jung"],
|
19 |
+
"publish_date": "2023-04-13"
|
20 |
+
}
|
21 |
+
"
|
22 |
+
'''
|
extraction_pipeline/pdf_process_stage.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
import os
|
3 |
+
from typing import Iterable
|
4 |
+
import pymupdf
|
5 |
+
import re
|
6 |
+
|
7 |
+
from domain.domain_protocol import DomainProtocol
|
8 |
+
from domain.chunk_d import DocumentD, ChunkD
|
9 |
+
import proto.chunk_pb2 as chunk_pb2
|
10 |
+
from extraction_pipeline.base_stage import BaseStage, BaseTransform
|
11 |
+
from storage.domain_dao import CacheDomainDAO
|
12 |
+
|
13 |
+
|
14 |
+
class IDTagger():
|
15 |
+
|
16 |
+
_uuid_to_tag: CacheDomainDAO[ChunkD]
|
17 |
+
|
18 |
+
def __init__(self, cache_save_path: str):
|
19 |
+
self._uuid_to_tag = CacheDomainDAO(f"{cache_save_path}.json", ChunkD)
|
20 |
+
|
21 |
+
def __call__(self, element: ChunkD) -> ChunkD:
|
22 |
+
self._uuid_to_tag.insert([element])
|
23 |
+
return element
|
24 |
+
|
25 |
+
|
26 |
+
class PdfPageChunkerStage(BaseStage[DocumentD, ChunkD]):
|
27 |
+
|
28 |
+
def __init__(self):
|
29 |
+
self._pdf_cache_dao = CacheDomainDAO("pdf_chunk_id_map.json", DocumentD)
|
30 |
+
self.id_tagger = CacheDomainDAO(f"{str.lower((self.__class__.__name__))}_chunk_id_map.json",
|
31 |
+
ChunkD)
|
32 |
+
|
33 |
+
def _process_element(self, element: DocumentD) -> Iterable[ChunkD]:
|
34 |
+
pdf_document: Iterable = pymupdf.open(element.file_path) # type: ignore
|
35 |
+
for page_index, pdf_page in enumerate(pdf_document):
|
36 |
+
page_text: str = pdf_page.get_textpage().extractText() # type: ignore
|
37 |
+
yield self.id_tagger(
|
38 |
+
ChunkD(parent_reference=element,
|
39 |
+
chunk_text=page_text,
|
40 |
+
chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PAGE,
|
41 |
+
chunk_index=page_index,
|
42 |
+
chunk_id=self._pdf_cache_dao.set(element)))
|
43 |
+
|
44 |
+
|
45 |
+
class ParagraphChunkerStage(BaseStage[ChunkD, ChunkD]):
|
46 |
+
|
47 |
+
def __init__(self):
|
48 |
+
self.id_tagger = CacheDomainDAO(f"{str.lower((self.__class__.__name__))}_chunk_id_map.json",
|
49 |
+
ChunkD)
|
50 |
+
|
51 |
+
def _process_element(self, element: ChunkD) -> Iterable[ChunkD]:
|
52 |
+
paragraphs = re.split(r'\n+', element.chunk_text)
|
53 |
+
paragraphs = [para.strip() for para in paragraphs if para.strip()]
|
54 |
+
for chunk_index, paragraph in enumerate(paragraphs):
|
55 |
+
yield self.id_tagger(
|
56 |
+
ChunkD(parent_reference=element.chunk_id,
|
57 |
+
chunk_text=paragraph,
|
58 |
+
chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PARAGRAPH,
|
59 |
+
chunk_index=chunk_index))
|
60 |
+
|
61 |
+
|
62 |
+
class SentenceChunkerStage(BaseStage[ChunkD, ChunkD]):
|
63 |
+
|
64 |
+
def __init__(self):
|
65 |
+
self.id_tagger = CacheDomainDAO(f"{str.lower((self.__class__.__name__))}_chunk_id_map.json",
|
66 |
+
ChunkD)
|
67 |
+
|
68 |
+
def _process_element(self, element: ChunkD) -> Iterable[ChunkD]:
|
69 |
+
sentence_endings = r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|!)(\s|$)'
|
70 |
+
sentences = re.split(sentence_endings, element.chunk_text)
|
71 |
+
sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
|
72 |
+
for chunk_index, sentence in enumerate(sentences):
|
73 |
+
yield self.id_tagger(
|
74 |
+
ChunkD(parent_reference=element.chunk_id,
|
75 |
+
chunk_text=sentence,
|
76 |
+
chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_SENTENCE,
|
77 |
+
chunk_index=chunk_index))
|
78 |
+
|
79 |
+
|
80 |
+
class PdfToSentencesTransform(BaseTransform[DocumentD, ChunkD]):
|
81 |
+
|
82 |
+
_pdf_page_chunker: PdfPageChunkerStage
|
83 |
+
_paragraph_chunker: ParagraphChunkerStage
|
84 |
+
_sentence_chunker: SentenceChunkerStage
|
85 |
+
|
86 |
+
def __init__(self):
|
87 |
+
self._pdf_page_chunker = PdfPageChunkerStage()
|
88 |
+
self._paragraph_chunker = ParagraphChunkerStage()
|
89 |
+
self._sentence_chunker = SentenceChunkerStage()
|
90 |
+
|
91 |
+
def _process_collection(self, collection: Iterable[DocumentD]) -> Iterable[ChunkD]:
|
92 |
+
for pdf_document in collection:
|
93 |
+
pdf_pages = self._pdf_page_chunker.process_element(pdf_document)
|
94 |
+
for pdf_page in pdf_pages:
|
95 |
+
paragraphs = self._paragraph_chunker.process_element(pdf_page)
|
96 |
+
for paragraph in paragraphs:
|
97 |
+
sentences = self._sentence_chunker.process_element(paragraph)
|
98 |
+
yield from sentences
|
99 |
+
|
100 |
+
|
101 |
+
class PdfToParagraphTransform(BaseTransform[DocumentD, ChunkD]):
|
102 |
+
|
103 |
+
_pdf_page_chunker: PdfPageChunkerStage
|
104 |
+
_paragraph_chunker: ParagraphChunkerStage
|
105 |
+
|
106 |
+
def __init__(self):
|
107 |
+
self._pdf_page_chunker = PdfPageChunkerStage()
|
108 |
+
self._paragraph_chunker = ParagraphChunkerStage()
|
109 |
+
|
110 |
+
def _process_collection(self, collection: Iterable[DocumentD]) -> Iterable[ChunkD]:
|
111 |
+
for pdf_document in collection:
|
112 |
+
pdf_pages = self._pdf_page_chunker.process_element(pdf_document)
|
113 |
+
for pdf_page in pdf_pages:
|
114 |
+
paragraphs = self._paragraph_chunker.process_element(pdf_page)
|
115 |
+
for paragraph in paragraphs:
|
116 |
+
yield paragraph
|
117 |
+
|
118 |
+
|
119 |
+
class PdfToPageTransform(BaseTransform[DocumentD, ChunkD]):
|
120 |
+
|
121 |
+
_pdf_page_chunker: PdfPageChunkerStage
|
122 |
+
|
123 |
+
def __init__(self):
|
124 |
+
self._pdf_page_chunker = PdfPageChunkerStage()
|
125 |
+
self._paragraph_chunker = ParagraphChunkerStage()
|
126 |
+
|
127 |
+
def _process_collection(self, collection: Iterable[DocumentD]) -> Iterable[ChunkD]:
|
128 |
+
for pdf_document in collection:
|
129 |
+
pdf_pages = self._pdf_page_chunker.process_element(pdf_document)
|
130 |
+
for pdf_page in pdf_pages:
|
131 |
+
yield pdf_page
|
extraction_pipeline/pdf_process_stage_test.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import logging
|
3 |
+
import uuid
|
4 |
+
|
5 |
+
from domain.chunk_d import DocumentD, ChunkD
|
6 |
+
import proto.chunk_pb2 as chunk_pb2
|
7 |
+
from extraction_pipeline.pdf_process_stage import PdfPageChunkerStage, ParagraphChunkerStage, SentenceChunkerStage, PdfToSentencesTransform
|
8 |
+
|
9 |
+
|
10 |
+
def compare_chunk_without_id(chunk1: ChunkD, chunk2: ChunkD):
|
11 |
+
assert chunk1.parent_reference == chunk2.parent_reference
|
12 |
+
assert chunk1.chunk_text == chunk2.chunk_text
|
13 |
+
assert chunk1.chunk_type == chunk2.chunk_type
|
14 |
+
assert chunk1.chunk_index == chunk2.chunk_index
|
15 |
+
|
16 |
+
|
17 |
+
class PdfPageChunkerStageTest(unittest.TestCase):
|
18 |
+
|
19 |
+
@classmethod
|
20 |
+
def setUpClass(cls):
|
21 |
+
cls.test_pdf_path = "extraction_pipeline/test_data/test.pdf"
|
22 |
+
cls.expected_pdf_text = """This research presents a thorough examination of Large Language Models (LLMs) in the
|
23 |
+
domain of code re-engineering.
|
24 |
+
Our focus on the methodical development and utilization of robust pipelines highlights both the
|
25 |
+
potential and the current challenges of employing LLMs for these complex tasks.
|
26 |
+
"""
|
27 |
+
cls.document_d = DocumentD(file_path="extraction_pipeline/test_data/test.pdf",
|
28 |
+
authors="BofA John Doe Herb Johnson Taylor Mason",
|
29 |
+
publish_date="2021-01-01")
|
30 |
+
|
31 |
+
def test__process_element(self):
|
32 |
+
chunks = list(PdfPageChunkerStage()._process_element(self.document_d))
|
33 |
+
self.assertEqual(len(chunks), 1)
|
34 |
+
expected_chunk = ChunkD(parent_reference=self.document_d,
|
35 |
+
chunk_text=self.expected_pdf_text,
|
36 |
+
chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PAGE,
|
37 |
+
chunk_index=0)
|
38 |
+
compare_chunk_without_id(chunks[0], expected_chunk)
|
39 |
+
|
40 |
+
|
41 |
+
class ParagraphChunkerStageTest(unittest.TestCase):
|
42 |
+
|
43 |
+
@classmethod
|
44 |
+
def setUpClass(cls):
|
45 |
+
cls.document_d = DocumentD(file_path="extraction_pipeline/test_data/test.pdf",
|
46 |
+
authors="BofA John Doe Herb Johnson Taylor Mason",
|
47 |
+
publish_date="2021-01-01")
|
48 |
+
|
49 |
+
def test__process_element(self):
|
50 |
+
chunk = ChunkD(parent_reference=self.document_d,
|
51 |
+
chunk_text="paragraph1\nparagraph2\n\nparagraph3",
|
52 |
+
chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PAGE,
|
53 |
+
chunk_index=0)
|
54 |
+
chunks = list(ParagraphChunkerStage()._process_element(chunk))
|
55 |
+
expected_chunks = [
|
56 |
+
ChunkD(parent_reference=chunk.chunk_id,
|
57 |
+
chunk_text="paragraph1",
|
58 |
+
chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PARAGRAPH,
|
59 |
+
chunk_index=0),
|
60 |
+
ChunkD(parent_reference=chunk.chunk_id,
|
61 |
+
chunk_text="paragraph2",
|
62 |
+
chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PARAGRAPH,
|
63 |
+
chunk_index=1),
|
64 |
+
ChunkD(parent_reference=chunk.chunk_id,
|
65 |
+
chunk_text="paragraph3",
|
66 |
+
chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PARAGRAPH,
|
67 |
+
chunk_index=2)
|
68 |
+
]
|
69 |
+
self.assertEqual(len(chunks), len(expected_chunks))
|
70 |
+
for chunk, expected_chunk in zip(chunks, expected_chunks):
|
71 |
+
compare_chunk_without_id(chunk, expected_chunk)
|
72 |
+
|
73 |
+
|
74 |
+
class SentenceChunkerStageTest(unittest.TestCase):
|
75 |
+
|
76 |
+
def test__process_element(self):
|
77 |
+
chunk = ChunkD(parent_reference=uuid.uuid4(),
|
78 |
+
chunk_text="sentence1. sentence2! sentence3?",
|
79 |
+
chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PARAGRAPH,
|
80 |
+
chunk_index=0)
|
81 |
+
chunks = list(SentenceChunkerStage()._process_element(chunk))
|
82 |
+
expected_chunks = [
|
83 |
+
ChunkD(parent_reference=chunk.chunk_id,
|
84 |
+
chunk_text="sentence1.",
|
85 |
+
chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_SENTENCE,
|
86 |
+
chunk_index=0),
|
87 |
+
ChunkD(parent_reference=chunk.chunk_id,
|
88 |
+
chunk_text="sentence2!",
|
89 |
+
chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_SENTENCE,
|
90 |
+
chunk_index=1),
|
91 |
+
ChunkD(parent_reference=chunk.chunk_id,
|
92 |
+
chunk_text="sentence3?",
|
93 |
+
chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_SENTENCE,
|
94 |
+
chunk_index=2)
|
95 |
+
]
|
96 |
+
self.assertEqual(len(chunks), len(expected_chunks))
|
97 |
+
for chunk, expected_chunk in zip(chunks, expected_chunks):
|
98 |
+
compare_chunk_without_id(chunk, expected_chunk)
|
99 |
+
|
100 |
+
|
101 |
+
if __name__ == '__main__':
|
102 |
+
logging.basicConfig(level=logging.INFO)
|
103 |
+
unittest.main()
|
extraction_pipeline/pdf_to_knowledge_graph_transform.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Iterable, Optional
|
8 |
+
|
9 |
+
from domain.chunk_d import ChunkD, DocumentD
|
10 |
+
from domain.entity_d import EntityKnowledgeGraphD
|
11 |
+
from extraction_pipeline.base_stage import BaseTransform
|
12 |
+
from extraction_pipeline.document_metadata_extractor.openai_document_metadata_extractor import (
|
13 |
+
OpenAIDocumentMetadataExtractor,)
|
14 |
+
from extraction_pipeline.pdf_process_stage import (
|
15 |
+
PdfToPageTransform,
|
16 |
+
PdfToParagraphTransform,
|
17 |
+
PdfToSentencesTransform,
|
18 |
+
)
|
19 |
+
from extraction_pipeline.relationship_extractor.entity_relationship_extractor import (
|
20 |
+
RelationshipExtractor,)
|
21 |
+
from extraction_pipeline.relationship_extractor.openai_relationship_extractor import (
|
22 |
+
OpenAIRelationshipExtractor,)
|
23 |
+
from llm_handler.openai_handler import OpenAIHandler
|
24 |
+
from storage.domain_dao import InMemDomainDAO
|
25 |
+
from storage.neo4j_dao import Neo4jDomainDAO
|
26 |
+
|
27 |
+
|
28 |
+
class PdfToKnowledgeGraphTransform(BaseTransform[DocumentD, EntityKnowledgeGraphD]):
|
29 |
+
_metadata_extractor: OpenAIDocumentMetadataExtractor
|
30 |
+
_pdf_chunker: BaseTransform[DocumentD, ChunkD]
|
31 |
+
_relationship_extractor: RelationshipExtractor
|
32 |
+
|
33 |
+
def __init__(self, pdf_chunker: BaseTransform[DocumentD, ChunkD]):
|
34 |
+
openai_handler = OpenAIHandler()
|
35 |
+
self._metadata_extractor = OpenAIDocumentMetadataExtractor(openai_handler=openai_handler)
|
36 |
+
self._pdf_chunker = pdf_chunker
|
37 |
+
self._relationship_extractor = OpenAIRelationshipExtractor(openai_handler=openai_handler)
|
38 |
+
|
39 |
+
def _process_collection(self,
|
40 |
+
collection: Iterable[DocumentD]) -> Iterable[EntityKnowledgeGraphD]:
|
41 |
+
# produce 1 EntityKnowledgeGraphD per DocumentD
|
42 |
+
for pdf_document in collection:
|
43 |
+
# metadata extractor only yields 1 filled in DocumentD for each input DocumentD
|
44 |
+
document = next(iter(self._metadata_extractor.process_element(pdf_document)))
|
45 |
+
entity_relationships = []
|
46 |
+
for pdf_chunk in self._pdf_chunker.process_collection([document]):
|
47 |
+
for relationship in self._relationship_extractor.process_element(pdf_chunk):
|
48 |
+
entity_relationships.append(relationship)
|
49 |
+
yield EntityKnowledgeGraphD(entity_relationships=entity_relationships)
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == '__main__':
|
53 |
+
## CLI Arguments for running transform as multi-threaded script
|
54 |
+
parser = argparse.ArgumentParser(description='Extract knowledge graphs from PDF files')
|
55 |
+
parser.add_argument('--pdf_folder',
|
56 |
+
type=str,
|
57 |
+
help='Path to folder of PDF files to process',
|
58 |
+
default='')
|
59 |
+
parser.add_argument('--pdf_file',
|
60 |
+
type=str,
|
61 |
+
help='Path to the one PDF file to process',
|
62 |
+
default='')
|
63 |
+
parser.add_argument('--output_json_file',
|
64 |
+
type=str,
|
65 |
+
help='Path for output json file of knowledge graphs',
|
66 |
+
default='./knowledge_graphs.json')
|
67 |
+
parser.add_argument('--log_folder', type=str, help='Path to log folder', default='log')
|
68 |
+
parser.add_argument('--chunk_to',
|
69 |
+
type=str,
|
70 |
+
help='What level to chunk PDF text',
|
71 |
+
default='page',
|
72 |
+
choices=['page', 'paragraph', 'sentence'])
|
73 |
+
parser.add_argument('--verbose', help='Enable DEBUG level logs', action='store_true')
|
74 |
+
parser.add_argument('--upload_to_neo4j',
|
75 |
+
help='Enable uploads to Neo4j database',
|
76 |
+
action='store_true')
|
77 |
+
args = parser.parse_args()
|
78 |
+
|
79 |
+
## Setup logging
|
80 |
+
if args.verbose:
|
81 |
+
log_level = logging.DEBUG
|
82 |
+
else:
|
83 |
+
log_level = logging.INFO
|
84 |
+
os.makedirs(args.log_folder, exist_ok=True)
|
85 |
+
script_name = os.path.splitext(os.path.basename(__file__))[0]
|
86 |
+
logging.basicConfig(level=log_level,
|
87 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
88 |
+
filename=f'{args.log_folder}/{script_name}.log',
|
89 |
+
filemode='w')
|
90 |
+
logger = logging.getLogger(__name__)
|
91 |
+
|
92 |
+
## Setup PDF Chunking
|
93 |
+
if args.chunk_to == 'page':
|
94 |
+
pdf_transform = PdfToPageTransform()
|
95 |
+
elif args.chunk_to == 'paragraph':
|
96 |
+
pdf_transform = PdfToParagraphTransform()
|
97 |
+
elif args.chunk_to == 'sentence':
|
98 |
+
pdf_transform = PdfToSentencesTransform()
|
99 |
+
else:
|
100 |
+
logging.error('Invalid chunking level: %s', args.chunk_to)
|
101 |
+
sys.exit(1)
|
102 |
+
|
103 |
+
## Process PDF Files
|
104 |
+
if not args.pdf_folder and not args.pdf_file:
|
105 |
+
logging.error('No PDF file or folder provided')
|
106 |
+
sys.exit(1)
|
107 |
+
elif args.pdf_folder:
|
108 |
+
pdf_folder = Path(args.pdf_folder)
|
109 |
+
if not pdf_folder.exists():
|
110 |
+
logging.error('PDF folder does not exist: %s', pdf_folder)
|
111 |
+
sys.exit(1)
|
112 |
+
pdf_files = list(pdf_folder.glob('*.pdf'))
|
113 |
+
if len(pdf_files) == 0:
|
114 |
+
logging.warning('No PDF files found in folder: %s', pdf_folder)
|
115 |
+
sys.exit(0)
|
116 |
+
pdfs = [
|
117 |
+
DocumentD(file_path=str(pdf_file), authors='', publish_date='')
|
118 |
+
for pdf_file in pdf_files
|
119 |
+
]
|
120 |
+
else:
|
121 |
+
pdf_file = Path(args.pdf_file)
|
122 |
+
if not pdf_file.exists():
|
123 |
+
logging.error('PDF file does not exist: %s', pdf_file)
|
124 |
+
sys.exit(1)
|
125 |
+
pdfs = [DocumentD(file_path=str(pdf_file), authors='', publish_date='')]
|
126 |
+
|
127 |
+
pdf_to_kg = PdfToKnowledgeGraphTransform(pdf_transform)
|
128 |
+
|
129 |
+
def process_pdf(pdf: DocumentD) -> tuple[Optional[EntityKnowledgeGraphD], str]:
|
130 |
+
pdf_name = Path(pdf.file_path).name
|
131 |
+
try:
|
132 |
+
# process collection yields 1 KG per pdf but we are only
|
133 |
+
# inputing 1 PDF at a time so we just need the 1st element
|
134 |
+
return list(pdf_to_kg.process_collection([pdf]))[0], pdf_name
|
135 |
+
except Exception as e:
|
136 |
+
logging.error(f"Error processing pdf: {e}")
|
137 |
+
return None, pdf_name
|
138 |
+
|
139 |
+
results: list[EntityKnowledgeGraphD] = []
|
140 |
+
with ThreadPoolExecutor() as executor, Neo4jDomainDAO() as dao:
|
141 |
+
futures = [executor.submit(process_pdf, pdf) for pdf in pdfs]
|
142 |
+
|
143 |
+
for future in as_completed(futures):
|
144 |
+
kg, pdf_name = future.result()
|
145 |
+
if not kg:
|
146 |
+
continue
|
147 |
+
results.append(kg)
|
148 |
+
if args.upload_to_neo4j:
|
149 |
+
dao.insert(kg, pdf_name)
|
150 |
+
|
151 |
+
dao = InMemDomainDAO()
|
152 |
+
dao.insert(results)
|
153 |
+
dao.save_to_file(args.output_json_file)
|
extraction_pipeline/relationship_extractor/entity_relationship_extractor.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
|
3 |
+
from extraction_pipeline.base_stage import BaseStage
|
4 |
+
from domain.chunk_d import ChunkD
|
5 |
+
from domain.entity_d import EntityRelationshipD
|
6 |
+
|
7 |
+
|
8 |
+
@dataclasses.dataclass(frozen=True)
|
9 |
+
class RelationshipExtractor(BaseStage[ChunkD, EntityRelationshipD]):
|
10 |
+
...
|
extraction_pipeline/relationship_extractor/openai_relationship_extractor.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import uuid
|
4 |
+
from typing import ClassVar, Dict, Iterator, List, Optional
|
5 |
+
|
6 |
+
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
7 |
+
|
8 |
+
from domain.chunk_d import (ChunkD, DocumentD)
|
9 |
+
from domain.entity_d import (
|
10 |
+
EntityD,
|
11 |
+
EntityRelationshipD,
|
12 |
+
RelationshipD,
|
13 |
+
)
|
14 |
+
from extraction_pipeline.relationship_extractor.entity_relationship_extractor import (
|
15 |
+
RelationshipExtractor,)
|
16 |
+
from extraction_pipeline.relationship_extractor.prompts import (
|
17 |
+
EXTRACT_RELATIONSHIPS_PROMPT,
|
18 |
+
NER_TAGGING_PROMPT,
|
19 |
+
)
|
20 |
+
from llm_handler.llm_interface import LLMInterface
|
21 |
+
from llm_handler.openai_handler import ChatModelVersion, OpenAIHandler
|
22 |
+
|
23 |
+
|
24 |
+
class OpenAIRelationshipExtractor(RelationshipExtractor):
|
25 |
+
|
26 |
+
_handler: LLMInterface
|
27 |
+
_MODEL_VERSION: ChatModelVersion = ChatModelVersion.GPT_4_O
|
28 |
+
_RELATIONSHIP_KEY: ClassVar[str] = "relationships"
|
29 |
+
_ENTITIES_KEY: ClassVar[str] = "entities"
|
30 |
+
_RELATIONSHIPS_TYPES: ClassVar[List[str]] = ["PREDICTION"]
|
31 |
+
_TEMPARATURE: ClassVar[float] = 0.2
|
32 |
+
|
33 |
+
def __init__(self,
|
34 |
+
openai_handler: Optional[LLMInterface] = None,
|
35 |
+
model_version: Optional[ChatModelVersion] = None):
|
36 |
+
self._handler = openai_handler or OpenAIHandler()
|
37 |
+
self._model_version = model_version or self._MODEL_VERSION
|
38 |
+
|
39 |
+
def _extract_entity_names(self, chunk_text: str) -> List[Dict[str, str]]:
|
40 |
+
|
41 |
+
messages: List[ChatCompletionMessageParam] = [{
|
42 |
+
"role": "system", "content": NER_TAGGING_PROMPT
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"role": "user",
|
46 |
+
"content": f"Input:\n{chunk_text}"
|
47 |
+
}]
|
48 |
+
completion_text = self._handler.get_chat_completion(messages=messages,
|
49 |
+
model=self._model_version,
|
50 |
+
temperature=self._TEMPARATURE,
|
51 |
+
response_format={"type": "json_object"})
|
52 |
+
logging.info(f"entity extraction results: {completion_text}")
|
53 |
+
return dict(json.loads(completion_text)).get(self._ENTITIES_KEY, [])
|
54 |
+
|
55 |
+
def _extract_relationships(self, chunk: ChunkD,
|
56 |
+
entity_nodes: List[Dict[str, str]]) -> Iterator[EntityRelationshipD]:
|
57 |
+
if isinstance(chunk.parent_reference, DocumentD):
|
58 |
+
analyst_names: str = chunk.parent_reference.authors
|
59 |
+
document_datetime: str = chunk.parent_reference.publish_date
|
60 |
+
else:
|
61 |
+
raise NotImplementedError("Parent reference is not a DocumentD")
|
62 |
+
messages: List[ChatCompletionMessageParam] = [
|
63 |
+
{
|
64 |
+
"role": "system", "content": EXTRACT_RELATIONSHIPS_PROMPT
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"role":
|
68 |
+
"user",
|
69 |
+
"content":
|
70 |
+
f"Analyst: {analyst_names} \n Date: {document_datetime} \n Text Chunk: {chunk.chunk_text} \n {str(entity_nodes)}"
|
71 |
+
}
|
72 |
+
]
|
73 |
+
|
74 |
+
completion_text = self._handler.get_chat_completion(messages=messages,
|
75 |
+
model=self._model_version,
|
76 |
+
temperature=self._TEMPARATURE,
|
77 |
+
response_format={"type": "json_object"})
|
78 |
+
logging.info(f"relationship results: {completion_text}")
|
79 |
+
completion_text = dict(json.loads(completion_text))
|
80 |
+
relationships: List[Dict[str, Dict[str,
|
81 |
+
str]]] = completion_text.get(self._RELATIONSHIP_KEY, [])
|
82 |
+
for extracted_relationship in relationships:
|
83 |
+
key: str = list(extracted_relationship.keys())[0]
|
84 |
+
if key in self._RELATIONSHIPS_TYPES:
|
85 |
+
relationship_kw_attr: Dict[str, str] = extracted_relationship[key]
|
86 |
+
relationship_d = RelationshipD(
|
87 |
+
relationship_id=str(uuid.uuid4()),
|
88 |
+
start_date=relationship_kw_attr.get("start_date", ""),
|
89 |
+
end_date=relationship_kw_attr.get("end_date", ""),
|
90 |
+
source_text=chunk.chunk_text,
|
91 |
+
predicted_movement=RelationshipD.from_string(
|
92 |
+
relationship_kw_attr.get("predicted_movement", "")))
|
93 |
+
else:
|
94 |
+
raise ValueError(f"No valid relationships in {extracted_relationship}")
|
95 |
+
|
96 |
+
entity_l_d = EntityD(entity_id=str(uuid.uuid4()),
|
97 |
+
entity_name=relationship_kw_attr.get("from_entity", ""))
|
98 |
+
|
99 |
+
entity_r_d = EntityD(entity_id=str(uuid.uuid4()),
|
100 |
+
entity_name=relationship_kw_attr.get("to_entity", ""))
|
101 |
+
|
102 |
+
yield EntityRelationshipD(relationship=relationship_d,
|
103 |
+
from_entity=entity_l_d,
|
104 |
+
to_entity=entity_r_d)
|
105 |
+
|
106 |
+
def _process_element(self, element: ChunkD) -> Iterator[EntityRelationshipD]:
|
107 |
+
entities_text: List[Dict[str, str]] = self._extract_entity_names(element.chunk_text)
|
108 |
+
yield from self._extract_relationships(element, entities_text)
|
extraction_pipeline/relationship_extractor/openai_relationship_extractor_test.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import unittest
|
3 |
+
import uuid
|
4 |
+
|
5 |
+
import proto.chunk_pb2 as chunk_pb2
|
6 |
+
|
7 |
+
from domain.chunk_d import ChunkD, DocumentD
|
8 |
+
from domain.entity_d import RelationshipD
|
9 |
+
from extraction_pipeline.relationship_extractor.openai_relationship_extractor import (
|
10 |
+
OpenAIRelationshipExtractor,)
|
11 |
+
from llm_handler.mock_llm_handler import MockLLMHandler
|
12 |
+
|
13 |
+
ENTITY_EXTRACTION_RESPONSE = '''
|
14 |
+
{
|
15 |
+
"entities": [
|
16 |
+
{"name": "entity 1", "type": "Financial Securities"},
|
17 |
+
{"name": "entity 2", "type": "Financial Securities"}
|
18 |
+
]
|
19 |
+
}
|
20 |
+
'''
|
21 |
+
|
22 |
+
RELATIONSHIP_EXTRACTION_RESPONSE = '''
|
23 |
+
{
|
24 |
+
"relationships": [
|
25 |
+
{
|
26 |
+
"PREDICTION": {
|
27 |
+
"from_entity": "John Doe",
|
28 |
+
"to_entity": "entity 1",
|
29 |
+
"start_date": "2024-06-12",
|
30 |
+
"end_date": "2024-11-05",
|
31 |
+
"predicted_movement": "PREDICTED_MOVEMENT_DECREASE",
|
32 |
+
"description": "200bps Fed cuts priced in from June to election, look for US$ trading bounce after Fed ends hiking cycle May 3rd."
|
33 |
+
}
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"PREDICTION": {
|
37 |
+
"from_entity": "John Doe",
|
38 |
+
"to_entity": "entity 2",
|
39 |
+
"start_date": "2024-06-12",
|
40 |
+
"end_date": "2024-11-05",
|
41 |
+
"predicted_movement": "PREDICTED_MOVEMENT_DECREASE",
|
42 |
+
"description": "200bps Fed cuts priced in from June to election, look for US$ trading bounce after Fed ends hiking cycle May 3rd."
|
43 |
+
}
|
44 |
+
}
|
45 |
+
]
|
46 |
+
}
|
47 |
+
'''
|
48 |
+
|
49 |
+
|
50 |
+
class TestOpenAIRelationshipExtractor(unittest.TestCase):
|
51 |
+
|
52 |
+
def setUp(self) -> None:
|
53 |
+
self.document_d = DocumentD(file_path="Morgan_Stanley_Research_2024.pdf",
|
54 |
+
authors='John Doe',
|
55 |
+
publish_date='2024-06-12')
|
56 |
+
self.chunkd = ChunkD(chunk_text='entity 1 relates to entity 2.',
|
57 |
+
chunk_type=chunk_pb2.ChunkType.CHUNK_TYPE_PAGE,
|
58 |
+
chunk_index=1,
|
59 |
+
parent_reference=self.document_d)
|
60 |
+
|
61 |
+
def test__process_elements(self):
|
62 |
+
chat_responses = [ENTITY_EXTRACTION_RESPONSE, RELATIONSHIP_EXTRACTION_RESPONSE]
|
63 |
+
handler = MockLLMHandler(chat_completion=chat_responses)
|
64 |
+
openai_relationship_extractor = OpenAIRelationshipExtractor(handler)
|
65 |
+
output = list(openai_relationship_extractor._process_element(self.chunkd))
|
66 |
+
extracted_relationship1 = output[0]
|
67 |
+
extracted_relationship2 = output[1]
|
68 |
+
self.assertEqual(len(output), 2)
|
69 |
+
self.assertEqual(extracted_relationship1.from_entity.entity_name, 'John Doe')
|
70 |
+
self.assertEqual(extracted_relationship1.to_entity.entity_name, 'entity 1')
|
71 |
+
self.assertEqual(extracted_relationship1.relationship.start_date, '2024-06-12')
|
72 |
+
self.assertEqual(extracted_relationship1.relationship.end_date, '2024-11-05')
|
73 |
+
self.assertEqual(extracted_relationship1.relationship.predicted_movement, 3)
|
74 |
+
self.assertEqual(extracted_relationship1.relationship.source_text,
|
75 |
+
'entity 1 relates to entity 2.')
|
76 |
+
self.assertEqual(extracted_relationship2.from_entity.entity_name, 'John Doe')
|
77 |
+
self.assertEqual(extracted_relationship2.to_entity.entity_name, 'entity 2')
|
78 |
+
self.assertEqual(extracted_relationship2.relationship.start_date, '2024-06-12')
|
79 |
+
self.assertEqual(extracted_relationship2.relationship.end_date, '2024-11-05')
|
80 |
+
self.assertEqual(extracted_relationship2.relationship.predicted_movement, 3)
|
81 |
+
self.assertEqual(extracted_relationship2.relationship.source_text,
|
82 |
+
'entity 1 relates to entity 2.')
|
83 |
+
|
84 |
+
|
85 |
+
if __name__ == '__main__':
|
86 |
+
logging.basicConfig(level=logging.INFO)
|
87 |
+
unittest.main()
|
extraction_pipeline/relationship_extractor/prompts.py
ADDED
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
NER_TAGGING_PROMPT = '''
|
2 |
+
You are an expert in finance who will be performing Named Entity Recognition (NER)
|
3 |
+
with the high-level objective of analyzing financial documents. The output should
|
4 |
+
be formatted as a JSON object, where the JSON object has a key "entities" that
|
5 |
+
contains a list of the extracted entities. If no entities are found, return an
|
6 |
+
empty JSON object.
|
7 |
+
|
8 |
+
- Finacial Markets (Grouping of financial securities)
|
9 |
+
- Finacial Securities
|
10 |
+
- Finacial Indicators
|
11 |
+
|
12 |
+
Now, given a text chunk, extract the relevant entities and format the output as a JSON object.
|
13 |
+
|
14 |
+
Examples of entities:
|
15 |
+
|
16 |
+
- Financial Markets: "US Stock Market", "US Bond Market", "US Real Estate Market", "Tech Stocks", "Energy Stocks"
|
17 |
+
- Financial Securities: "Apple", "Tesla", "Microsoft", "USD", "Bitcoin", "Gold", "Oil"
|
18 |
+
- Economic Indicators: "Consumer Price Index (CPI)", "Unemployment Rate", "GDP Growth Rate", "Federal Deficit"
|
19 |
+
|
20 |
+
Input:
|
21 |
+
Scores on the Doors: crypto 65.3%, gold 10.1%, stocks 7.7%, HY bonds 4.3%, IG bonds 4.3%, govt bonds 3.4%,
|
22 |
+
oil 3.7%, cash 1.2%, commodities -0.1%, US dollar -2.0% YTD.
|
23 |
+
The Biggest Picture: everyone's new favorite theme...US dollar debasement; US$ -11%
|
24 |
+
since Sept, gold >$2k, bitcoin >$30k; right secular theme (deficits, debt, geopolitics),
|
25 |
+
US$ in 4th bear market of past 50 years, bullish gold, bearish Euro, bullish international
|
26 |
+
stocks; but pessimism so high right now, and 200bps Fed cuts priced in from June to election,
|
27 |
+
look for US$ trading bounce after Fed ends hiking cycle May 3rd (Chart 2).
|
28 |
+
2024 Tale of the Tape: inflation slowing, Fed hiking cycle over, recession
|
29 |
+
expectations universal, yet UST 30-year can't break below 3.6% (200-dma) because we've already traded 8% to 4% CPI,
|
30 |
+
labor market yet to crack, US govt deficit growing too quickly. The Price is Right: next 6 months it's
|
31 |
+
“recession vs Fed cuts”; best tells who's winning...HY bonds, homebuilders, semiconductors...HYG <73, XHB <70,
|
32 |
+
SOX <2900 recessionary, and if levels hold...it's a no/soft landing.
|
33 |
+
|
34 |
+
Output:
|
35 |
+
{
|
36 |
+
"entities": [
|
37 |
+
{
|
38 |
+
"name": "crypto",
|
39 |
+
"type": "Financial Securities",
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"name": "gold",
|
43 |
+
"type": "Financial Securities",
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"name": "stocks",
|
47 |
+
"type": "Financial Securities",
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"name": "HY bonds",
|
51 |
+
"type": "Financial Securities",
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"name": "IG bonds",
|
55 |
+
"type": "Financial Securities",
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"name": "govt bonds",
|
59 |
+
"type": "Financial Securities",
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"name": "oil",
|
63 |
+
"type": "Financial Securities",
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"name": "cash",
|
67 |
+
"type": "Financial Securities",
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"name": "commodities",
|
71 |
+
"type": "Financial Securities",
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"name": "US dollar",
|
75 |
+
"type": "Financial Securities",
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"name": "US$",
|
79 |
+
"type": "Financial Securities",
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"name": "gold",
|
83 |
+
"type": "Financial Securities",
|
84 |
+
},
|
85 |
+
{
|
86 |
+
"name": "bitcoin",
|
87 |
+
"type": "Financial Securities",
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"name": "UST 30-year",
|
91 |
+
"type": "Financial Securities",
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"name": "CPI",
|
95 |
+
"type": "Economic Indicators",
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"name": "labor market",
|
99 |
+
"type": "Economic Indicators",
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"name": "US govt deficit",
|
103 |
+
"type": "Economic Indicators",
|
104 |
+
},
|
105 |
+
{
|
106 |
+
"name": "HY bonds",
|
107 |
+
"type": "Financial Securities",
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"name": "homebuilders",
|
111 |
+
"type": "Financial Markets",
|
112 |
+
},
|
113 |
+
{
|
114 |
+
"name": "semiconductors",
|
115 |
+
"type": "Financial Markets",
|
116 |
+
},
|
117 |
+
{
|
118 |
+
"name": "HYG",
|
119 |
+
"type": "Financial Securities",
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"name": "XHB",
|
123 |
+
"type": "Financial Securities",
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"name": "SOX",
|
127 |
+
"type": "Financial Securities",
|
128 |
+
}
|
129 |
+
]
|
130 |
+
}
|
131 |
+
|
132 |
+
Input:
|
133 |
+
the latest tech stock resurgence of 2024; NASDAQ +15% since January, Apple >$180, Tesla >$250;
|
134 |
+
driven by strong earnings, innovation, and AI advancements, NASDAQ in its 3rd bull market of the past decade,
|
135 |
+
bullish on semiconductors, cloud computing, and cybersecurity stocks; however, optimism
|
136 |
+
is soaring, and 150bps rate hikes anticipated from July to year-end, expect a potential
|
137 |
+
pullback in tech stocks after the next earnings season concludes in August (Chart 3).
|
138 |
+
Anyway, the bipartisan agreement to suspend the US debt ceiling until 2025 with only limited
|
139 |
+
fiscal restraint (roughly 0.1-0.2% of GDP yoy in 2024 and 2025 compared with the baseline)
|
140 |
+
means that we have avoided another potential significant negative shock to demand.
|
141 |
+
On the brightside, the boost to funding that Congress approved late last year for FY23
|
142 |
+
was so large (nearly 10% yoy) that overall discretionary spending is likely to be slightly
|
143 |
+
higher in real terms next year despite the new caps.
|
144 |
+
|
145 |
+
Output:
|
146 |
+
{
|
147 |
+
"entities": [
|
148 |
+
{
|
149 |
+
"name": "NASDAQ",
|
150 |
+
"type": "Financial Securities",
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"name": "Apple",
|
154 |
+
"type": "Financial Securities",
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"name": "Tesla",
|
158 |
+
"type": "Financial Securities",
|
159 |
+
},
|
160 |
+
{
|
161 |
+
"name": "semiconductors",
|
162 |
+
"type": "Financial Market",
|
163 |
+
},
|
164 |
+
{
|
165 |
+
"name": "cloud computing",
|
166 |
+
"type": "Financial Market",
|
167 |
+
},
|
168 |
+
{
|
169 |
+
"name": "cybersecurity stocks",
|
170 |
+
"type": "Financial Market",
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"name": "tech stocks",
|
174 |
+
"type": "Financial Marker",
|
175 |
+
},
|
176 |
+
{
|
177 |
+
"name": "GDP",
|
178 |
+
"type": "Economic Indicators",
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"name": "FY23",,
|
182 |
+
"type": "Economic Indiciators",
|
183 |
+
}
|
184 |
+
]
|
185 |
+
}
|
186 |
+
|
187 |
+
'''
|
188 |
+
|
189 |
+
EXTRACT_RELATIONSHIPS_PROMPT = '''
|
190 |
+
As a finance expert, your role is to extract relationships between a set of entities.
|
191 |
+
You will receive an excerpt from a financial report, the name of the analyst who authored
|
192 |
+
it, the date the report was published, and a list of entities that have already been extracted.
|
193 |
+
Your task is to identify and output relationships between these entities to construct a
|
194 |
+
knowledge graph for financial document analysis.
|
195 |
+
|
196 |
+
Your output should be a JSON object containing a single key, "relationships,"
|
197 |
+
which maps to a list of the extracted relationships. These relationships should be
|
198 |
+
formatted as structured JSON objects. Each relationship type must include "from_entity"
|
199 |
+
and "to_entity" attributes, using previously extracted entities as guidance,
|
200 |
+
to denote the source and target of the relationship, respectively. Additionally, all
|
201 |
+
relationship types should have "start_date" and "end_date" attributes, formatted as
|
202 |
+
"YYYY-MM-DD". The "source_text" attribute within a relationship should directly quote
|
203 |
+
the description of the relationship from the financial report excerpt.
|
204 |
+
|
205 |
+
The relationship types you need to extract include:
|
206 |
+
|
207 |
+
[Relationship Type 1: "PREDICTION"]: This relationship captures any forecasts or predictions the analyst makes
|
208 |
+
regarding trends in financial securities, indicators, or market entities. Search for phrases such as "we forecast,"
|
209 |
+
"is likely that," or "expect the market to." You must also assess the sentiment of the prediction—whether it is neutral,
|
210 |
+
positive (increase), or negative (decrease)—and include this in the relationship. Map the sentiment to one of the following
|
211 |
+
|
212 |
+
- PREDICTED_MOVEMENT_NEUTRAL (neutral sentiment)
|
213 |
+
- PREDICTED_MOVEMENT_INCREASE (positive sentiment)
|
214 |
+
- PREDICTED_MOVEMENT_DECREASE (negative sentiment)
|
215 |
+
|
216 |
+
The "from_entity" should be the analyst's name(s), while the "to_entity" should be the entity the prediction concerns (utilize the previously extracted entities as guidance).
|
217 |
+
Each prediction must have a defined "start_date" and "end_date" indicating the expected timeframe of the prediction. If
|
218 |
+
the prediction is not explicitly time-bound, omit this relationship. Use the report's publishing date as the "start_date"
|
219 |
+
unless the author specifies a different date. The "end_date" should correspond to when the prediction is anticipated to conclude.
|
220 |
+
|
221 |
+
Be sure to include the source excerpt you used to derive each relationship in the "source_text" attribute of the corresponding relationship.
|
222 |
+
Realize that predictions are about the future. Lastly, note that sometimes you will need to infer what entity is being referred to,
|
223 |
+
even if it is not explicitly named in the source sentence.
|
224 |
+
|
225 |
+
Example Input:
|
226 |
+
|
227 |
+
Analyst: Helen Brooks
|
228 |
+
Publish Date: 2024-06-12
|
229 |
+
|
230 |
+
Text Chunk: Labor markets are showing signs of normalization to end 2024; unemployment could drift higher in 2024 while remaining low in historical context.
|
231 |
+
Momentum in the job market is starting to wane with slowing payroll growth and modestly rising unemployment, as well as declining quit rates and temporary help.
|
232 |
+
Increased labor force participation and elevated immigration patterns over the past year have added labor supply, while a shortening work week indicates moderating demand for labor.
|
233 |
+
Considering the challenges to add and retain workers coming out of the pandemic, businesses could be more reluctant than normal to shed workers in a slowing economic environment.
|
234 |
+
Even so, less hiring activity could be enough to cause the unemployment rate to tick up to the mid-4% area by the end of next year due to worker churn.
|
235 |
+
Already slowing wage gains should slow further in the context of a softer labor market.
|
236 |
+
Anyway, the bipartisan agreement to suspend the US debt ceiling until 2025 with only limited
|
237 |
+
fiscal restraint (roughly 0.1-0.2% of GDP yoy in 2024 and 2025 compared with the baseline)
|
238 |
+
means that we have avoided another potential significant negative shock to demand.
|
239 |
+
On the brightside, the boost to funding that Congress approved late last year for FY23
|
240 |
+
was so large (nearly 10% yoy) that overall discretionary spending is likely to be slightly
|
241 |
+
higher in real terms next year despite the new caps.
|
242 |
+
|
243 |
+
{
|
244 |
+
"entities": [
|
245 |
+
{
|
246 |
+
"name": "labor markets",
|
247 |
+
"type": "Economic Market",
|
248 |
+
},
|
249 |
+
{
|
250 |
+
"name": "unemployment",
|
251 |
+
"type": "Economic Indicators",
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"name": "job market",
|
255 |
+
"type": "Economic Market",
|
256 |
+
},
|
257 |
+
{
|
258 |
+
"name": "payroll growth",
|
259 |
+
"type": "Economic Indicators",
|
260 |
+
},
|
261 |
+
{
|
262 |
+
"name": "labor force participation",
|
263 |
+
"type": "Economic Indicators",
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"name": "rising unemployment",
|
267 |
+
"type": "Economic Indicators",
|
268 |
+
}
|
269 |
+
{
|
270 |
+
"name": "immigration patterns",
|
271 |
+
"type": "Economic Indicators",
|
272 |
+
},
|
273 |
+
{
|
274 |
+
"name": "hiring activity",
|
275 |
+
"type": "Economic Indicators",
|
276 |
+
},
|
277 |
+
{
|
278 |
+
"name": "worker churn",
|
279 |
+
"type": "Economic Indicators",
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"name": "unemployment rate",
|
283 |
+
"type": "Economic Indicators",
|
284 |
+
},
|
285 |
+
{
|
286 |
+
"name": "discretionary spending",
|
287 |
+
"type": "Economic Indicators",
|
288 |
+
},
|
289 |
+
{
|
290 |
+
"name": "GDP",
|
291 |
+
"type": "Economic Indicators",
|
292 |
+
}
|
293 |
+
]
|
294 |
+
}
|
295 |
+
|
296 |
+
|
297 |
+
Example Output:
|
298 |
+
{
|
299 |
+
"relationships": [
|
300 |
+
{
|
301 |
+
"PREDICTION": {
|
302 |
+
"from_entity": "Helen Brooks",
|
303 |
+
"to_entity": "labor market",
|
304 |
+
"start_date": "2024-06-12",
|
305 |
+
"end_date": "2024-12-31",
|
306 |
+
"predicted_movement": "PREDICTED_MOVEMENT_NEUTRAL",
|
307 |
+
"source_text": "Labor markets are showing signs of normalization to end 2024;"
|
308 |
+
}
|
309 |
+
},
|
310 |
+
{
|
311 |
+
"PREDICTION": {
|
312 |
+
"from_entity": "Helen Brooks",
|
313 |
+
"to_entity": "unemployment",
|
314 |
+
"startDate": "2024-06-12",
|
315 |
+
"endDate": "2024-12-31",
|
316 |
+
"predicted_movement": "PREDICTED_MOVEMENT_INCREASE",
|
317 |
+
"source_text": "unemployment could drift higher in 2024 while remaining low in historical context."
|
318 |
+
}
|
319 |
+
},
|
320 |
+
{
|
321 |
+
"PREDICTION": {
|
322 |
+
"from_entity": "Helen Brooks",
|
323 |
+
"to_entity": "unemployment rate",
|
324 |
+
"start_date": "2024-06-12",
|
325 |
+
"end_date": "2025-12-31",
|
326 |
+
"predicted_movement": "PREDICTED_MOVEMENT_INCREASE",
|
327 |
+
"source_text": "Even so, less hiring activity could be enough to cause the unemployment rate to tick up to the mid-4% area by the end of next year due to worker churn."
|
328 |
+
}
|
329 |
+
},
|
330 |
+
{
|
331 |
+
"PREDICTION": {
|
332 |
+
"from_entity": "Helen Brooks",
|
333 |
+
"to_entity": "discretionary spending",
|
334 |
+
"start_date": "2024-06-12",
|
335 |
+
"end_date": "2025-12-31",
|
336 |
+
"predicted_movement": "PREDICTED_MOVEMENT_INCREASE",
|
337 |
+
"source_text": "On the brightside, the boost to funding that Congress approved late last year for FY23 was so large (nearly 10% yoy) that overall discretionary spending is likely to be slightly higher in real terms next year despite the new caps."
|
338 |
+
}
|
339 |
+
},
|
340 |
+
]
|
341 |
+
}
|
342 |
+
|
343 |
+
Example Input:
|
344 |
+
|
345 |
+
Analyst: John Doe, Herb Johnson, Taylor Mason
|
346 |
+
|
347 |
+
Publish Date: 2024-01-02
|
348 |
+
|
349 |
+
Text Chunk: Scores on the Doors: crypto 65.3%, gold 10.1%, stocks 7.7%, HY bonds 4.3%, IG bonds 4.3%, govt bonds 3.4%,
|
350 |
+
oil 3.7%, cash 1.2%, commodities -0.1%, US dollar -2.0% YTD. The Biggest Picture: everyone’s new favorite
|
351 |
+
theme...US dollar debasement; US$ -11% since Sept, gold >$2k, bitcoin >$30k; right secular theme (deficits,
|
352 |
+
debt, geopolitics), US$ in 4th bear market of past 50 years, bullish gold, oil, Euro, international stocks; but pessimism
|
353 |
+
so high right now, and 200bps Fed cuts priced in from June to election, look for US$ trading bounce after Fed
|
354 |
+
ends hiking cycle May 3rd (Chart 2). Tale of the Tape: inflation slowing, Fed hiking cycle over, recession
|
355 |
+
expectations universal, yet UST 30-year can’t break below 3.6% (200-dma) because we’ve already traded 8% to 4% CPI,
|
356 |
+
labor market yet to crack, US govt deficit growing too quickly. The Price is Right: next 6 months it’s
|
357 |
+
“recession vs Fed cuts”; best tells who’s winning...HY bonds, homebuilders, semiconductors...HYG <73, XHB <70,
|
358 |
+
SOX <2900 recessionary, and if levels hold...it’s a no/soft landing.
|
359 |
+
|
360 |
+
{
|
361 |
+
"entities": [
|
362 |
+
{
|
363 |
+
"name": "crypto",
|
364 |
+
"type": "Financial Securities",
|
365 |
+
},
|
366 |
+
{
|
367 |
+
"name": "gold",
|
368 |
+
"type": "Financial Securities",
|
369 |
+
},
|
370 |
+
{
|
371 |
+
"name": "stocks",
|
372 |
+
"type": "Financial Securities",
|
373 |
+
},
|
374 |
+
{
|
375 |
+
"name": "HY bonds",
|
376 |
+
"type": "Financial Securities",
|
377 |
+
},
|
378 |
+
{
|
379 |
+
"name": "IG bonds",
|
380 |
+
"type": "Financial Securities",
|
381 |
+
},
|
382 |
+
{
|
383 |
+
"name": "govt bonds",
|
384 |
+
"type": "Financial Securities",
|
385 |
+
},
|
386 |
+
{
|
387 |
+
"name": "oil",
|
388 |
+
"type": "Financial Securities",
|
389 |
+
},
|
390 |
+
{
|
391 |
+
"name": "cash",
|
392 |
+
"type": "Financial Securities",
|
393 |
+
},
|
394 |
+
{
|
395 |
+
"name": "commodities",
|
396 |
+
"type": "Financial Securities",
|
397 |
+
},
|
398 |
+
{
|
399 |
+
"name": "US dollar",
|
400 |
+
"type": "Financial Securities",
|
401 |
+
},
|
402 |
+
{
|
403 |
+
"name": "US$",
|
404 |
+
"type": "Financial Securities",
|
405 |
+
},
|
406 |
+
{
|
407 |
+
"name": "gold",
|
408 |
+
"type": "Financial Securities",
|
409 |
+
},
|
410 |
+
{
|
411 |
+
"name": "bitcoin",
|
412 |
+
"type": "Financial Securities",
|
413 |
+
},
|
414 |
+
{
|
415 |
+
"name": "UST 30-year",
|
416 |
+
"type": "Financial Securities",
|
417 |
+
},
|
418 |
+
{
|
419 |
+
"name": "CPI",
|
420 |
+
"type": "Economic Indicators",
|
421 |
+
},
|
422 |
+
{
|
423 |
+
"name": "labor market",
|
424 |
+
"type": "Economic Indicators",
|
425 |
+
},
|
426 |
+
{
|
427 |
+
"name": "US govt deficit",
|
428 |
+
"type": "Economic Indicators",
|
429 |
+
},
|
430 |
+
{
|
431 |
+
"name": "HY bonds",
|
432 |
+
"type": "Financial Securities",
|
433 |
+
},
|
434 |
+
{
|
435 |
+
"name": "homebuilders",
|
436 |
+
"type": "Financial Markets",
|
437 |
+
},
|
438 |
+
{
|
439 |
+
"name": "semiconductors",
|
440 |
+
"type": "Financial Markets",
|
441 |
+
},
|
442 |
+
{
|
443 |
+
"name": "HYG",
|
444 |
+
"type": "Financial Securities",
|
445 |
+
},
|
446 |
+
{
|
447 |
+
"name": "XHB",
|
448 |
+
"type": "Financial Securities",
|
449 |
+
},
|
450 |
+
{
|
451 |
+
"name": "SOX",
|
452 |
+
"type": "Financial Securities",
|
453 |
+
}
|
454 |
+
]
|
455 |
+
}
|
456 |
+
|
457 |
+
Example Output:
|
458 |
+
|
459 |
+
{
|
460 |
+
"relationships": [
|
461 |
+
{
|
462 |
+
"PREDICTION": {
|
463 |
+
"from_entity": "John Doe, Herb Johnson, Taylor Mason",
|
464 |
+
"to_entity": "US$",
|
465 |
+
"start_date": "2024-06-12",
|
466 |
+
"end_date": "2024-11-05",
|
467 |
+
"predicted_movement": "PREDICTED_MOVEMENT_INCREASE",
|
468 |
+
"source_text": "200bps Fed cuts priced in from June to election, look for US$ trading bounce after Fed ends hiking cycle May 3rd."
|
469 |
+
}
|
470 |
+
"PREDICTION": {
|
471 |
+
"from_entity": "John Doe, Herb Johnson, Taylor Mason",
|
472 |
+
"to_entity": "Fed cuts",
|
473 |
+
"start_date": "2024-06-12",
|
474 |
+
"end_date": "2024-11-05",
|
475 |
+
"predicted_movement": "PREDICTED_MOVEMENT_DECREASE",
|
476 |
+
"source_text": "200bps Fed cuts priced in from June to election, look for US$ trading bounce after Fed ends hiking cycle May 3rd."
|
477 |
+
}
|
478 |
+
},
|
479 |
+
]
|
480 |
+
}
|
481 |
+
'''
|
llm_handler/llm_interface.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
from typing import List, Protocol, Type, Dict
|
3 |
+
import enum
|
4 |
+
|
5 |
+
|
6 |
+
class DefaultEnumMeta(enum.EnumMeta):
|
7 |
+
|
8 |
+
def __call__(cls, value=None, *args, **kwargs) -> DefaultEnumMeta:
|
9 |
+
if value is None:
|
10 |
+
return next(iter(cls))
|
11 |
+
return super().__call__(value, *args, **kwargs) # type: ignore
|
12 |
+
|
13 |
+
|
14 |
+
class LLMInterface(Protocol):
|
15 |
+
|
16 |
+
def get_chat_completion(self, messages: List, model: enum.Enum, temperature: float,
|
17 |
+
**kwargs) -> str:
|
18 |
+
...
|
19 |
+
|
20 |
+
def get_text_embedding(self, input: str, model: enum.Enum) -> List[float]:
|
21 |
+
...
|
llm_handler/mock_llm_handler.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Dict
|
2 |
+
import enum
|
3 |
+
|
4 |
+
import llm_handler.llm_interface as llm_interface
|
5 |
+
|
6 |
+
|
7 |
+
class MockLLMHandler(llm_interface.LLMInterface):
|
8 |
+
|
9 |
+
_chat_completion: Optional[List[str]]
|
10 |
+
_text_embedding: Optional[List[float]]
|
11 |
+
|
12 |
+
def __init__(self,
|
13 |
+
chat_completion: Optional[List[str]] = None,
|
14 |
+
text_embedding: Optional[List[float]] = None):
|
15 |
+
self._chat_completion = chat_completion
|
16 |
+
self._text_embedding = text_embedding
|
17 |
+
|
18 |
+
def get_chat_completion(self,
|
19 |
+
messages: List[Dict],
|
20 |
+
model: Optional[enum.Enum] = None,
|
21 |
+
temperature: float = 0.2,
|
22 |
+
**kwargs) -> str:
|
23 |
+
if not self._chat_completion:
|
24 |
+
raise ValueError(f'_chat_completion not set')
|
25 |
+
return self._chat_completion.pop(0)
|
26 |
+
|
27 |
+
def get_text_embedding(
|
28 |
+
self,
|
29 |
+
input: str,
|
30 |
+
model: Optional[enum.Enum] = None,
|
31 |
+
) -> List[float]:
|
32 |
+
if not self._text_embedding:
|
33 |
+
raise ValueError(f'_text_embedding not set')
|
34 |
+
return self._text_embedding
|
llm_handler/openai_handler.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
import openai
|
3 |
+
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
4 |
+
import os
|
5 |
+
from typing import List, Optional, ClassVar
|
6 |
+
import enum
|
7 |
+
|
8 |
+
from llm_handler.llm_interface import LLMInterface, DefaultEnumMeta
|
9 |
+
|
10 |
+
|
11 |
+
class ChatModelVersion(enum.Enum, metaclass=DefaultEnumMeta):
|
12 |
+
GPT_3_5 = 'gpt-3.5-turbo-1106'
|
13 |
+
GPT_4 = 'gpt-4'
|
14 |
+
GPT_4_TURBO = 'gpt-4-1106-preview'
|
15 |
+
GPT_4_O = 'gpt-4o'
|
16 |
+
|
17 |
+
|
18 |
+
class EmbeddingModelVersion(enum.Enum, metaclass=DefaultEnumMeta):
|
19 |
+
SMALL_3 = 'text-embedding-3-small'
|
20 |
+
ADA_002 = 'text-embedding-ada-002'
|
21 |
+
LARGE = 'text-embedding-3-large'
|
22 |
+
|
23 |
+
|
24 |
+
class OpenAIHandler(LLMInterface):
|
25 |
+
|
26 |
+
_ENV_KEY_NAME: ClassVar[str] = 'OPENAI_API_KEY'
|
27 |
+
_client: openai.Client
|
28 |
+
|
29 |
+
def __init__(self, openai_api_key: Optional[str] = None):
|
30 |
+
_openai_api_key = openai_api_key or os.environ.get(self._ENV_KEY_NAME)
|
31 |
+
if not _openai_api_key:
|
32 |
+
raise ValueError(f'{self._ENV_KEY_NAME} not set')
|
33 |
+
openai.api_key = _openai_api_key
|
34 |
+
self._client = openai.Client()
|
35 |
+
|
36 |
+
def get_chat_completion( # type: ignore
|
37 |
+
self,
|
38 |
+
messages: List[ChatCompletionMessageParam],
|
39 |
+
model: ChatModelVersion = ChatModelVersion.GPT_4_O,
|
40 |
+
temperature: float = 0.2,
|
41 |
+
**kwargs) -> str:
|
42 |
+
response = self._client.chat.completions.create(model=model.value,
|
43 |
+
messages=messages,
|
44 |
+
temperature=temperature,
|
45 |
+
**kwargs)
|
46 |
+
responses: List[str] = []
|
47 |
+
for choice in response.choices:
|
48 |
+
if choice.finish_reason != 'stop' or not choice.message.content:
|
49 |
+
raise ValueError(f'Choice did not complete correctly: {choice}')
|
50 |
+
responses.append(choice.message.content)
|
51 |
+
if len(responses) != 1:
|
52 |
+
raise ValueError(f'Expected one response, got {len(responses)}: {responses}')
|
53 |
+
return responses[0]
|
54 |
+
|
55 |
+
def get_text_embedding( # type: ignore
|
56 |
+
self, input: str, model: EmbeddingModelVersion) -> List[float]:
|
57 |
+
response = self._client.embeddings.create(model=model.value,
|
58 |
+
encoding_format='float',
|
59 |
+
input=input)
|
60 |
+
if not response.data:
|
61 |
+
raise ValueError(f'No embedding in response: {response}')
|
62 |
+
elif len(response.data) != 1:
|
63 |
+
raise ValueError(f'More than one embedding in response: {response}')
|
64 |
+
return response.data[0].embedding
|
llm_handler/openai_handler_test.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List
|
3 |
+
from unittest import mock
|
4 |
+
import pytest
|
5 |
+
import openai
|
6 |
+
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
7 |
+
|
8 |
+
import llm_handler.openai_handler as openai_handler
|
9 |
+
|
10 |
+
|
11 |
+
# Clear openai.api_key before each test
|
12 |
+
@pytest.fixture(autouse=True)
|
13 |
+
def test_env_setup():
|
14 |
+
openai.api_key = None
|
15 |
+
|
16 |
+
|
17 |
+
# This allows us to clear the OPENAI_API_KEY before any test we want
|
18 |
+
@pytest.fixture()
|
19 |
+
def mock_settings_env_vars():
|
20 |
+
with mock.patch.dict(os.environ, clear=True):
|
21 |
+
yield
|
22 |
+
|
23 |
+
|
24 |
+
# Define some constants for our tests
|
25 |
+
_TEST_KEY: str = "TEST_KEY"
|
26 |
+
_TEST_MESSAGE: str = "Hello how are you?"
|
27 |
+
_TEST_MESSAGE_FOR_CHAT_COMPLETION: List[ChatCompletionMessageParam] = [
|
28 |
+
{
|
29 |
+
"role": "system", "content": "You are serving as a en endpoint to verify a test."
|
30 |
+
}, {
|
31 |
+
"role": "user", "content": "Respond with something to help us verify our code is working."
|
32 |
+
}
|
33 |
+
]
|
34 |
+
_TEXT_EMBEDDING_SMALL_3_LENGTH: int = 1536
|
35 |
+
|
36 |
+
|
37 |
+
def test_init_without_key(mock_settings_env_vars):
|
38 |
+
# Ensure key not set as env var and openai.api_key not set
|
39 |
+
with pytest.raises(KeyError):
|
40 |
+
os.environ[openai_handler.OpenAIHandler._ENV_KEY_NAME]
|
41 |
+
assert openai.api_key == None
|
42 |
+
# Ensure proper exception raised when instantiating handler without key as param or env var
|
43 |
+
with pytest.raises(ValueError) as excinfo:
|
44 |
+
openai_handler.OpenAIHandler()
|
45 |
+
assert f'{openai_handler.OpenAIHandler._ENV_KEY_NAME} not set' in str(excinfo.value)
|
46 |
+
assert openai.api_key == None
|
47 |
+
|
48 |
+
|
49 |
+
def test_init_with_key_as_param():
|
50 |
+
# Ensure key is set as env var, key value is unique from _TEST_KEY, and openai.api_key not set
|
51 |
+
assert not os.environ[openai_handler.OpenAIHandler._ENV_KEY_NAME] == _TEST_KEY
|
52 |
+
assert openai.api_key == None
|
53 |
+
# Ensure successful instantiation and openai.api_key properly set
|
54 |
+
handler = openai_handler.OpenAIHandler(openai_api_key=_TEST_KEY)
|
55 |
+
assert isinstance(handler, openai_handler.OpenAIHandler)
|
56 |
+
assert openai.api_key == _TEST_KEY
|
57 |
+
|
58 |
+
|
59 |
+
def test_init_with_key_as_env_var(mock_settings_env_vars):
|
60 |
+
# Ensure key not set as env var and openai.api_key not set
|
61 |
+
with pytest.raises(KeyError):
|
62 |
+
os.environ[openai_handler.OpenAIHandler._ENV_KEY_NAME]
|
63 |
+
assert openai.api_key == None
|
64 |
+
# Set key as env var
|
65 |
+
os.environ.setdefault(openai_handler.OpenAIHandler._ENV_KEY_NAME, _TEST_KEY)
|
66 |
+
# Ensure successful instantiation and openai.api_key properly set
|
67 |
+
handler = openai_handler.OpenAIHandler()
|
68 |
+
assert isinstance(handler, openai_handler.OpenAIHandler)
|
69 |
+
assert openai.api_key == _TEST_KEY
|
proto/chunk_pb2.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
3 |
+
# source: proto/chunk.proto
|
4 |
+
# Protobuf Python Version: 5.26.1
|
5 |
+
"""Generated protocol buffer code."""
|
6 |
+
from google.protobuf import descriptor as _descriptor
|
7 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
8 |
+
from google.protobuf import symbol_database as _symbol_database
|
9 |
+
from google.protobuf.internal import builder as _builder
|
10 |
+
# @@protoc_insertion_point(imports)
|
11 |
+
|
12 |
+
_sym_db = _symbol_database.Default()
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11proto/chunk.proto\x12\x06\x66inllm\"D\n\x08\x44ocument\x12\x11\n\tfile_path\x18\x01 \x01(\t\x12\x0f\n\x07\x61uthors\x18\x02 \x01(\t\x12\x14\n\x0cpublish_date\x18\x03 \x01(\t\"\xbe\x01\n\x05\x43hunk\x12\x10\n\x08\x63hunk_id\x18\x01 \x01(\t\x12\x12\n\nchunk_text\x18\x02 \x01(\t\x12%\n\nchunk_type\x18\x03 \x01(\x0e\x32\x11.finllm.ChunkType\x12\x13\n\x0b\x63hunk_index\x18\x04 \x01(\x03\x12\x19\n\x0fparent_chunk_id\x18\x05 \x01(\tH\x00\x12$\n\x08\x64ocument\x18\x06 \x01(\x0b\x32\x10.finllm.DocumentH\x00\x42\x12\n\x10parent_reference*k\n\tChunkType\x12\x16\n\x12\x43HUNK_TYPE_UNKNOWN\x10\x00\x12\x17\n\x13\x43HUNK_TYPE_SENTENCE\x10\x01\x12\x18\n\x14\x43HUNK_TYPE_PARAGRAPH\x10\x02\x12\x13\n\x0f\x43HUNK_TYPE_PAGE\x10\x03\x62\x06proto3')
|
18 |
+
|
19 |
+
_globals = globals()
|
20 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
21 |
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'proto.chunk_pb2', _globals)
|
22 |
+
if not _descriptor._USE_C_DESCRIPTORS:
|
23 |
+
DESCRIPTOR._loaded_options = None
|
24 |
+
_globals['_CHUNKTYPE']._serialized_start=292
|
25 |
+
_globals['_CHUNKTYPE']._serialized_end=399
|
26 |
+
_globals['_DOCUMENT']._serialized_start=29
|
27 |
+
_globals['_DOCUMENT']._serialized_end=97
|
28 |
+
_globals['_CHUNK']._serialized_start=100
|
29 |
+
_globals['_CHUNK']._serialized_end=290
|
30 |
+
# @@protoc_insertion_point(module_scope)
|
proto/chunk_pb2.pyi
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
|
2 |
+
from google.protobuf import descriptor as _descriptor
|
3 |
+
from google.protobuf import message as _message
|
4 |
+
from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional, Union as _Union
|
5 |
+
|
6 |
+
DESCRIPTOR: _descriptor.FileDescriptor
|
7 |
+
|
8 |
+
class ChunkType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
9 |
+
__slots__ = ()
|
10 |
+
CHUNK_TYPE_UNKNOWN: _ClassVar[ChunkType]
|
11 |
+
CHUNK_TYPE_SENTENCE: _ClassVar[ChunkType]
|
12 |
+
CHUNK_TYPE_PARAGRAPH: _ClassVar[ChunkType]
|
13 |
+
CHUNK_TYPE_PAGE: _ClassVar[ChunkType]
|
14 |
+
CHUNK_TYPE_UNKNOWN: ChunkType
|
15 |
+
CHUNK_TYPE_SENTENCE: ChunkType
|
16 |
+
CHUNK_TYPE_PARAGRAPH: ChunkType
|
17 |
+
CHUNK_TYPE_PAGE: ChunkType
|
18 |
+
|
19 |
+
class Document(_message.Message):
|
20 |
+
__slots__ = ("file_path", "authors", "publish_date")
|
21 |
+
FILE_PATH_FIELD_NUMBER: _ClassVar[int]
|
22 |
+
AUTHORS_FIELD_NUMBER: _ClassVar[int]
|
23 |
+
PUBLISH_DATE_FIELD_NUMBER: _ClassVar[int]
|
24 |
+
file_path: str
|
25 |
+
authors: str
|
26 |
+
publish_date: str
|
27 |
+
def __init__(self, file_path: _Optional[str] = ..., authors: _Optional[str] = ..., publish_date: _Optional[str] = ...) -> None: ...
|
28 |
+
|
29 |
+
class Chunk(_message.Message):
|
30 |
+
__slots__ = ("chunk_id", "chunk_text", "chunk_type", "chunk_index", "parent_chunk_id", "document")
|
31 |
+
CHUNK_ID_FIELD_NUMBER: _ClassVar[int]
|
32 |
+
CHUNK_TEXT_FIELD_NUMBER: _ClassVar[int]
|
33 |
+
CHUNK_TYPE_FIELD_NUMBER: _ClassVar[int]
|
34 |
+
CHUNK_INDEX_FIELD_NUMBER: _ClassVar[int]
|
35 |
+
PARENT_CHUNK_ID_FIELD_NUMBER: _ClassVar[int]
|
36 |
+
DOCUMENT_FIELD_NUMBER: _ClassVar[int]
|
37 |
+
chunk_id: str
|
38 |
+
chunk_text: str
|
39 |
+
chunk_type: ChunkType
|
40 |
+
chunk_index: int
|
41 |
+
parent_chunk_id: str
|
42 |
+
document: Document
|
43 |
+
def __init__(self, chunk_id: _Optional[str] = ..., chunk_text: _Optional[str] = ..., chunk_type: _Optional[_Union[ChunkType, str]] = ..., chunk_index: _Optional[int] = ..., parent_chunk_id: _Optional[str] = ..., document: _Optional[_Union[Document, _Mapping]] = ...) -> None: ...
|
proto/entity_pb2.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
3 |
+
# source: proto/entity.proto
|
4 |
+
# Protobuf Python Version: 5.26.1
|
5 |
+
"""Generated protocol buffer code."""
|
6 |
+
from google.protobuf import descriptor as _descriptor
|
7 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
8 |
+
from google.protobuf import symbol_database as _symbol_database
|
9 |
+
from google.protobuf.internal import builder as _builder
|
10 |
+
# @@protoc_insertion_point(imports)
|
11 |
+
|
12 |
+
_sym_db = _symbol_database.Default()
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12proto/entity.proto\x12\x06\x66inllm\"0\n\x06\x45ntity\x12\x11\n\tentity_id\x18\x01 \x01(\t\x12\x13\n\x0b\x65ntity_name\x18\x02 \x01(\t\"\x99\x01\n\x0cRelationship\x12\x17\n\x0frelationship_id\x18\x01 \x01(\t\x12\x12\n\nstart_date\x18\x02 \x01(\t\x12\x10\n\x08\x65nd_date\x18\x03 \x01(\t\x12\x13\n\x0bsource_text\x18\x04 \x01(\t\x12\x35\n\x12predicted_movement\x18\x05 \x01(\x0e\x32\x19.finllm.PredictedMovement\"\x88\x01\n\x12\x45ntityRelationship\x12#\n\x0b\x66rom_entity\x18\x01 \x01(\x0b\x32\x0e.finllm.Entity\x12*\n\x0crelationship\x18\x02 \x01(\x0b\x32\x14.finllm.Relationship\x12!\n\tto_entity\x18\x03 \x01(\x0b\x32\x0e.finllm.Entity\"P\n\x14\x45ntityKnowledgeGraph\x12\x38\n\x14\x65ntity_relationships\x18\x01 \x03(\x0b\x32\x1a.finllm.EntityRelationship*\x99\x01\n\x11PredictedMovement\x12\"\n\x1ePREDICTED_MOVEMENT_UNSPECIFIED\x10\x00\x12\x1e\n\x1aPREDICTED_MOVEMENT_NEUTRAL\x10\x01\x12\x1f\n\x1bPREDICTED_MOVEMENT_INCREASE\x10\x02\x12\x1f\n\x1bPREDICTED_MOVEMENT_DECREASE\x10\x03\x62\x06proto3')
|
18 |
+
|
19 |
+
_globals = globals()
|
20 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
21 |
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'proto.entity_pb2', _globals)
|
22 |
+
if not _descriptor._USE_C_DESCRIPTORS:
|
23 |
+
DESCRIPTOR._loaded_options = None
|
24 |
+
_globals['_PREDICTEDMOVEMENT']._serialized_start=458
|
25 |
+
_globals['_PREDICTEDMOVEMENT']._serialized_end=611
|
26 |
+
_globals['_ENTITY']._serialized_start=30
|
27 |
+
_globals['_ENTITY']._serialized_end=78
|
28 |
+
_globals['_RELATIONSHIP']._serialized_start=81
|
29 |
+
_globals['_RELATIONSHIP']._serialized_end=234
|
30 |
+
_globals['_ENTITYRELATIONSHIP']._serialized_start=237
|
31 |
+
_globals['_ENTITYRELATIONSHIP']._serialized_end=373
|
32 |
+
_globals['_ENTITYKNOWLEDGEGRAPH']._serialized_start=375
|
33 |
+
_globals['_ENTITYKNOWLEDGEGRAPH']._serialized_end=455
|
34 |
+
# @@protoc_insertion_point(module_scope)
|
proto/entity_pb2.pyi
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from google.protobuf.internal import containers as _containers
|
2 |
+
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
|
3 |
+
from google.protobuf import descriptor as _descriptor
|
4 |
+
from google.protobuf import message as _message
|
5 |
+
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
|
6 |
+
|
7 |
+
DESCRIPTOR: _descriptor.FileDescriptor
|
8 |
+
|
9 |
+
class PredictedMovement(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
10 |
+
__slots__ = ()
|
11 |
+
PREDICTED_MOVEMENT_UNSPECIFIED: _ClassVar[PredictedMovement]
|
12 |
+
PREDICTED_MOVEMENT_NEUTRAL: _ClassVar[PredictedMovement]
|
13 |
+
PREDICTED_MOVEMENT_INCREASE: _ClassVar[PredictedMovement]
|
14 |
+
PREDICTED_MOVEMENT_DECREASE: _ClassVar[PredictedMovement]
|
15 |
+
PREDICTED_MOVEMENT_UNSPECIFIED: PredictedMovement
|
16 |
+
PREDICTED_MOVEMENT_NEUTRAL: PredictedMovement
|
17 |
+
PREDICTED_MOVEMENT_INCREASE: PredictedMovement
|
18 |
+
PREDICTED_MOVEMENT_DECREASE: PredictedMovement
|
19 |
+
|
20 |
+
class Entity(_message.Message):
|
21 |
+
__slots__ = ("entity_id", "entity_name")
|
22 |
+
ENTITY_ID_FIELD_NUMBER: _ClassVar[int]
|
23 |
+
ENTITY_NAME_FIELD_NUMBER: _ClassVar[int]
|
24 |
+
entity_id: str
|
25 |
+
entity_name: str
|
26 |
+
def __init__(self, entity_id: _Optional[str] = ..., entity_name: _Optional[str] = ...) -> None: ...
|
27 |
+
|
28 |
+
class Relationship(_message.Message):
|
29 |
+
__slots__ = ("relationship_id", "start_date", "end_date", "source_text", "predicted_movement")
|
30 |
+
RELATIONSHIP_ID_FIELD_NUMBER: _ClassVar[int]
|
31 |
+
START_DATE_FIELD_NUMBER: _ClassVar[int]
|
32 |
+
END_DATE_FIELD_NUMBER: _ClassVar[int]
|
33 |
+
SOURCE_TEXT_FIELD_NUMBER: _ClassVar[int]
|
34 |
+
PREDICTED_MOVEMENT_FIELD_NUMBER: _ClassVar[int]
|
35 |
+
relationship_id: str
|
36 |
+
start_date: str
|
37 |
+
end_date: str
|
38 |
+
source_text: str
|
39 |
+
predicted_movement: PredictedMovement
|
40 |
+
def __init__(self, relationship_id: _Optional[str] = ..., start_date: _Optional[str] = ..., end_date: _Optional[str] = ..., source_text: _Optional[str] = ..., predicted_movement: _Optional[_Union[PredictedMovement, str]] = ...) -> None: ...
|
41 |
+
|
42 |
+
class EntityRelationship(_message.Message):
|
43 |
+
__slots__ = ("from_entity", "relationship", "to_entity")
|
44 |
+
FROM_ENTITY_FIELD_NUMBER: _ClassVar[int]
|
45 |
+
RELATIONSHIP_FIELD_NUMBER: _ClassVar[int]
|
46 |
+
TO_ENTITY_FIELD_NUMBER: _ClassVar[int]
|
47 |
+
from_entity: Entity
|
48 |
+
relationship: Relationship
|
49 |
+
to_entity: Entity
|
50 |
+
def __init__(self, from_entity: _Optional[_Union[Entity, _Mapping]] = ..., relationship: _Optional[_Union[Relationship, _Mapping]] = ..., to_entity: _Optional[_Union[Entity, _Mapping]] = ...) -> None: ...
|
51 |
+
|
52 |
+
class EntityKnowledgeGraph(_message.Message):
|
53 |
+
__slots__ = ("entity_relationships",)
|
54 |
+
ENTITY_RELATIONSHIPS_FIELD_NUMBER: _ClassVar[int]
|
55 |
+
entity_relationships: _containers.RepeatedCompositeFieldContainer[EntityRelationship]
|
56 |
+
def __init__(self, entity_relationships: _Optional[_Iterable[_Union[EntityRelationship, _Mapping]]] = ...) -> None: ...
|
proto/pdf_document_pb2.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
3 |
+
# source: proto/pdf_document.proto
|
4 |
+
# Protobuf Python Version: 5.26.1
|
5 |
+
"""Generated protocol buffer code."""
|
6 |
+
from google.protobuf import descriptor as _descriptor
|
7 |
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
8 |
+
from google.protobuf import symbol_database as _symbol_database
|
9 |
+
from google.protobuf.internal import builder as _builder
|
10 |
+
# @@protoc_insertion_point(imports)
|
11 |
+
|
12 |
+
_sym_db = _symbol_database.Default()
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x18proto/pdf_document.proto\x12\x06\x66inllm\"!\n\x0cPdfDocumentD\x12\x11\n\tfile_path\x18\x01 \x01(\tb\x06proto3')
|
18 |
+
|
19 |
+
_globals = globals()
|
20 |
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
21 |
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'proto.pdf_document_pb2', _globals)
|
22 |
+
if not _descriptor._USE_C_DESCRIPTORS:
|
23 |
+
DESCRIPTOR._loaded_options = None
|
24 |
+
_globals['_PDFDOCUMENTD']._serialized_start=36
|
25 |
+
_globals['_PDFDOCUMENTD']._serialized_end=69
|
26 |
+
# @@protoc_insertion_point(module_scope)
|
proto/pdf_document_pb2.pyi
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from google.protobuf import descriptor as _descriptor
|
2 |
+
from google.protobuf import message as _message
|
3 |
+
from typing import ClassVar as _ClassVar, Optional as _Optional
|
4 |
+
|
5 |
+
DESCRIPTOR: _descriptor.FileDescriptor
|
6 |
+
|
7 |
+
class PdfDocumentD(_message.Message):
|
8 |
+
__slots__ = ("file_path",)
|
9 |
+
FILE_PATH_FIELD_NUMBER: _ClassVar[int]
|
10 |
+
file_path: str
|
11 |
+
def __init__(self, file_path: _Optional[str] = ...) -> None: ...
|
query_pipeline/evaluation_engine.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from proto.entity_pb2 import PredictedMovement
|
5 |
+
from tabulate import tabulate
|
6 |
+
|
7 |
+
from domain.entity_d import (
|
8 |
+
EntityD,
|
9 |
+
EntityKnowledgeGraphD,
|
10 |
+
EntityRelationshipD,
|
11 |
+
RelationshipD,
|
12 |
+
)
|
13 |
+
from llm_handler.openai_handler import (
|
14 |
+
ChatCompletionMessageParam,
|
15 |
+
ChatModelVersion,
|
16 |
+
OpenAIHandler,
|
17 |
+
)
|
18 |
+
from utils.dates import parse_date
|
19 |
+
|
20 |
+
FUZZY_MATCH_ENTITIES_PROMPT = '''
|
21 |
+
You are an expert in financial analysis. You will be given a two lists of entities. Your task is to output a semantic mapping from the entities in list A to the entities in list B. This means an entity in list A that is semantically similar to an entity in list B should be mapped together. If there is no reasonable semantic match for an entity in list A, output an empty string. Output should be in the format of a JSON object. Ensure the entity keys are in the same order as the input list A.
|
22 |
+
|
23 |
+
Input:
|
24 |
+
List A: ["BofA", "Bank of Amerca Corp" "GDP", "Inflation", "Yen"]
|
25 |
+
List B: ["Bank of America", "inflation", "Gross Domestic Product", "oil"]
|
26 |
+
|
27 |
+
Output:
|
28 |
+
{
|
29 |
+
"BofA": "Bank of America",
|
30 |
+
"Bank of America Corp": "Bank of America",
|
31 |
+
"GDP": "Gross Domestic Product",
|
32 |
+
"Inflation": "inflation",
|
33 |
+
"Yen": ""
|
34 |
+
}
|
35 |
+
'''
|
36 |
+
|
37 |
+
|
38 |
+
class EvaluationEngine:
|
39 |
+
|
40 |
+
_handler: OpenAIHandler
|
41 |
+
_MODEL_VERSION: ChatModelVersion = ChatModelVersion.GPT_4_O
|
42 |
+
_TEMPERATURE: float = 0.2
|
43 |
+
|
44 |
+
def __init__(self,
|
45 |
+
ground_truth_kg: EntityKnowledgeGraphD,
|
46 |
+
openai_handler: Optional[OpenAIHandler] = None,
|
47 |
+
model_version: Optional[ChatModelVersion] = None):
|
48 |
+
self._handler = openai_handler or OpenAIHandler()
|
49 |
+
self._model_version = model_version or self._MODEL_VERSION
|
50 |
+
|
51 |
+
# setup adjacency list representation of ground truth knowledge graph
|
52 |
+
self.kg: dict[str, list[EntityRelationshipD]] = {}
|
53 |
+
for entity_relationship in ground_truth_kg.entity_relationships:
|
54 |
+
to_entity_name = entity_relationship.to_entity.entity_name
|
55 |
+
relationships = self.kg.get(to_entity_name, [])
|
56 |
+
relationships.append(entity_relationship)
|
57 |
+
self.kg[to_entity_name] = relationships
|
58 |
+
|
59 |
+
def _get_thesis_to_gt_entity_map(self, thesis_kg: EntityKnowledgeGraphD) -> dict[str, str]:
|
60 |
+
thesis_entities = []
|
61 |
+
for entity_relationship in thesis_kg.entity_relationships:
|
62 |
+
thesis_entities.append(entity_relationship.to_entity.entity_name)
|
63 |
+
|
64 |
+
# LLM call to return out the matched entities
|
65 |
+
messages: list[ChatCompletionMessageParam] = [
|
66 |
+
{
|
67 |
+
"role": "system", "content": FUZZY_MATCH_ENTITIES_PROMPT
|
68 |
+
}, {
|
69 |
+
"role": "user",
|
70 |
+
"content": f"List A: {thesis_entities}\nList B: {list(self.kg.keys())}"
|
71 |
+
}
|
72 |
+
]
|
73 |
+
completion_text = self._handler.get_chat_completion(messages=messages,
|
74 |
+
model=self._model_version,
|
75 |
+
temperature=self._TEMPERATURE,
|
76 |
+
response_format={"type": "json_object"})
|
77 |
+
|
78 |
+
thesis_to_gt_entity_mapping: dict[str, str] = json.loads(completion_text)
|
79 |
+
|
80 |
+
return thesis_to_gt_entity_mapping
|
81 |
+
|
82 |
+
def _get_relationships_matching_timeperiod(
|
83 |
+
self, gt_kg_to_node: str, relationship: RelationshipD) -> list[EntityRelationshipD]:
|
84 |
+
matching_relationships = []
|
85 |
+
thesis_relationship_start = parse_date(relationship.start_date)
|
86 |
+
thesis_relationship_end = parse_date(relationship.end_date)
|
87 |
+
for gt_relationship in self.kg[gt_kg_to_node]:
|
88 |
+
gt_relationship_start = parse_date(gt_relationship.relationship.start_date)
|
89 |
+
gt_relationship_end = parse_date(gt_relationship.relationship.end_date)
|
90 |
+
|
91 |
+
if (gt_relationship_start <= thesis_relationship_start <= gt_relationship_end and \
|
92 |
+
gt_relationship_start <= thesis_relationship_end <= gt_relationship_end):
|
93 |
+
# thesis relationship timeframe and gt relationship timeframe overlap
|
94 |
+
matching_relationships.append(gt_relationship)
|
95 |
+
|
96 |
+
return matching_relationships
|
97 |
+
|
98 |
+
def evaluate_thesis(
|
99 |
+
self, thesis_kg: EntityKnowledgeGraphD
|
100 |
+
) -> list[tuple[EntityRelationshipD, bool, Optional[EntityRelationshipD]]]:
|
101 |
+
thesis_to_kg_map = self._get_thesis_to_gt_entity_map(thesis_kg)
|
102 |
+
results = []
|
103 |
+
for thesis_relationship in thesis_kg.entity_relationships:
|
104 |
+
thesis_to_node = thesis_relationship.to_entity.entity_name
|
105 |
+
kg_node = thesis_to_kg_map[thesis_to_node]
|
106 |
+
if not kg_node: # no matching entity in KG
|
107 |
+
results.append((thesis_relationship, False, None))
|
108 |
+
continue
|
109 |
+
|
110 |
+
matching_relationships = self._get_relationships_matching_timeperiod(
|
111 |
+
kg_node, thesis_relationship.relationship)
|
112 |
+
|
113 |
+
for entity_relationship in matching_relationships:
|
114 |
+
if entity_relationship.relationship.predicted_movement == thesis_relationship.relationship.predicted_movement:
|
115 |
+
results.append((thesis_relationship, True, entity_relationship))
|
116 |
+
else:
|
117 |
+
results.append((thesis_relationship, False, entity_relationship))
|
118 |
+
if len(matching_relationships) == 0:
|
119 |
+
results.append((thesis_relationship, False, None))
|
120 |
+
|
121 |
+
return results
|
122 |
+
|
123 |
+
def evaluate_and_display_thesis(self, thesis_kg: EntityKnowledgeGraphD):
|
124 |
+
results = self.evaluate_thesis(thesis_kg)
|
125 |
+
|
126 |
+
int_to_str = {1: "Neutral", 2: 'Increase', 3: 'Decrease'}
|
127 |
+
|
128 |
+
headers = ["Thesis Claim", "Supported by KG", "Related KG Relationship"]
|
129 |
+
table_data = []
|
130 |
+
for triplet in results:
|
131 |
+
claim_entity = triplet[0].to_entity.entity_name
|
132 |
+
claim_movement = int_to_str[triplet[0].relationship.predicted_movement]
|
133 |
+
claim = f'{claim_entity} {claim_movement}'
|
134 |
+
if triplet[2]:
|
135 |
+
evidence = int_to_str[triplet[2].relationship.predicted_movement]
|
136 |
+
evidence += f' ({triplet[2].from_entity.entity_name}) '
|
137 |
+
else:
|
138 |
+
evidence = "No evidence in KG"
|
139 |
+
table_data.append([claim, triplet[1], evidence])
|
140 |
+
return tabulate(table_data, tablefmt="html", headers=headers)
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == '__main__':
|
144 |
+
# TODO: extract the cases into pytest tests
|
145 |
+
kg = EntityKnowledgeGraphD(entity_relationships=[
|
146 |
+
EntityRelationshipD(from_entity=EntityD(entity_id='3', entity_name="analyst A"),
|
147 |
+
relationship=RelationshipD(
|
148 |
+
relationship_id='2',
|
149 |
+
start_date='2021-01-01',
|
150 |
+
end_date='2024-12-31',
|
151 |
+
source_text='',
|
152 |
+
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE),
|
153 |
+
to_entity=EntityD(entity_id='1', entity_name="GDP")),
|
154 |
+
EntityRelationshipD(from_entity=EntityD(entity_id='5', entity_name="analyst B"),
|
155 |
+
relationship=RelationshipD(
|
156 |
+
relationship_id='3',
|
157 |
+
start_date='2021-01-01',
|
158 |
+
end_date='2021-12-31',
|
159 |
+
source_text='',
|
160 |
+
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_DECREASE),
|
161 |
+
to_entity=EntityD(entity_id='1', entity_name="GDP")),
|
162 |
+
EntityRelationshipD(from_entity=EntityD(entity_id='7', entity_name="analyst C"),
|
163 |
+
relationship=RelationshipD(
|
164 |
+
relationship_id='4',
|
165 |
+
start_date='2021-01-01',
|
166 |
+
end_date='2021-12-31',
|
167 |
+
source_text='',
|
168 |
+
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL),
|
169 |
+
to_entity=EntityD(entity_id='1', entity_name="GDP")),
|
170 |
+
EntityRelationshipD(from_entity=EntityD(entity_id='9', entity_name="analyst D"),
|
171 |
+
relationship=RelationshipD(
|
172 |
+
relationship_id='5',
|
173 |
+
start_date='2021-01-01',
|
174 |
+
end_date='2021-12-31',
|
175 |
+
source_text='',
|
176 |
+
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL),
|
177 |
+
to_entity=EntityD(entity_id='10', entity_name="USD")),
|
178 |
+
EntityRelationshipD( # out of time range for thesis
|
179 |
+
from_entity=EntityD(entity_id='9', entity_name="analyst E"),
|
180 |
+
relationship=RelationshipD(
|
181 |
+
relationship_id='5',
|
182 |
+
start_date='2024-01-01',
|
183 |
+
end_date='2024-12-31',
|
184 |
+
source_text='',
|
185 |
+
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_NEUTRAL),
|
186 |
+
to_entity=EntityD(entity_id='10', entity_name="USD")),
|
187 |
+
])
|
188 |
+
|
189 |
+
thesis_claims = [
|
190 |
+
EntityRelationshipD(from_entity=EntityD(entity_id='2', entity_name="user"),
|
191 |
+
relationship=RelationshipD(
|
192 |
+
relationship_id='1',
|
193 |
+
start_date='2021-01-01',
|
194 |
+
end_date='2021-12-31',
|
195 |
+
source_text='',
|
196 |
+
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE),
|
197 |
+
to_entity=EntityD(entity_id='1', entity_name="Gross Domestic Product")),
|
198 |
+
EntityRelationshipD(from_entity=EntityD(entity_id='2', entity_name="user"),
|
199 |
+
relationship=RelationshipD(
|
200 |
+
relationship_id='1',
|
201 |
+
start_date='2021-01-01',
|
202 |
+
end_date='2021-12-31',
|
203 |
+
source_text='',
|
204 |
+
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE),
|
205 |
+
to_entity=EntityD(entity_id='1', entity_name="US$")),
|
206 |
+
EntityRelationshipD(from_entity=EntityD(entity_id='2', entity_name="user"),
|
207 |
+
relationship=RelationshipD(
|
208 |
+
relationship_id='1',
|
209 |
+
start_date='2021-01-01',
|
210 |
+
end_date='2021-12-31',
|
211 |
+
source_text='',
|
212 |
+
predicted_movement=PredictedMovement.PREDICTED_MOVEMENT_INCREASE),
|
213 |
+
to_entity=EntityD(entity_id='1', entity_name="Yen")),
|
214 |
+
]
|
215 |
+
thesis = EntityKnowledgeGraphD(entity_relationships=thesis_claims)
|
216 |
+
|
217 |
+
eval_engine = EvaluationEngine(kg)
|
218 |
+
|
219 |
+
eval_engine.evaluate_and_display_thesis(thesis)
|
query_pipeline/thesis_extractor.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from domain.chunk_d import ChunkD
|
2 |
+
from domain.entity_d import EntityKnowledgeGraphD
|
3 |
+
from extraction_pipeline.relationship_extractor.openai_relationship_extractor import (
|
4 |
+
OpenAIRelationshipExtractor,)
|
5 |
+
|
6 |
+
|
7 |
+
class ThesisExtractor:
|
8 |
+
_relationship_extractor: OpenAIRelationshipExtractor
|
9 |
+
|
10 |
+
def __init__(self):
|
11 |
+
self._relationship_extractor = OpenAIRelationshipExtractor()
|
12 |
+
|
13 |
+
def extract_relationships(self, thesis: ChunkD) -> EntityKnowledgeGraphD:
|
14 |
+
# Thesis chunk will have DocumentD for its parent_reference field
|
15 |
+
# which will contain "user" as the author and more importantly
|
16 |
+
# the time of the thesis query for use by the relationship extractor
|
17 |
+
entity_relationships = list(self._relationship_extractor.process_element(thesis))
|
18 |
+
return EntityKnowledgeGraphD(entity_relationships=entity_relationships)
|
requirements.txt
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
annotated-types==0.7.0
|
2 |
+
anyio==4.4.0
|
3 |
+
certifi==2024.6.2
|
4 |
+
distro==1.9.0
|
5 |
+
exceptiongroup==1.2.1
|
6 |
+
grpcio==1.64.1
|
7 |
+
grpcio-tools==1.64.1
|
8 |
+
h11==0.14.0
|
9 |
+
httpcore==1.0.5
|
10 |
+
httpx==0.27.0
|
11 |
+
idna==3.7
|
12 |
+
importlib_metadata==7.1.0
|
13 |
+
iniconfig==2.0.0
|
14 |
+
neo4j==5.21.0
|
15 |
+
nodeenv==1.9.1
|
16 |
+
numpy==1.26.4
|
17 |
+
openai==1.33.0
|
18 |
+
packaging==24.0
|
19 |
+
pandas==2.2.2
|
20 |
+
platformdirs==4.2.2
|
21 |
+
pluggy==1.5.0
|
22 |
+
protobuf==5.27.1
|
23 |
+
pydantic==2.7.3
|
24 |
+
pydantic_core==2.18.4
|
25 |
+
PyMuPDF==1.24.5
|
26 |
+
PyMuPDFb==1.24.3
|
27 |
+
pyright==1.1.366
|
28 |
+
pytest==8.2.2
|
29 |
+
python-dateutil==2.9.0.post0
|
30 |
+
pytz==2024.1
|
31 |
+
six==1.16.0
|
32 |
+
sniffio==1.3.1
|
33 |
+
tabulate==0.9.0
|
34 |
+
tomli==2.0.1
|
35 |
+
tqdm==4.66.4
|
36 |
+
typing_extensions==4.12.2
|
37 |
+
tzdata==2024.1
|
38 |
+
yapf==0.40.2
|
39 |
+
zipp==3.19.2
|
storage/domain_dao.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
from typing import List, Type, Protocol, TypeVar, Dict, Set
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
import uuid
|
6 |
+
|
7 |
+
from domain.domain_protocol import DomainProtocol
|
8 |
+
|
9 |
+
DomainT = TypeVar('DomainT', bound=DomainProtocol)
|
10 |
+
MAP_BIN = os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))), ".bin", "maps")
|
11 |
+
|
12 |
+
|
13 |
+
class DomainDAO(Protocol[DomainT]):
|
14 |
+
|
15 |
+
def insert(self, domain_objs: List[DomainT]):
|
16 |
+
...
|
17 |
+
|
18 |
+
def read_by_id(self, domain_id: str) -> DomainT:
|
19 |
+
...
|
20 |
+
|
21 |
+
def read_all(self) -> Set[DomainT]:
|
22 |
+
...
|
23 |
+
|
24 |
+
|
25 |
+
class InMemDomainDAO(DomainDAO[DomainT]):
|
26 |
+
|
27 |
+
_id_to_domain_obj: Dict[str, DomainT]
|
28 |
+
|
29 |
+
def __init__(self):
|
30 |
+
self._id_to_domain_obj = {}
|
31 |
+
|
32 |
+
def insert(self, domain_objs: List[DomainT]):
|
33 |
+
new_id_to_domain_obj = {domain_obj.id: domain_obj for domain_obj in domain_objs}
|
34 |
+
if len(new_id_to_domain_obj) != len(domain_objs):
|
35 |
+
raise ValueError("Duplicate IDs exist within incoming domain_objs")
|
36 |
+
if duplicate_ids := set(new_id_to_domain_obj.keys()) & set(self._id_to_domain_obj.keys()):
|
37 |
+
raise ValueError(f"Duplicate ids exist in DB: {duplicate_ids}")
|
38 |
+
self._id_to_domain_obj.update(new_id_to_domain_obj)
|
39 |
+
|
40 |
+
def read_by_id(self, domain_id: str) -> DomainT:
|
41 |
+
if domain_obj := self._id_to_domain_obj.get(domain_id):
|
42 |
+
return domain_obj
|
43 |
+
raise ValueError(f"Domain obj with id {domain_id} not found")
|
44 |
+
|
45 |
+
def read_all(self) -> Set[DomainT]:
|
46 |
+
return set(self._id_to_domain_obj.values())
|
47 |
+
|
48 |
+
@classmethod
|
49 |
+
def load_from_file(cls, file_path: str, domain_cls: Type[DomainT]) -> InMemDomainDAO[DomainT]:
|
50 |
+
if not os.path.isfile(file_path):
|
51 |
+
raise ValueError(f"File not found: {file_path}")
|
52 |
+
with open(file_path, 'r') as f:
|
53 |
+
domain_objs = [domain_cls.from_json(line) for line in f]
|
54 |
+
dao = cls()
|
55 |
+
dao.insert(domain_objs)
|
56 |
+
return dao
|
57 |
+
|
58 |
+
def save_to_file(self, file_path: str):
|
59 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
60 |
+
domain_jsons = [domain_obj.to_json() for domain_obj in self._id_to_domain_obj.values()]
|
61 |
+
with open(file_path, 'w') as f:
|
62 |
+
f.write('\n'.join(domain_jsons) + '\n')
|
63 |
+
|
64 |
+
|
65 |
+
class CacheDomainDAO(DomainDAO[DomainT]):
|
66 |
+
|
67 |
+
_id_to_domain_obj: Dict[str, DomainT]
|
68 |
+
_save_path: str
|
69 |
+
|
70 |
+
def __init__(self, save_path: str, domain_cls: Type[DomainT]):
|
71 |
+
self._id_to_domain_obj = {}
|
72 |
+
self._save_path = os.path.join(MAP_BIN, save_path)
|
73 |
+
self._load_cache(domain_cls)
|
74 |
+
|
75 |
+
def __enter__(self):
|
76 |
+
return self
|
77 |
+
|
78 |
+
def __call__(self, element: DomainT) -> DomainT:
|
79 |
+
self.insert([element])
|
80 |
+
return element
|
81 |
+
|
82 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
83 |
+
self._save_cache()
|
84 |
+
|
85 |
+
def set(self, element: DomainT) -> uuid.UUID:
|
86 |
+
id = uuid.uuid4()
|
87 |
+
self._id_to_domain_obj[str(id)] = element
|
88 |
+
self._save_cache()
|
89 |
+
return id
|
90 |
+
|
91 |
+
def _save_cache(self):
|
92 |
+
os.makedirs(MAP_BIN, exist_ok=True)
|
93 |
+
cache = {}
|
94 |
+
if os.path.isfile(self._save_path):
|
95 |
+
with open(self._save_path, 'r') as f:
|
96 |
+
cache = json.load(f)
|
97 |
+
domain_json_map = {
|
98 |
+
id: domain_obj.to_json()
|
99 |
+
for id, domain_obj in self._id_to_domain_obj.items()
|
100 |
+
}
|
101 |
+
cache.update(domain_json_map)
|
102 |
+
with open(self._save_path, 'w') as f:
|
103 |
+
json.dump(cache, f, indent=4)
|
104 |
+
|
105 |
+
def _load_cache(self, domain_cls: Type[DomainT]):
|
106 |
+
if not os.path.isfile(self._save_path):
|
107 |
+
return
|
108 |
+
with open(self._save_path, 'r') as f:
|
109 |
+
domain_json_map = json.load(f)
|
110 |
+
for id, domain_json in domain_json_map.items():
|
111 |
+
self._id_to_domain_obj[id] = domain_cls.from_json(domain_json)
|
112 |
+
|
113 |
+
def read_by_id(self, domain_id: str) -> DomainT:
|
114 |
+
if domain_obj := self._id_to_domain_obj.get(domain_id):
|
115 |
+
return domain_obj
|
116 |
+
raise ValueError(f"Domain obj with id {domain_id} not found")
|
117 |
+
|
118 |
+
def read_all(self) -> Set[DomainT]:
|
119 |
+
return set(self._id_to_domain_obj.values())
|
120 |
+
|
121 |
+
def insert(self, domain_objs: List[DomainT]):
|
122 |
+
new_id_to_domain_obj = {domain_obj.id: domain_obj for domain_obj in domain_objs}
|
123 |
+
if len(new_id_to_domain_obj) != len(domain_objs):
|
124 |
+
raise ValueError("Duplicate IDs exist within incoming domain_objs")
|
125 |
+
if duplicate_ids := set(new_id_to_domain_obj.keys()) & set(self._id_to_domain_obj.keys()):
|
126 |
+
raise ValueError(f"Duplicate ids exist in DB: {duplicate_ids}")
|
127 |
+
self._id_to_domain_obj.update(new_id_to_domain_obj)
|
128 |
+
self._save_cache()
|
storage/domain_dao_test.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import unittest
|
3 |
+
import uuid
|
4 |
+
import os
|
5 |
+
from storage.domain_dao import InMemDomainDAO
|
6 |
+
from domain.domain_protocol_test import TimestampTestD
|
7 |
+
|
8 |
+
|
9 |
+
class InMemDomainDAOTest(unittest.TestCase):
|
10 |
+
|
11 |
+
def setUp(self):
|
12 |
+
self.in_mem_dao: InMemDomainDAO = InMemDomainDAO[TimestampTestD]()
|
13 |
+
self.timestamp_d: TimestampTestD = TimestampTestD(nanos=1)
|
14 |
+
|
15 |
+
def test_insert_domain_obj(self):
|
16 |
+
self.in_mem_dao.insert([self.timestamp_d])
|
17 |
+
expected_map = {self.timestamp_d.id: self.timestamp_d}
|
18 |
+
self.assertEqual(self.in_mem_dao._id_to_domain_obj, expected_map)
|
19 |
+
|
20 |
+
def test_insert_domain_obj_raise_on_duplicate_id_in_db(self):
|
21 |
+
self.in_mem_dao.insert([self.timestamp_d])
|
22 |
+
with self.assertRaises(ValueError) as context:
|
23 |
+
self.in_mem_dao.insert([self.timestamp_d])
|
24 |
+
self.assertIn("Duplicate ids exist in DB", str(context.exception))
|
25 |
+
|
26 |
+
def test_insert_domain_obj_raise_on_duplicate_id_arguements(self):
|
27 |
+
with self.assertRaises(ValueError) as context:
|
28 |
+
self.in_mem_dao.insert([self.timestamp_d, self.timestamp_d])
|
29 |
+
self.assertIn("Duplicate IDs exist within incoming domain_objs", str(context.exception))
|
30 |
+
|
31 |
+
def test_read_by_id(self):
|
32 |
+
self.in_mem_dao.insert([self.timestamp_d])
|
33 |
+
timestamp_d = self.in_mem_dao.read_by_id(self.timestamp_d.id)
|
34 |
+
self.assertEqual(timestamp_d, self.timestamp_d)
|
35 |
+
|
36 |
+
def test_read_by_id_raise_not_found(self):
|
37 |
+
with self.assertRaises(ValueError):
|
38 |
+
self.in_mem_dao.read_by_id(self.timestamp_d.id)
|
39 |
+
|
40 |
+
def test_read_all(self):
|
41 |
+
timestamp_d_b = TimestampTestD(2)
|
42 |
+
self.in_mem_dao.insert([self.timestamp_d, timestamp_d_b])
|
43 |
+
expected_timestamps = {self.timestamp_d, timestamp_d_b}
|
44 |
+
self.assertEqual(expected_timestamps, self.in_mem_dao.read_all())
|
45 |
+
|
46 |
+
def test_load_from_file(self):
|
47 |
+
file_path = f".bin/{uuid.uuid4()}.jsonl"
|
48 |
+
with open(file_path, 'w') as f:
|
49 |
+
f.write(self.timestamp_d.to_json() + '\n')
|
50 |
+
dao = InMemDomainDAO[TimestampTestD].load_from_file(file_path, TimestampTestD)
|
51 |
+
os.remove(file_path)
|
52 |
+
self.assertEqual({self.timestamp_d}, dao.read_all())
|
53 |
+
|
54 |
+
def test_load_from_file_fail_not_found(self):
|
55 |
+
with self.assertRaises(ValueError):
|
56 |
+
_ = InMemDomainDAO[TimestampTestD].load_from_file("file_path", TimestampTestD)
|
57 |
+
|
58 |
+
def test_save_to_file(self):
|
59 |
+
file_path = f".bin/{uuid.uuid4()}.jsonl"
|
60 |
+
self.in_mem_dao.insert([self.timestamp_d])
|
61 |
+
self.in_mem_dao.save_to_file(file_path)
|
62 |
+
created_dao = InMemDomainDAO[TimestampTestD].load_from_file(file_path, TimestampTestD)
|
63 |
+
os.remove(file_path)
|
64 |
+
self.assertEqual(self.in_mem_dao.read_all(), created_dao.read_all())
|
65 |
+
|
66 |
+
|
67 |
+
#TODO: Add test for CacheDomainDAO
|
68 |
+
|
69 |
+
if __name__ == '__main__':
|
70 |
+
logging.basicConfig(level=logging.DEBUG)
|
71 |
+
unittest.main()
|
storage/neo4j_dao.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import neo4j
|
6 |
+
|
7 |
+
from domain.entity_d import (
|
8 |
+
EntityD,
|
9 |
+
EntityKnowledgeGraphD,
|
10 |
+
EntityRelationshipD,
|
11 |
+
RelationshipD,
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
class Neo4jError(Exception):
|
16 |
+
...
|
17 |
+
|
18 |
+
|
19 |
+
class Neo4jDomainDAO:
|
20 |
+
""" To be used with a context manager to ensure the connection is closed after use. """
|
21 |
+
|
22 |
+
def __enter__(self):
|
23 |
+
uri = os.environ.get("NEO4J_URI", "")
|
24 |
+
user = os.environ.get("NEO4J_USER", "")
|
25 |
+
password = os.environ.get("NEO4J_PASSWORD", "")
|
26 |
+
|
27 |
+
if not uri:
|
28 |
+
raise ValueError("NEO4J_URI environment variable not set")
|
29 |
+
if not user:
|
30 |
+
raise ValueError("NEO4J_USER environment variable not set")
|
31 |
+
if not password:
|
32 |
+
raise ValueError("NEO4J_PASSWORD environment variable not set")
|
33 |
+
|
34 |
+
try:
|
35 |
+
self.driver = neo4j.GraphDatabase.driver(uri, auth=(user, password))
|
36 |
+
self.driver.verify_connectivity()
|
37 |
+
except Exception as e:
|
38 |
+
logging.error(f"Failed to connect to Neo4j: {e}")
|
39 |
+
raise Neo4jError("Failed to connect to Neo4j")
|
40 |
+
|
41 |
+
return self
|
42 |
+
|
43 |
+
def insert(self, knowledge_graph: EntityKnowledgeGraphD, pdf_file: str = ""):
|
44 |
+
for entity_relationship in knowledge_graph.entity_relationships:
|
45 |
+
create_cmds = entity_relationship.neo4j_create_cmds
|
46 |
+
create_cmds_args = entity_relationship.neo4j_create_args
|
47 |
+
|
48 |
+
for create_cmd, args in zip(create_cmds, create_cmds_args):
|
49 |
+
args['pdf_file'] = pdf_file
|
50 |
+
try:
|
51 |
+
self.driver.execute_query(
|
52 |
+
create_cmd, # type: ignore
|
53 |
+
parameters_=args, # type: ignore
|
54 |
+
database_='neo4j') # type: ignore
|
55 |
+
except Exception as e:
|
56 |
+
logging.warning(
|
57 |
+
f"Failed to insert entity relationship: {entity_relationship} due to {e}")
|
58 |
+
|
59 |
+
def query(self, query, query_args):
|
60 |
+
return self.driver.execute_query(query, parameters_=query_args,
|
61 |
+
database_='neo4j') # type: ignore
|
62 |
+
|
63 |
+
def get_knowledge_graph(self) -> Optional[EntityKnowledgeGraphD]:
|
64 |
+
records = [] #list[dict[str, Neo4jDict]]
|
65 |
+
try:
|
66 |
+
records, _, _ = self.driver.execute_query("MATCH (from:Entity) -[r:Relationship]-> (to:Entity) RETURN from, properties(r), to", database_='neo4j') # type: ignore
|
67 |
+
except Exception as e:
|
68 |
+
logging.exception(e)
|
69 |
+
return None
|
70 |
+
|
71 |
+
entity_relationships = []
|
72 |
+
for record in records:
|
73 |
+
er_dict = record.data()
|
74 |
+
|
75 |
+
from_args = er_dict['from']
|
76 |
+
from_entity = EntityD(entity_id='', entity_name=from_args['name'])
|
77 |
+
|
78 |
+
to_args = er_dict['to']
|
79 |
+
to_entity = EntityD(entity_id='', entity_name=to_args['name'])
|
80 |
+
|
81 |
+
relationship_args = er_dict['properties(r)']
|
82 |
+
relationship = RelationshipD(relationship_id='',
|
83 |
+
start_date=relationship_args['start_date'],
|
84 |
+
end_date=relationship_args['end_date'],
|
85 |
+
source_text=relationship_args['source_text'],
|
86 |
+
predicted_movement=RelationshipD.from_string(
|
87 |
+
relationship_args['predicted_movement']))
|
88 |
+
|
89 |
+
entity_relationships.append(
|
90 |
+
EntityRelationshipD(from_entity=from_entity,
|
91 |
+
relationship=relationship,
|
92 |
+
to_entity=to_entity))
|
93 |
+
|
94 |
+
return EntityKnowledgeGraphD(entity_relationships=entity_relationships)
|
95 |
+
|
96 |
+
def __exit__(self, exception_type, exception_value, traceback):
|
97 |
+
if traceback:
|
98 |
+
logging.error("Neo4jDomainDAO error: %s | %s | %s",
|
99 |
+
exception_type,
|
100 |
+
exception_value,
|
101 |
+
traceback)
|
102 |
+
self.driver.close()
|
utils/dates.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
|
3 |
+
DATE_FMT = "%Y-%m-%d"
|
4 |
+
|
5 |
+
|
6 |
+
def parse_date(date_str: str) -> datetime.datetime:
|
7 |
+
return datetime.datetime.strptime(date_str, DATE_FMT)
|