import json import logging import uuid from typing import ClassVar, Dict, Iterator, List, Optional from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam from domain.chunk_d import (ChunkD, DocumentD) from domain.entity_d import ( EntityD, EntityRelationshipD, RelationshipD, ) from extraction_pipeline.relationship_extractor.entity_relationship_extractor import ( RelationshipExtractor,) from extraction_pipeline.relationship_extractor.prompts import ( EXTRACT_RELATIONSHIPS_PROMPT, NER_TAGGING_PROMPT, ) from llm_handler.llm_interface import LLMInterface from llm_handler.openai_handler import ChatModelVersion, OpenAIHandler class OpenAIRelationshipExtractor(RelationshipExtractor): _handler: LLMInterface _MODEL_VERSION: ChatModelVersion = ChatModelVersion.GPT_4_O _RELATIONSHIP_KEY: ClassVar[str] = "relationships" _ENTITIES_KEY: ClassVar[str] = "entities" _RELATIONSHIPS_TYPES: ClassVar[List[str]] = ["PREDICTION"] _TEMPARATURE: ClassVar[float] = 0.2 def __init__(self, openai_handler: Optional[LLMInterface] = None, model_version: Optional[ChatModelVersion] = None): self._handler = openai_handler or OpenAIHandler() self._model_version = model_version or self._MODEL_VERSION def _extract_entity_names(self, chunk_text: str) -> List[Dict[str, str]]: messages: List[ChatCompletionMessageParam] = [{ "role": "system", "content": NER_TAGGING_PROMPT }, { "role": "user", "content": f"Input:\n{chunk_text}" }] completion_text = self._handler.get_chat_completion(messages=messages, model=self._model_version, temperature=self._TEMPARATURE, response_format={"type": "json_object"}) logging.info(f"entity extraction results: {completion_text}") return dict(json.loads(completion_text)).get(self._ENTITIES_KEY, []) def _extract_relationships(self, chunk: ChunkD, entity_nodes: List[Dict[str, str]]) -> Iterator[EntityRelationshipD]: if isinstance(chunk.parent_reference, DocumentD): analyst_names: str = chunk.parent_reference.authors document_datetime: str = chunk.parent_reference.publish_date else: raise NotImplementedError("Parent reference is not a DocumentD") messages: List[ChatCompletionMessageParam] = [ { "role": "system", "content": EXTRACT_RELATIONSHIPS_PROMPT }, { "role": "user", "content": f"Analyst: {analyst_names} \n Date: {document_datetime} \n Text Chunk: {chunk.chunk_text} \n {str(entity_nodes)}" } ] completion_text = self._handler.get_chat_completion(messages=messages, model=self._model_version, temperature=self._TEMPARATURE, response_format={"type": "json_object"}) logging.info(f"relationship results: {completion_text}") completion_text = dict(json.loads(completion_text)) relationships: List[Dict[str, Dict[str, str]]] = completion_text.get(self._RELATIONSHIP_KEY, []) for extracted_relationship in relationships: key: str = list(extracted_relationship.keys())[0] if key in self._RELATIONSHIPS_TYPES: relationship_kw_attr: Dict[str, str] = extracted_relationship[key] relationship_d = RelationshipD( relationship_id=str(uuid.uuid4()), start_date=relationship_kw_attr.get("start_date", ""), end_date=relationship_kw_attr.get("end_date", ""), source_text=chunk.chunk_text, predicted_movement=RelationshipD.from_string( relationship_kw_attr.get("predicted_movement", ""))) else: raise ValueError(f"No valid relationships in {extracted_relationship}") entity_l_d = EntityD(entity_id=str(uuid.uuid4()), entity_name=relationship_kw_attr.get("from_entity", "")) entity_r_d = EntityD(entity_id=str(uuid.uuid4()), entity_name=relationship_kw_attr.get("to_entity", "")) yield EntityRelationshipD(relationship=relationship_d, from_entity=entity_l_d, to_entity=entity_r_d) def _process_element(self, element: ChunkD) -> Iterator[EntityRelationshipD]: entities_text: List[Dict[str, str]] = self._extract_entity_names(element.chunk_text) yield from self._extract_relationships(element, entities_text)