Spaces:
Running
Running
File size: 6,260 Bytes
6dc66f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
"""Ingest a directory of documentation files into a vector store and store the relevant artifacts in Weights & Biases"""
import argparse
import json
import logging
import os
import pathlib
from typing import List, Tuple
import langchain
import wandb
from langchain.cache import SQLiteCache
from langchain.docstore.document import Document
from langchain.document_loaders import UnstructuredMarkdownLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.text_splitter import MarkdownTextSplitter
from langchain.vectorstores import Chroma
langchain.llm_cache = SQLiteCache(database_path="langchain.db")
logger = logging.getLogger(__name__)
def load_documents(data_dir: str) -> List[Document]:
"""Load documents from a directory of markdown files
Args:
data_dir (str): The directory containing the markdown files
Returns:
List[Document]: A list of documents
"""
md_files = list(map(str, pathlib.Path(data_dir).glob("*.md")))
documents = [
UnstructuredMarkdownLoader(file_path=file_path).load()[0]
for file_path in md_files
]
return documents
def chunk_documents(
documents: List[Document], chunk_size: int = 500, chunk_overlap=0
) -> List[Document]:
"""Split documents into chunks
Args:
documents (List[Document]): A list of documents to split into chunks
chunk_size (int, optional): The size of each chunk. Defaults to 500.
chunk_overlap (int, optional): The number of tokens to overlap between chunks. Defaults to 0.
Returns:
List[Document]: A list of chunked documents.
"""
markdown_text_splitter = MarkdownTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
split_documents = markdown_text_splitter.split_documents(documents)
return split_documents
def create_vector_store(
documents,
vector_store_path: str = "./vector_store",
) -> Chroma:
"""Create a ChromaDB vector store from a list of documents
Args:
documents (_type_): A list of documents to add to the vector store
vector_store_path (str, optional): The path to the vector store. Defaults to "./vector_store".
Returns:
Chroma: A ChromaDB vector store containing the documents.
"""
api_key = os.environ.get("OPENAI_API_KEY", None)
embedding_function = OpenAIEmbeddings(openai_api_key=api_key)
vector_store = Chroma.from_documents(
documents=documents,
embedding=embedding_function,
persist_directory=vector_store_path,
)
vector_store.persist()
return vector_store
def log_dataset(documents: List[Document], run: "wandb.run"):
"""Log a dataset to wandb
Args:
documents (List[Document]): A list of documents to log to a wandb artifact
run (wandb.run): The wandb run to log the artifact to.
"""
document_artifact = wandb.Artifact(name="documentation_dataset", type="dataset")
with document_artifact.new_file("documents.json") as f:
for document in documents:
f.write(document.json() + "\n")
run.log_artifact(document_artifact)
def log_index(vector_store_dir: str, run: "wandb.run"):
"""Log a vector store to wandb
Args:
vector_store_dir (str): The directory containing the vector store to log
run (wandb.run): The wandb run to log the artifact to.
"""
index_artifact = wandb.Artifact(name="vector_store", type="search_index")
index_artifact.add_dir(vector_store_dir)
run.log_artifact(index_artifact)
def log_prompt(prompt: dict, run: "wandb.run"):
"""Log a prompt to wandb
Args:
prompt (str): The prompt to log
run (wandb.run): The wandb run to log the artifact to.
"""
prompt_artifact = wandb.Artifact(name="chat_prompt", type="prompt")
with prompt_artifact.new_file("prompt.json") as f:
f.write(json.dumps(prompt))
run.log_artifact(prompt_artifact)
def ingest_data(
docs_dir: str,
chunk_size: int,
chunk_overlap: int,
vector_store_path: str,
) -> Tuple[List[Document], Chroma]:
"""Ingest a directory of markdown files into a vector store
Args:
docs_dir (str):
chunk_size (int):
chunk_overlap (int):
vector_store_path (str):
"""
# load the documents
documents = load_documents(docs_dir)
# split the documents into chunks
split_documents = chunk_documents(documents, chunk_size, chunk_overlap)
# create document embeddings and store them in a vector store
vector_store = create_vector_store(split_documents, vector_store_path)
return split_documents, vector_store
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--docs_dir",
type=str,
required=True,
help="The directory containing the wandb documentation",
)
parser.add_argument(
"--chunk_size",
type=int,
default=500,
help="The number of tokens to include in each document chunk",
)
parser.add_argument(
"--chunk_overlap",
type=int,
default=0,
help="The number of tokens to overlap between document chunks",
)
parser.add_argument(
"--vector_store",
type=str,
default="./vector_store",
help="The directory to save or load the Chroma db to/from",
)
parser.add_argument(
"--prompt_file",
type=pathlib.Path,
default="./chat_prompt.json",
help="The path to the chat prompt to use",
)
parser.add_argument(
"--wandb_project",
default="llmapps",
type=str,
help="The wandb project to use for storing artifacts",
)
return parser
def main():
parser = get_parser()
args = parser.parse_args()
run = wandb.init(project=args.wandb_project, config=args)
documents, vector_store = ingest_data(
docs_dir=args.docs_dir,
chunk_size=args.chunk_size,
chunk_overlap=args.chunk_overlap,
vector_store_path=args.vector_store,
)
log_dataset(documents, run)
log_index(args.vector_store, run)
log_prompt(json.load(args.prompt_file.open("r")), run)
run.finish()
if __name__ == "__main__":
main()
|