Gateston Johns
first real commit
9041389
raw
history blame
3.17 kB
import dataclasses
import json
from typing import ClassVar, Dict, Iterable, Iterator, List, Optional, Tuple, Union
import pymupdf
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from domain.chunk_d import DocumentD
from extraction_pipeline.document_metadata_extractor.document_metadata_extractor import (
DocumentMetadataExtractor,)
from extraction_pipeline.document_metadata_extractor.prompts import (
DOCUMENT_METADATA_PROMPT,)
from llm_handler.llm_interface import LLMInterface
from llm_handler.openai_handler import ChatModelVersion, OpenAIHandler
from utils.dates import parse_date
class CreationDateError(Exception):
pass
class AuthorsError(Exception):
pass
class OpenAIDocumentMetadataExtractor(DocumentMetadataExtractor):
_handler: LLMInterface
_MODEL_VERSION: ChatModelVersion = ChatModelVersion.GPT_4_O
_AUTHORS_KEY: ClassVar[str] = "authors"
_PUBLISH_DATE_KEY: ClassVar[str] = "publish_date"
_TEMPARATURE: ClassVar[float] = 0.2
def __init__(self,
openai_handler: Optional[LLMInterface] = None,
model_version: Optional[ChatModelVersion] = None):
self._handler = openai_handler or OpenAIHandler()
self._model_version = model_version or self._MODEL_VERSION
def _validate_text(self, completion_text: Dict[str, Union[str, List[str]]]):
if not completion_text.get(self._AUTHORS_KEY):
raise AuthorsError("No authors found.")
if not completion_text.get(self._PUBLISH_DATE_KEY):
raise CreationDateError("No creation date found.")
publish_date_str: str = str(completion_text.get(self._PUBLISH_DATE_KEY, ""))
try:
parse_date(publish_date_str)
except ValueError as e:
raise CreationDateError(
f"Failed to parse publish date '{publish_date_str}': {e}") from e
def _process_element(self, element: DocumentD) -> Iterable[DocumentD]:
pdf_document_pages: Iterator = pymupdf.open(element.file_path).pages()
first_page_text: str = next(pdf_document_pages).get_text()
messages: List[ChatCompletionMessageParam] = [{
"role": "system", "content": DOCUMENT_METADATA_PROMPT
},
{
"role": "user",
"content": f"Input:\n{first_page_text}"
}]
completion_text_raw = self._handler.get_chat_completion(
messages=messages,
model=self._model_version,
temperature=self._TEMPARATURE,
response_format={"type": "json_object"})
completion_text: Dict[str, Union[str, List[str]]] = dict(json.loads(completion_text_raw))
self._validate_text(completion_text)
authors: str = ", ".join(completion_text.get(self._AUTHORS_KEY, []))
publish_date: str = str(completion_text.get(self._PUBLISH_DATE_KEY, ""))
yield dataclasses.replace(element, authors=authors, publish_date=publish_date)