Spaces:
Sleeping
Sleeping
File size: 6,162 Bytes
494b300 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import TextLoader, JSONLoader, PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.docstore.in_memory import InMemoryDocstore
import faiss
import os
import glob
from typing import Any,List,Dict
from embedding import Embedding
class KnowledgeBaseManager:
def __init__(self, base_path="./knowledge_bases", embedding_dim=512, batch_size=16):
self.base_path = base_path
self.embedding_dim = embedding_dim
self.batch_size = batch_size
self.embeddings = Embedding()
self.knowledge_bases: Dict[str, FAISS] = {}
os.makedirs(self.base_path, exist_ok=True)
faiss_files = glob.glob(os.path.join(base_path, '*.faiss'))
# 获取不带后缀的名称
file_names_without_extension = [os.path.splitext(os.path.basename(file))[0] for file in faiss_files]
for name in file_names_without_extension:
self.load_knowledge_base(name)
def create_knowledge_base(self, name: str):
index = faiss.IndexFlatL2(self.embedding_dim)
kb = FAISS(self.embeddings, index, InMemoryDocstore(), {})
if name in self.knowledge_bases:
print(f"Knowledge base '{name}' already exists.")
return
self.knowledge_bases[name] = kb
self.save_knowledge_base(name)
print(f"Knowledge base '{name}' created.")
def delete_knowledge_base(self, name: str):
if name in self.knowledge_bases:
del self.knowledge_bases[name]
os.remove(os.path.join(self.base_path, f"{name}.faiss"))
print(f"Knowledge base '{name}' deleted.")
else:
print(f"Knowledge base '{name}' does not exist.")
def load_knowledge_base(self, name: str):
kb_path = os.path.join(self.base_path, f"{name}.faiss")
if os.path.exists(kb_path):
self.knowledge_bases[name] = FAISS.load_local(self.base_path, self.embeddings, name, allow_dangerous_deserialization=True)
print(f"Knowledge base '{name}' loaded.")
else:
print(f"Knowledge base '{name}' does not exist.")
def save_knowledge_base(self, name: str):
if name in self.knowledge_bases:
self.knowledge_bases[name].save_local(self.base_path, name)
print(f"Knowledge base '{name}' saved.")
else:
print(f"Knowledge base '{name}' does not exist.")
# Document(page_content = '渠道版', metadata = {
# 'source': './files/input/PS004.pdf',
# 'page': 0
# }), Document(page_content = '2/20.', metadata = {
# 'source': './files/input/PS004.pdf',
# 'page': 1
# })
def add_documents_to_kb(self, name: str, file_paths: List[str]):
if name not in self.knowledge_bases:
print(f"Knowledge base '{name}' does not exist.")
self.create_knowledge_base(name)
kb = self.knowledge_bases[name]
documents = self.load_documents(file_paths)
print(f"Loaded {len(documents)} documents.")
print(documents)
pages = self.split_documents(documents)
print(f"Split documents into {len(pages)} pages.")
# print(pages)
doc_ids = []
for i in range(0, len(pages), self.batch_size):
batch = pages[i:i+self.batch_size]
doc_ids.extend(kb.add_documents(batch))
self.save_knowledge_base(name)
return doc_ids
def load_documents(self, file_paths: List[str]):
documents = []
for file_path in file_paths:
loader = self.get_loader(file_path)
documents.extend(loader.load())
return documents
def get_loader(self, file_path: str):
if file_path.endswith('.txt'):
return TextLoader(file_path)
elif file_path.endswith('.json'):
return JSONLoader(file_path)
elif file_path.endswith('.pdf'):
return PyPDFLoader(file_path)
else:
raise ValueError("Unsupported file format")
def split_documents(self, documents):
text_splitter = RecursiveCharacterTextSplitter(separators=[
"\n\n",
"\n",
" ",
".",
",",
"\u200b", # Zero-width space
"\uff0c", # Fullwidth comma
"\u3001", # Ideographic comma
"\uff0e", # Fullwidth full stop
"\u3002", # Ideographic full stop
"",
],
chunk_size=512, chunk_overlap=0)
return text_splitter.split_documents(documents)
def retrieve_documents(self, names: List[str], query: str):
results = []
for name in names:
if name not in self.knowledge_bases:
print(f"Knowledge base '{name}' does not exist.")
continue
retriever = self.knowledge_bases[name].as_retriever(
search_type="mmr",
search_kwargs={"score_threshold": 0.5, "k": 3}
)
docs = retriever.get_relevant_documents(query)
results.extend([{"name": name, "content": doc.page_content,"meta": doc.metadata} for doc in docs])
return results
def get_bases(self):
data = self.knowledge_bases.keys()
return list(data)
def get_df_bases(self):
import pandas as pd
data = self.knowledge_bases.keys()
return pd.DataFrame(list(data), columns=['列表'])
knowledgeBase = KnowledgeBaseManager()
|