Gateston Johns
first real commit
9041389
import json
import logging
import uuid
from typing import ClassVar, Dict, Iterator, List, Optional
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from domain.chunk_d import (ChunkD, DocumentD)
from domain.entity_d import (
EntityD,
EntityRelationshipD,
RelationshipD,
)
from extraction_pipeline.relationship_extractor.entity_relationship_extractor import (
RelationshipExtractor,)
from extraction_pipeline.relationship_extractor.prompts import (
EXTRACT_RELATIONSHIPS_PROMPT,
NER_TAGGING_PROMPT,
)
from llm_handler.llm_interface import LLMInterface
from llm_handler.openai_handler import ChatModelVersion, OpenAIHandler
class OpenAIRelationshipExtractor(RelationshipExtractor):
_handler: LLMInterface
_MODEL_VERSION: ChatModelVersion = ChatModelVersion.GPT_4_O
_RELATIONSHIP_KEY: ClassVar[str] = "relationships"
_ENTITIES_KEY: ClassVar[str] = "entities"
_RELATIONSHIPS_TYPES: ClassVar[List[str]] = ["PREDICTION"]
_TEMPARATURE: ClassVar[float] = 0.2
def __init__(self,
openai_handler: Optional[LLMInterface] = None,
model_version: Optional[ChatModelVersion] = None):
self._handler = openai_handler or OpenAIHandler()
self._model_version = model_version or self._MODEL_VERSION
def _extract_entity_names(self, chunk_text: str) -> List[Dict[str, str]]:
messages: List[ChatCompletionMessageParam] = [{
"role": "system", "content": NER_TAGGING_PROMPT
},
{
"role": "user",
"content": f"Input:\n{chunk_text}"
}]
completion_text = self._handler.get_chat_completion(messages=messages,
model=self._model_version,
temperature=self._TEMPARATURE,
response_format={"type": "json_object"})
logging.info(f"entity extraction results: {completion_text}")
return dict(json.loads(completion_text)).get(self._ENTITIES_KEY, [])
def _extract_relationships(self, chunk: ChunkD,
entity_nodes: List[Dict[str, str]]) -> Iterator[EntityRelationshipD]:
if isinstance(chunk.parent_reference, DocumentD):
analyst_names: str = chunk.parent_reference.authors
document_datetime: str = chunk.parent_reference.publish_date
else:
raise NotImplementedError("Parent reference is not a DocumentD")
messages: List[ChatCompletionMessageParam] = [
{
"role": "system", "content": EXTRACT_RELATIONSHIPS_PROMPT
},
{
"role":
"user",
"content":
f"Analyst: {analyst_names} \n Date: {document_datetime} \n Text Chunk: {chunk.chunk_text} \n {str(entity_nodes)}"
}
]
completion_text = self._handler.get_chat_completion(messages=messages,
model=self._model_version,
temperature=self._TEMPARATURE,
response_format={"type": "json_object"})
logging.info(f"relationship results: {completion_text}")
completion_text = dict(json.loads(completion_text))
relationships: List[Dict[str, Dict[str,
str]]] = completion_text.get(self._RELATIONSHIP_KEY, [])
for extracted_relationship in relationships:
key: str = list(extracted_relationship.keys())[0]
if key in self._RELATIONSHIPS_TYPES:
relationship_kw_attr: Dict[str, str] = extracted_relationship[key]
relationship_d = RelationshipD(
relationship_id=str(uuid.uuid4()),
start_date=relationship_kw_attr.get("start_date", ""),
end_date=relationship_kw_attr.get("end_date", ""),
source_text=chunk.chunk_text,
predicted_movement=RelationshipD.from_string(
relationship_kw_attr.get("predicted_movement", "")))
else:
raise ValueError(f"No valid relationships in {extracted_relationship}")
entity_l_d = EntityD(entity_id=str(uuid.uuid4()),
entity_name=relationship_kw_attr.get("from_entity", ""))
entity_r_d = EntityD(entity_id=str(uuid.uuid4()),
entity_name=relationship_kw_attr.get("to_entity", ""))
yield EntityRelationshipD(relationship=relationship_d,
from_entity=entity_l_d,
to_entity=entity_r_d)
def _process_element(self, element: ChunkD) -> Iterator[EntityRelationshipD]:
entities_text: List[Dict[str, str]] = self._extract_entity_names(element.chunk_text)
yield from self._extract_relationships(element, entities_text)