Spaces:
Sleeping
Sleeping
import json | |
import logging | |
import uuid | |
from typing import ClassVar, Dict, Iterator, List, Optional | |
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam | |
from domain.chunk_d import (ChunkD, DocumentD) | |
from domain.entity_d import ( | |
EntityD, | |
EntityRelationshipD, | |
RelationshipD, | |
) | |
from extraction_pipeline.relationship_extractor.entity_relationship_extractor import ( | |
RelationshipExtractor,) | |
from extraction_pipeline.relationship_extractor.prompts import ( | |
EXTRACT_RELATIONSHIPS_PROMPT, | |
NER_TAGGING_PROMPT, | |
) | |
from llm_handler.llm_interface import LLMInterface | |
from llm_handler.openai_handler import ChatModelVersion, OpenAIHandler | |
class OpenAIRelationshipExtractor(RelationshipExtractor): | |
_handler: LLMInterface | |
_MODEL_VERSION: ChatModelVersion = ChatModelVersion.GPT_4_O | |
_RELATIONSHIP_KEY: ClassVar[str] = "relationships" | |
_ENTITIES_KEY: ClassVar[str] = "entities" | |
_RELATIONSHIPS_TYPES: ClassVar[List[str]] = ["PREDICTION"] | |
_TEMPARATURE: ClassVar[float] = 0.2 | |
def __init__(self, | |
openai_handler: Optional[LLMInterface] = None, | |
model_version: Optional[ChatModelVersion] = None): | |
self._handler = openai_handler or OpenAIHandler() | |
self._model_version = model_version or self._MODEL_VERSION | |
def _extract_entity_names(self, chunk_text: str) -> List[Dict[str, str]]: | |
messages: List[ChatCompletionMessageParam] = [{ | |
"role": "system", "content": NER_TAGGING_PROMPT | |
}, | |
{ | |
"role": "user", | |
"content": f"Input:\n{chunk_text}" | |
}] | |
completion_text = self._handler.get_chat_completion(messages=messages, | |
model=self._model_version, | |
temperature=self._TEMPARATURE, | |
response_format={"type": "json_object"}) | |
logging.info(f"entity extraction results: {completion_text}") | |
return dict(json.loads(completion_text)).get(self._ENTITIES_KEY, []) | |
def _extract_relationships(self, chunk: ChunkD, | |
entity_nodes: List[Dict[str, str]]) -> Iterator[EntityRelationshipD]: | |
if isinstance(chunk.parent_reference, DocumentD): | |
analyst_names: str = chunk.parent_reference.authors | |
document_datetime: str = chunk.parent_reference.publish_date | |
else: | |
raise NotImplementedError("Parent reference is not a DocumentD") | |
messages: List[ChatCompletionMessageParam] = [ | |
{ | |
"role": "system", "content": EXTRACT_RELATIONSHIPS_PROMPT | |
}, | |
{ | |
"role": | |
"user", | |
"content": | |
f"Analyst: {analyst_names} \n Date: {document_datetime} \n Text Chunk: {chunk.chunk_text} \n {str(entity_nodes)}" | |
} | |
] | |
completion_text = self._handler.get_chat_completion(messages=messages, | |
model=self._model_version, | |
temperature=self._TEMPARATURE, | |
response_format={"type": "json_object"}) | |
logging.info(f"relationship results: {completion_text}") | |
completion_text = dict(json.loads(completion_text)) | |
relationships: List[Dict[str, Dict[str, | |
str]]] = completion_text.get(self._RELATIONSHIP_KEY, []) | |
for extracted_relationship in relationships: | |
key: str = list(extracted_relationship.keys())[0] | |
if key in self._RELATIONSHIPS_TYPES: | |
relationship_kw_attr: Dict[str, str] = extracted_relationship[key] | |
relationship_d = RelationshipD( | |
relationship_id=str(uuid.uuid4()), | |
start_date=relationship_kw_attr.get("start_date", ""), | |
end_date=relationship_kw_attr.get("end_date", ""), | |
source_text=chunk.chunk_text, | |
predicted_movement=RelationshipD.from_string( | |
relationship_kw_attr.get("predicted_movement", ""))) | |
else: | |
raise ValueError(f"No valid relationships in {extracted_relationship}") | |
entity_l_d = EntityD(entity_id=str(uuid.uuid4()), | |
entity_name=relationship_kw_attr.get("from_entity", "")) | |
entity_r_d = EntityD(entity_id=str(uuid.uuid4()), | |
entity_name=relationship_kw_attr.get("to_entity", "")) | |
yield EntityRelationshipD(relationship=relationship_d, | |
from_entity=entity_l_d, | |
to_entity=entity_r_d) | |
def _process_element(self, element: ChunkD) -> Iterator[EntityRelationshipD]: | |
entities_text: List[Dict[str, str]] = self._extract_entity_names(element.chunk_text) | |
yield from self._extract_relationships(element, entities_text) | |