cosco's picture
Upload 304 files
184a47b verified
import os
import sys
import re
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
import tempfile
from dotenv import load_dotenv, find_dotenv
from embedding.call_embedding import get_embedding
from langchain.document_loaders import UnstructuredFileLoader
from langchain.document_loaders import UnstructuredMarkdownLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import PyMuPDFLoader
from langchain.vectorstores import Chroma
# 首先实现基本配置
DEFAULT_DB_PATH = "./knowledge_db"
DEFAULT_PERSIST_PATH = "./vector_db"
def get_files(dir_path):
file_list = []
for filepath, dirnames, filenames in os.walk(dir_path):
for filename in filenames:
file_list.append(os.path.join(filepath, filename))
return file_list
def file_loader(file, loaders):
if isinstance(file, tempfile._TemporaryFileWrapper):
file = file.name
if not os.path.isfile(file):
[file_loader(os.path.join(file, f), loaders) for f in os.listdir(file)]
return
file_type = file.split(".")[-1]
if file_type == "pdf":
loaders.append(PyMuPDFLoader(file))
elif file_type == "md":
pattern = r"不存在|风控"
match = re.search(pattern, file)
if not match:
loaders.append(UnstructuredMarkdownLoader(file))
elif file_type == "txt":
loaders.append(UnstructuredFileLoader(file))
return
def create_db_info(
files=DEFAULT_DB_PATH, embeddings="openai", persist_directory=DEFAULT_PERSIST_PATH
):
if embeddings == "openai" or embeddings == "m3e" or embeddings == "zhipuai":
vectordb = create_db(files, persist_directory, embeddings)
return ""
def create_db(
files=DEFAULT_DB_PATH, persist_directory=DEFAULT_PERSIST_PATH, embeddings="zhipuai"
):
"""
该函数用于加载 PDF 文件,切分文档,生成文档的嵌入向量,创建向量数据库。
参数:
file: 存放文件的路径。
embeddings: 用于生产 Embedding 的模型
返回:
vectordb: 创建的数据库。
"""
if files == None:
return "can't load empty file"
if type(files) != list:
files = [files]
loaders = []
[file_loader(file, loaders) for file in files]
docs = []
for loader in loaders:
if loader is not None:
docs.extend(loader.load())
# 切分文档
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=150)
split_docs = text_splitter.split_documents(docs)
if type(embeddings) == str:
embeddings = get_embedding(embedding=embeddings)
# 定义持久化路径
persist_directory = "./vector_db/chroma"
# 加载数据库
vectordb = Chroma.from_documents(
documents=split_docs,
embedding=embeddings,
persist_directory=persist_directory, # 允许我们将persist_directory目录保存到磁盘上
)
vectordb.persist()
return vectordb
def presit_knowledge_db(vectordb):
"""
该函数用于持久化向量数据库。
参数:
vectordb: 要持久化的向量数据库。
"""
vectordb.persist()
def load_knowledge_db(path, embeddings):
"""
该函数用于加载向量数据库。
参数:
path: 要加载的向量数据库路径。
embeddings: 向量数据库使用的 embedding 模型。
返回:
vectordb: 加载的数据库。
"""
vectordb = Chroma(persist_directory=path, embedding_function=embeddings)
return vectordb
if __name__ == "__main__":
create_db(embeddings="zhipuai")