safetybot / .ipynb_checkpoints /ingest-checkpoint.py
sevdeawesome's picture
add file
6dc66f9
raw
history blame contribute delete
6.26 kB
"""Ingest a directory of documentation files into a vector store and store the relevant artifacts in Weights & Biases"""
import argparse
import json
import logging
import os
import pathlib
from typing import List, Tuple
import langchain
import wandb
from langchain.cache import SQLiteCache
from langchain.docstore.document import Document
from langchain.document_loaders import UnstructuredMarkdownLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import MarkdownTextSplitter
from langchain.vectorstores import Chroma
langchain.llm_cache = SQLiteCache(database_path="langchain.db")
logger = logging.getLogger(__name__)
def load_documents(data_dir: str) -> List[Document]:
"""Load documents from a directory of markdown files
Args:
data_dir (str): The directory containing the markdown files
Returns:
List[Document]: A list of documents
"""
md_files = list(map(str, pathlib.Path(data_dir).glob("*.md")))
documents = [
UnstructuredMarkdownLoader(file_path=file_path).load()[0]
for file_path in md_files
]
return documents
def chunk_documents(
documents: List[Document], chunk_size: int = 500, chunk_overlap=0
) -> List[Document]:
"""Split documents into chunks
Args:
documents (List[Document]): A list of documents to split into chunks
chunk_size (int, optional): The size of each chunk. Defaults to 500.
chunk_overlap (int, optional): The number of tokens to overlap between chunks. Defaults to 0.
Returns:
List[Document]: A list of chunked documents.
"""
markdown_text_splitter = MarkdownTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
split_documents = markdown_text_splitter.split_documents(documents)
return split_documents
def create_vector_store(
documents,
vector_store_path: str = "./vector_store",
) -> Chroma:
"""Create a ChromaDB vector store from a list of documents
Args:
documents (_type_): A list of documents to add to the vector store
vector_store_path (str, optional): The path to the vector store. Defaults to "./vector_store".
Returns:
Chroma: A ChromaDB vector store containing the documents.
"""
api_key = os.environ.get("OPENAI_API_KEY", None)
embedding_function = OpenAIEmbeddings(openai_api_key=api_key)
vector_store = Chroma.from_documents(
documents=documents,
embedding=embedding_function,
persist_directory=vector_store_path,
)
vector_store.persist()
return vector_store
def log_dataset(documents: List[Document], run: "wandb.run"):
"""Log a dataset to wandb
Args:
documents (List[Document]): A list of documents to log to a wandb artifact
run (wandb.run): The wandb run to log the artifact to.
"""
document_artifact = wandb.Artifact(name="documentation_dataset", type="dataset")
with document_artifact.new_file("documents.json") as f:
for document in documents:
f.write(document.json() + "\n")
run.log_artifact(document_artifact)
def log_index(vector_store_dir: str, run: "wandb.run"):
"""Log a vector store to wandb
Args:
vector_store_dir (str): The directory containing the vector store to log
run (wandb.run): The wandb run to log the artifact to.
"""
index_artifact = wandb.Artifact(name="vector_store", type="search_index")
index_artifact.add_dir(vector_store_dir)
run.log_artifact(index_artifact)
def log_prompt(prompt: dict, run: "wandb.run"):
"""Log a prompt to wandb
Args:
prompt (str): The prompt to log
run (wandb.run): The wandb run to log the artifact to.
"""
prompt_artifact = wandb.Artifact(name="chat_prompt", type="prompt")
with prompt_artifact.new_file("prompt.json") as f:
f.write(json.dumps(prompt))
run.log_artifact(prompt_artifact)
def ingest_data(
docs_dir: str,
chunk_size: int,
chunk_overlap: int,
vector_store_path: str,
) -> Tuple[List[Document], Chroma]:
"""Ingest a directory of markdown files into a vector store
Args:
docs_dir (str):
chunk_size (int):
chunk_overlap (int):
vector_store_path (str):
"""
# load the documents
documents = load_documents(docs_dir)
# split the documents into chunks
split_documents = chunk_documents(documents, chunk_size, chunk_overlap)
# create document embeddings and store them in a vector store
vector_store = create_vector_store(split_documents, vector_store_path)
return split_documents, vector_store
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--docs_dir",
type=str,
required=True,
help="The directory containing the wandb documentation",
)
parser.add_argument(
"--chunk_size",
type=int,
default=500,
help="The number of tokens to include in each document chunk",
)
parser.add_argument(
"--chunk_overlap",
type=int,
default=0,
help="The number of tokens to overlap between document chunks",
)
parser.add_argument(
"--vector_store",
type=str,
default="./vector_store",
help="The directory to save or load the Chroma db to/from",
)
parser.add_argument(
"--prompt_file",
type=pathlib.Path,
default="./chat_prompt.json",
help="The path to the chat prompt to use",
)
parser.add_argument(
"--wandb_project",
default="llmapps",
type=str,
help="The wandb project to use for storing artifacts",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
run = wandb.init(project=args.wandb_project, config=args)
documents, vector_store = ingest_data(
docs_dir=args.docs_dir,
chunk_size=args.chunk_size,
chunk_overlap=args.chunk_overlap,
vector_store_path=args.vector_store,
)
log_dataset(documents, run)
log_index(args.vector_store, run)
log_prompt(json.load(args.prompt_file.open("r")), run)
run.finish()
if __name__ == "__main__":
main()