import os import sys import re sys.path.append(os.path.dirname(os.path.dirname(__file__))) import tempfile from dotenv import load_dotenv, find_dotenv from embedding.call_embedding import get_embedding from langchain.document_loaders import UnstructuredFileLoader from langchain.document_loaders import UnstructuredMarkdownLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.document_loaders import PyMuPDFLoader from langchain.vectorstores import Chroma # 首先实现基本配置 DEFAULT_DB_PATH = "./knowledge_db" DEFAULT_PERSIST_PATH = "./vector_db" def get_files(dir_path): file_list = [] for filepath, dirnames, filenames in os.walk(dir_path): for filename in filenames: file_list.append(os.path.join(filepath, filename)) return file_list def file_loader(file, loaders): if isinstance(file, tempfile._TemporaryFileWrapper): file = file.name if not os.path.isfile(file): [file_loader(os.path.join(file, f), loaders) for f in os.listdir(file)] return file_type = file.split(".")[-1] if file_type == "pdf": loaders.append(PyMuPDFLoader(file)) elif file_type == "md": pattern = r"不存在|风控" match = re.search(pattern, file) if not match: loaders.append(UnstructuredMarkdownLoader(file)) elif file_type == "txt": loaders.append(UnstructuredFileLoader(file)) return def create_db_info( files=DEFAULT_DB_PATH, embeddings="openai", persist_directory=DEFAULT_PERSIST_PATH ): if embeddings == "openai" or embeddings == "m3e" or embeddings == "zhipuai": vectordb = create_db(files, persist_directory, embeddings) return "" def create_db( files=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="zhipuai" ): """ 该函数用于加载 PDF 文件,切分文档,生成文档的嵌入向量,创建向量数据库。 参数: file: 存放文件的路径。 embeddings: 用于生产 Embedding 的模型 返回: vectordb: 创建的数据库。 """ if files == None: return "can't load empty file" if type(files) != list: files = [files] loaders = [] [file_loader(file, loaders) for file in files] docs = [] for loader in loaders: if loader is not None: docs.extend(loader.load()) # 切分文档 text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=150) split_docs = text_splitter.split_documents(docs) if type(embeddings) == str: embeddings = get_embedding(embedding=embeddings) # 定义持久化路径 persist_directory = "./vector_db/chroma" # 加载数据库 vectordb = Chroma.from_documents( documents=split_docs, embedding=embeddings, persist_directory=persist_directory, # 允许我们将persist_directory目录保存到磁盘上 ) vectordb.persist() return vectordb def presit_knowledge_db(vectordb): """ 该函数用于持久化向量数据库。 参数: vectordb: 要持久化的向量数据库。 """ vectordb.persist() def load_knowledge_db(path, embeddings): """ 该函数用于加载向量数据库。 参数: path: 要加载的向量数据库路径。 embeddings: 向量数据库使用的 embedding 模型。 返回: vectordb: 加载的数据库。 """ vectordb = Chroma(persist_directory=path, embedding_function=embeddings) return vectordb if __name__ == "__main__": create_db(embeddings="zhipuai")