#!/usr/bin/env python3 """Manage the project database.""" # %% IMPORTS import argparse import logging import re import sys import typing as T import lib # %% LOGGING logging.basicConfig( level=logging.DEBUG, format="[%(asctime)s][%(levelname)s] %(message)s", ) # %% PARSING PARSER = argparse.ArgumentParser(description=__doc__) PARSER.add_argument("files", type=argparse.FileType("r"), nargs="+") PARSER.add_argument("--database", type=str, default=lib.DATABASE_PATH) PARSER.add_argument("--collection", type=str, default=lib.DATABASE_COLLECTION) # %% FUNCTIONS def segment_text(text: str, pattern: str) -> T.Iterator[tuple[str, str]]: """Segment the text in title and content pair by pattern.""" splits = re.split(pattern, text, flags=re.MULTILINE) pairs = zip(splits[1::2], splits[2::2]) return pairs def import_file( file: T.TextIO, collection: lib.Collection, encoding_function: T.Callable, max_output_tokens: int = lib.ENCODING_OUTPUT_LIMIT, ) -> tuple[int, int]: """Import a markdown file to a database collection.""" n_chars = 0 n_tokens = 0 text = file.read() filename = file.name segments_h1 = segment_text(text=text, pattern=r"^# (.+)") for h1, h1_text in segments_h1: logging.debug('\t- H1: "%s" (%d)', h1, len(h1_text)) segments_h2 = segment_text(text=h1_text, pattern=r"^## (.+)") for h2, content in segments_h2: content_chars = len(content) content_tokens = len(encoding_function(content)) logging.debug('\t\t- H2: "%s" (%d)', h2, content_chars) id_ = f"{filename} # {h1} ## {h2}" # unique doc id document = f"# {h1}\n\n## {h2}\n\n{content.strip()}" metadata = {"filename": filename, "h1": h1, "h2": h2} assert ( content_tokens < max_output_tokens ), f"Content is too long ({content_tokens}): #{h1} ##{h2}" collection.add(ids=id_, documents=document, metadatas=metadata) n_tokens += content_tokens n_chars += content_chars return n_chars, n_tokens def main(args: list[str] | None = None) -> int: """Main function of the script.""" # parsing opts = PARSER.parse_args(args) # database database_path = opts.database logging.info("Database path: %s", database_path) client = lib.get_database_client(path=database_path) logging.info("- Reseting database client: %s", client.reset()) # encoding encoding_function = lib.get_encoding_function() logging.info("Encoding function: %s", encoding_function) # embedding embedding_function = lib.get_embedding_function() logging.info("Embedding function: %s", embedding_function) # collection database_collection = opts.collection logging.info("Database collection: %s", database_collection) collection = client.create_collection( name=database_collection, embedding_function=embedding_function ) # files for i, file in enumerate(opts.files): logging.info("Importing file %d: %s", i, file.name) n_chars, n_tokens = import_file( file=file, collection=collection, encoding_function=encoding_function ) logging.info( "- Docs imported from file %s: %d chars | %d tokens", i, n_chars, n_tokens ) # return return 0 # %% ENTRYPOINTS if __name__ == "__main__": sys.exit(main())