import re import urllib from collections import namedtuple from enum import Enum from pathlib import Path from typing import Generator, List, Union, Tuple from loguru import logger FORMATTING_SEQUENCES = {"*", "**", "***", "_", "__", "~~", "||"} CODE_BLOCK_SEQUENCES = {"`", "``", "```"} ALL_SEQUENCES = FORMATTING_SEQUENCES | CODE_BLOCK_SEQUENCES MAX_FORMATTING_SEQUENCE_LENGTH = max(len(seq) for seq in ALL_SEQUENCES) class SplitCandidates(Enum): SPACE = 1 NEWLINE = 2 LAST_CHAR = 3 SPLIT_CANDIDATES_PREFRENCE = [ SplitCandidates.NEWLINE, SplitCandidates.SPACE, SplitCandidates.LAST_CHAR, ] BLOCK_SPLIT_CANDIDATES = [r"\n#\s+", r"\n##\s+", r"\n###\s+"] CODE_BLOCK_LEVEL = 10 MarkdownChunk = namedtuple("MarkdownChunk", "string level") class SplitCandidateInfo: last_seen: int active_sequences: List[str] active_sequences_length: int def __init__(self): self.last_seen = None self.active_sequences = [] self.active_sequences_length = 0 def process_sequence(self, seq: str, is_in_code_block: bool): if is_in_code_block: if self.active_sequences and seq == self.active_sequences[-1]: last_seq = self.active_sequences.pop() self.active_sequences_length -= len(last_seq) return True elif seq in CODE_BLOCK_SEQUENCES: self.active_sequences.append(seq) self.active_sequences_length += len(seq) return True else: for k in range(len(self.active_sequences) - 1, -1, -1): if seq == self.active_sequences[k]: sequences_being_removed = self.active_sequences[k:] self.active_sequences = self.active_sequences[:k] self.active_sequences_length -= sum( len(seq) for seq in sequences_being_removed ) return False self.active_sequences.append(seq) self.active_sequences_length += len(seq) return False def copy_from(self, other): self.last_seen = other.last_seen self.active_sequences = other.active_sequences.copy() self.active_sequences_length = other.active_sequences_length def physical_split(markdown: str, max_chunk_size: int) -> Generator[str, None, None]: if max_chunk_size <= MAX_FORMATTING_SEQUENCE_LENGTH: raise ValueError( f"max_chunk_size must be greater than {MAX_FORMATTING_SEQUENCE_LENGTH}" ) split_candidates = { SplitCandidates.SPACE: SplitCandidateInfo(), SplitCandidates.NEWLINE: SplitCandidateInfo(), SplitCandidates.LAST_CHAR: SplitCandidateInfo(), } is_in_code_block = False chunk_start_from, chunk_char_count, chunk_prefix = 0, 0, "" def split_chunk(): for split_variant in SPLIT_CANDIDATES_PREFRENCE: split_candidate = split_candidates[split_variant] if split_candidate.last_seen is None: continue chunk_end = split_candidate.last_seen + ( 1 if split_variant == SplitCandidates.LAST_CHAR else 0 ) chunk = ( chunk_prefix + markdown[chunk_start_from:chunk_end] + "".join(reversed(split_candidate.active_sequences)) ) next_chunk_prefix = "".join(split_candidate.active_sequences) next_chunk_char_count = len(next_chunk_prefix) next_chunk_start_from = chunk_end + ( 0 if split_variant == SplitCandidates.LAST_CHAR else 1 ) split_candidates[SplitCandidates.NEWLINE] = SplitCandidateInfo() split_candidates[SplitCandidates.SPACE] = SplitCandidateInfo() return ( chunk, next_chunk_start_from, next_chunk_char_count, next_chunk_prefix, ) i = 0 while i < len(markdown): for j in range(MAX_FORMATTING_SEQUENCE_LENGTH, 0, -1): seq = markdown[i: i + j] if seq in ALL_SEQUENCES: last_char_split_candidate_len = ( chunk_char_count + split_candidates[ SplitCandidates.LAST_CHAR ].active_sequences_length + len(seq) ) if last_char_split_candidate_len >= max_chunk_size: ( next_chunk, chunk_start_from, chunk_char_count, chunk_prefix, ) = split_chunk() yield next_chunk is_in_code_block = split_candidates[ SplitCandidates.LAST_CHAR ].process_sequence(seq, is_in_code_block) i += len(seq) chunk_char_count += len(seq) split_candidates[SplitCandidates.LAST_CHAR].last_seen = i - 1 break if i >= len(markdown): break split_candidates[SplitCandidates.LAST_CHAR].last_seen = i chunk_char_count += 1 if markdown[i] == "\n": split_candidates[SplitCandidates.NEWLINE].copy_from( split_candidates[SplitCandidates.LAST_CHAR] ) elif markdown[i] == " ": split_candidates[SplitCandidates.SPACE].copy_from( split_candidates[SplitCandidates.LAST_CHAR] ) last_char_split_candidate_len = ( chunk_char_count + split_candidates[SplitCandidates.LAST_CHAR].active_sequences_length ) if last_char_split_candidate_len == max_chunk_size: next_chunk, chunk_start_from, chunk_char_count, chunk_prefix = split_chunk() yield next_chunk i += 1 if chunk_start_from < len(markdown): yield chunk_prefix + markdown[chunk_start_from:] def get_logical_blocks_recursively( markdown: str, max_chunk_size: int, all_sections: list, split_candidate_index=0 ) -> List[MarkdownChunk]: if split_candidate_index >= len(BLOCK_SPLIT_CANDIDATES): for chunk in physical_split(markdown, max_chunk_size): all_sections.append( MarkdownChunk(string=chunk, level=split_candidate_index) ) return all_sections chunks = [] add_index = 0 for add_index, split_candidate in enumerate( BLOCK_SPLIT_CANDIDATES[split_candidate_index:] ): chunks = re.split(split_candidate, markdown) if len(chunks) > 1: break for i, chunk in enumerate(chunks): level = split_candidate_index + add_index if i > 0: level += 1 prefix = "\n\n" + "#" * level + " " if not chunk.strip(): continue if len(chunk) <= max_chunk_size: all_sections.append(MarkdownChunk(string=prefix + chunk, level=level - 1)) else: get_logical_blocks_recursively( chunk, max_chunk_size, all_sections, split_candidate_index=split_candidate_index + add_index + 1, ) return all_sections def markdown_splitter( path: Union[str, Path], max_chunk_size: int, **additional_splitter_settings ) -> List[dict]: try: with open(path, "r") as f: markdown = f.read() except OSError: return [] if len(markdown) < max_chunk_size: return [{"text": markdown, "metadata": {"heading": ""}}] sections = [MarkdownChunk(string="", level=0)] markdown, additional_metadata = preprocess_markdown( markdown, additional_splitter_settings ) # Split by code and non-code chunks = markdown.split("```") for i, chunk in enumerate(chunks): if i % 2 == 0: # Every even element (0 indexed) is a non-code logical_blocks = get_logical_blocks_recursively( chunk, max_chunk_size=max_chunk_size, all_sections=[] ) sections += logical_blocks else: # Process the code section rows = chunk.split("\n") code = rows[1:] lang = rows[0] # Get the language name # Provide a hint to LLM all_code_rows = ( [ f"\nFollowing is a code section in {lang}, delimited by triple backticks:", f"```{lang}", ] + code + ["```"] ) all_code_str = "\n".join(all_code_rows) # Merge code to a previous logical block if there is enough space if len(sections[-1].string) + len(all_code_str) < max_chunk_size: sections[-1] = MarkdownChunk( string=sections[-1].string + all_code_str, level=sections[-1].level ) # If code block is larger than max size, physically split it elif len(all_code_str) >= max_chunk_size: code_chunks = physical_split( all_code_str, max_chunk_size=max_chunk_size ) for cchunk in code_chunks: # Assign language header to the code chunk, if doesn't exist if f"```{lang}" not in cchunk: cchunk_rows = cchunk.split("```") cchunk = f"```{lang}\n" + cchunk_rows[1] + "```" sections.append( MarkdownChunk(string=cchunk, level=CODE_BLOCK_LEVEL) ) # Otherwise, add as a single chunk else: sections.append( MarkdownChunk(string=all_code_str, level=CODE_BLOCK_LEVEL) ) all_out = postprocess_sections( sections, max_chunk_size, additional_splitter_settings, additional_metadata, path, ) return all_out def preprocess_markdown(markdown: str, additional_settings: dict) -> Tuple[str, dict]: preprocess_remove_images = additional_settings.get("remove_images", False) preprocess_remove_extra_newlines = additional_settings.get( "remove_extra_newlines", True ) preprocess_find_metadata = additional_settings.get("find_metadata", dict()) if preprocess_remove_images: markdown = remove_images(markdown) if preprocess_remove_extra_newlines: markdown = remove_extra_newlines(markdown) additional_metadata = {} if preprocess_find_metadata: if not isinstance(preprocess_find_metadata, dict): raise TypeError( f"find_metadata settings should be of type dict. Got {type(preprocess_find_metadata)}" ) for label, search_string in preprocess_find_metadata.items(): logger.info(f"Looking for metadata: {search_string}") metadata = find_metadata(markdown, search_string) if metadata: logger.info(f"\tFound metadata for {label} - {metadata}") additional_metadata[label] = metadata return markdown, additional_metadata def postprocess_sections( sections: List[MarkdownChunk], max_chunk_size: int, additional_settings: dict, additional_metadata: dict, path: Union[str, Path], ) -> List[dict]: all_out = [] skip_first = additional_settings.get("skip_first", False) merge_headers = additional_settings.get("merge_sections", False) # Remove all empty sections sections = [s for s in sections if s.string] if sections and skip_first: # remove first section sections = sections[1:] if sections and merge_headers: # Merge sections sections = merge_sections(sections, max_chunk_size=max_chunk_size) current_heading = "" sections_metadata = {"Document name": Path(path).name} for s in sections: stripped_string = s.string.strip() doc_metadata = {} if len(stripped_string) > 0: heading = "" if stripped_string.startswith("#"): # heading detected heading = stripped_string.split("\n")[0].replace("#", "").strip() stripped_heading = heading.replace("#", "").replace(" ", "").strip() if not stripped_heading: heading = "" if s.level == 0: current_heading = heading doc_metadata["heading"] = urllib.parse.quote( heading ) # isolate the heading else: doc_metadata["heading"] = "" final_section = add_section_metadata( stripped_string, section_metadata={ **sections_metadata, **{"Subsection of": current_heading}, **additional_metadata, }, ) all_out.append({"text": final_section, "metadata": doc_metadata}) return all_out def remove_images(page_md: str) -> str: return re.sub(r"""!\[[^\]]*\]\((.*?)\s*("(?:.*[^"])")?\s*\)""", "", page_md) def remove_extra_newlines(page_md) -> str: page_md = re.sub(r"\n{3,}", "\n\n", page_md) return page_md def add_section_metadata(s, section_metadata: dict): metadata_s = "" for k, v in section_metadata.items(): if v: metadata_s += f"{k}: {v}\n" metadata = f"Metadata applicable to the next chunk of text delimited by five stars:\n>> METADATA START\n{metadata_s}>> METADATA END\n\n" return metadata + "*****\n" + s + "\n*****" def find_metadata(page_md: str, search_string: str) -> str: pattern = rf"{search_string}(.*)" match = re.search(pattern, page_md) if match: return match.group(1) return "" def merge_sections( sections: List[MarkdownChunk], max_chunk_size: int ) -> List[MarkdownChunk]: current_section = sections[0] all_out = [] prev_level = 0 for s in sections[1:]: if ( len(current_section.string + s.string) > max_chunk_size or s.level <= prev_level ): all_out.append(current_section) current_section = s prev_level = 0 else: current_section = MarkdownChunk( string=current_section.string + s.string, level=current_section.level ) prev_level = s.level if s.level != CODE_BLOCK_LEVEL else prev_level all_out.append(current_section) return all_out