project / app /parsers /splitter.py
kabylake's picture
commit
7bd11ed
raw
history blame contribute delete
No virus
4.47 kB
import hashlib
import urllib
import uuid
from pathlib import Path
from typing import List, Tuple
import pandas as pd
from loguru import logger
from app.config.models.configs import Document, Config
from app.parsers.markdown import markdown_splitter
HASH_BLOCKSIZE = 65536
class DocumentSplitter:
def __init__(self, config: Config) -> None:
self.document_path_settings = config.embeddings.document_settings
self.chunk_sizes = config.embeddings.chunk_sizes
def split(
self,
limit: int = None,
) -> Tuple[List[Document], pd.DataFrame, pd.DataFrame]:
all_docs = []
hash_filename_mappings = []
hash_docid_mappings = []
for setting in self.document_path_settings:
passage_prefix = setting.passage_prefix
docs_path = Path(setting.doc_path)
extension = "md"
for chunk_size in self.chunk_sizes:
paths = [p for p in list(docs_path.glob(f"**/*.{extension}"))]
additional_parser_settings = setting.additional_parser_settings.get(
extension, dict()
)
(
docs,
hf_mappings,
hd_mappings,
) = self._get_documents_from_custom_splitter(
document_paths=paths,
splitter_func=markdown_splitter,
max_size=chunk_size,
passage_prefix=passage_prefix,
**additional_parser_settings,
)
all_docs.extend(docs)
hash_filename_mappings.extend(hf_mappings)
hash_docid_mappings.extend(hd_mappings)
all_hash_filename_mappings = pd.DataFrame(hash_filename_mappings)
all_hash_docid_mappings = pd.concat(hash_docid_mappings, axis=0)
if limit:
all_docs = all_docs[:limit]
all_hash_filename_mappings = all_hash_filename_mappings[:limit]
all_hash_docid_mappings = all_hash_docid_mappings[:limit]
return all_docs, all_hash_filename_mappings, all_hash_docid_mappings
def _get_documents_from_custom_splitter(
self,
document_paths: List[Path],
splitter_func,
max_size,
passage_prefix: str,
**additional_kwargs,
) -> Tuple[List[Document], List[dict], List[pd.DataFrame]]:
all_docs = []
hash_filename_mappings = []
hash_docid_mappings = []
for path in document_paths:
filepath = str(path)
filename = filepath.split("/")[-1].replace(f".{path.suffix}", "")
if path.suffix != ".md":
continue
additional_kwargs.update({"filename": filepath})
docs_data = splitter_func(path, max_size, **additional_kwargs)
file_hash = get_md5_hash(path)
path = urllib.parse.quote(str(path)) # type: ignore
logger.info(path)
docs = [
Document(
page_content=passage_prefix + d["text"],
metadata={
**d["metadata"],
**{
"source": str(path),
"chunk_size": max_size,
"document_id": str(uuid.uuid1()),
"label": filename,
},
},
)
for d in docs_data
]
for d in docs:
if 'page' in d.metadata and d.metadata['page'] is None:
d.metadata['page'] = -1
all_docs.extend(docs)
hash_filename_mappings.append(dict(filename=filepath, filehash=file_hash))
df_hash_docid = (
pd.DataFrame()
.assign(docid=[d.metadata["document_id"] for d in docs])
.assign(filehash=file_hash)
)
hash_docid_mappings.append(df_hash_docid)
logger.info(f"Got {len(all_docs)} nodes.")
return all_docs, hash_filename_mappings, hash_docid_mappings
def get_md5_hash(file_path: Path) -> str:
hasher = hashlib.md5()
with open(file_path, "rb") as file:
buf = file.read(HASH_BLOCKSIZE)
while buf:
hasher.update(buf)
buf = file.read(HASH_BLOCKSIZE)
return hasher.hexdigest()