File size: 5,332 Bytes
9041389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import json
import logging
import uuid
from typing import ClassVar, Dict, Iterator, List, Optional

from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam

from domain.chunk_d import (ChunkD, DocumentD)
from domain.entity_d import (
    EntityD,
    EntityRelationshipD,
    RelationshipD,
)
from extraction_pipeline.relationship_extractor.entity_relationship_extractor import (
    RelationshipExtractor,)
from extraction_pipeline.relationship_extractor.prompts import (
    EXTRACT_RELATIONSHIPS_PROMPT,
    NER_TAGGING_PROMPT,
)
from llm_handler.llm_interface import LLMInterface
from llm_handler.openai_handler import ChatModelVersion, OpenAIHandler


class OpenAIRelationshipExtractor(RelationshipExtractor):

    _handler: LLMInterface
    _MODEL_VERSION: ChatModelVersion = ChatModelVersion.GPT_4_O
    _RELATIONSHIP_KEY: ClassVar[str] = "relationships"
    _ENTITIES_KEY: ClassVar[str] = "entities"
    _RELATIONSHIPS_TYPES: ClassVar[List[str]] = ["PREDICTION"]
    _TEMPARATURE: ClassVar[float] = 0.2

    def __init__(self,
                 openai_handler: Optional[LLMInterface] = None,
                 model_version: Optional[ChatModelVersion] = None):
        self._handler = openai_handler or OpenAIHandler()
        self._model_version = model_version or self._MODEL_VERSION

    def _extract_entity_names(self, chunk_text: str) -> List[Dict[str, str]]:

        messages: List[ChatCompletionMessageParam] = [{
            "role": "system", "content": NER_TAGGING_PROMPT
        },
                                                      {
                                                          "role": "user",
                                                          "content": f"Input:\n{chunk_text}"
                                                      }]
        completion_text = self._handler.get_chat_completion(messages=messages,
                                                            model=self._model_version,
                                                            temperature=self._TEMPARATURE,
                                                            response_format={"type": "json_object"})
        logging.info(f"entity extraction results: {completion_text}")
        return dict(json.loads(completion_text)).get(self._ENTITIES_KEY, [])

    def _extract_relationships(self, chunk: ChunkD,
                               entity_nodes: List[Dict[str, str]]) -> Iterator[EntityRelationshipD]:
        if isinstance(chunk.parent_reference, DocumentD):
            analyst_names: str = chunk.parent_reference.authors
            document_datetime: str = chunk.parent_reference.publish_date
        else:
            raise NotImplementedError("Parent reference is not a DocumentD")
        messages: List[ChatCompletionMessageParam] = [
            {
                "role": "system", "content": EXTRACT_RELATIONSHIPS_PROMPT
            },
            {
                "role":
                    "user",
                "content":
                    f"Analyst: {analyst_names} \n Date: {document_datetime} \n Text Chunk: {chunk.chunk_text} \n {str(entity_nodes)}"
            }
        ]

        completion_text = self._handler.get_chat_completion(messages=messages,
                                                            model=self._model_version,
                                                            temperature=self._TEMPARATURE,
                                                            response_format={"type": "json_object"})
        logging.info(f"relationship results: {completion_text}")
        completion_text = dict(json.loads(completion_text))
        relationships: List[Dict[str, Dict[str,
                                           str]]] = completion_text.get(self._RELATIONSHIP_KEY, [])
        for extracted_relationship in relationships:
            key: str = list(extracted_relationship.keys())[0]
            if key in self._RELATIONSHIPS_TYPES:
                relationship_kw_attr: Dict[str, str] = extracted_relationship[key]
                relationship_d = RelationshipD(
                    relationship_id=str(uuid.uuid4()),
                    start_date=relationship_kw_attr.get("start_date", ""),
                    end_date=relationship_kw_attr.get("end_date", ""),
                    source_text=chunk.chunk_text,
                    predicted_movement=RelationshipD.from_string(
                        relationship_kw_attr.get("predicted_movement", "")))
            else:
                raise ValueError(f"No valid relationships in {extracted_relationship}")

            entity_l_d = EntityD(entity_id=str(uuid.uuid4()),
                                 entity_name=relationship_kw_attr.get("from_entity", ""))

            entity_r_d = EntityD(entity_id=str(uuid.uuid4()),
                                 entity_name=relationship_kw_attr.get("to_entity", ""))

            yield EntityRelationshipD(relationship=relationship_d,
                                      from_entity=entity_l_d,
                                      to_entity=entity_r_d)

    def _process_element(self, element: ChunkD) -> Iterator[EntityRelationshipD]:
        entities_text: List[Dict[str, str]] = self._extract_entity_names(element.chunk_text)
        yield from self._extract_relationships(element, entities_text)