ai / retriever.py
neoguojing
uP
494b300
raw
history blame
No virus
6.16 kB
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import TextLoader, JSONLoader, PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.docstore.in_memory import InMemoryDocstore
import faiss
import os
import glob
from typing import Any,List,Dict
from embedding import Embedding
class KnowledgeBaseManager:
def __init__(self, base_path="./knowledge_bases", embedding_dim=512, batch_size=16):
self.base_path = base_path
self.embedding_dim = embedding_dim
self.batch_size = batch_size
self.embeddings = Embedding()
self.knowledge_bases: Dict[str, FAISS] = {}
os.makedirs(self.base_path, exist_ok=True)
faiss_files = glob.glob(os.path.join(base_path, '*.faiss'))
# 获取不带后缀的名称
file_names_without_extension = [os.path.splitext(os.path.basename(file))[0] for file in faiss_files]
for name in file_names_without_extension:
self.load_knowledge_base(name)
def create_knowledge_base(self, name: str):
index = faiss.IndexFlatL2(self.embedding_dim)
kb = FAISS(self.embeddings, index, InMemoryDocstore(), {})
if name in self.knowledge_bases:
print(f"Knowledge base '{name}' already exists.")
return
self.knowledge_bases[name] = kb
self.save_knowledge_base(name)
print(f"Knowledge base '{name}' created.")
def delete_knowledge_base(self, name: str):
if name in self.knowledge_bases:
del self.knowledge_bases[name]
os.remove(os.path.join(self.base_path, f"{name}.faiss"))
print(f"Knowledge base '{name}' deleted.")
else:
print(f"Knowledge base '{name}' does not exist.")
def load_knowledge_base(self, name: str):
kb_path = os.path.join(self.base_path, f"{name}.faiss")
if os.path.exists(kb_path):
self.knowledge_bases[name] = FAISS.load_local(self.base_path, self.embeddings, name, allow_dangerous_deserialization=True)
print(f"Knowledge base '{name}' loaded.")
else:
print(f"Knowledge base '{name}' does not exist.")
def save_knowledge_base(self, name: str):
if name in self.knowledge_bases:
self.knowledge_bases[name].save_local(self.base_path, name)
print(f"Knowledge base '{name}' saved.")
else:
print(f"Knowledge base '{name}' does not exist.")
# Document(page_content = '渠道版', metadata = {
# 'source': './files/input/PS004.pdf',
# 'page': 0
# }), Document(page_content = '2/20.', metadata = {
# 'source': './files/input/PS004.pdf',
# 'page': 1
# })
def add_documents_to_kb(self, name: str, file_paths: List[str]):
if name not in self.knowledge_bases:
print(f"Knowledge base '{name}' does not exist.")
self.create_knowledge_base(name)
kb = self.knowledge_bases[name]
documents = self.load_documents(file_paths)
print(f"Loaded {len(documents)} documents.")
print(documents)
pages = self.split_documents(documents)
print(f"Split documents into {len(pages)} pages.")
# print(pages)
doc_ids = []
for i in range(0, len(pages), self.batch_size):
batch = pages[i:i+self.batch_size]
doc_ids.extend(kb.add_documents(batch))
self.save_knowledge_base(name)
return doc_ids
def load_documents(self, file_paths: List[str]):
documents = []
for file_path in file_paths:
loader = self.get_loader(file_path)
documents.extend(loader.load())
return documents
def get_loader(self, file_path: str):
if file_path.endswith('.txt'):
return TextLoader(file_path)
elif file_path.endswith('.json'):
return JSONLoader(file_path)
elif file_path.endswith('.pdf'):
return PyPDFLoader(file_path)
else:
raise ValueError("Unsupported file format")
def split_documents(self, documents):
text_splitter = RecursiveCharacterTextSplitter(separators=[
"\n\n",
"\n",
" ",
".",
",",
"\u200b", # Zero-width space
"\uff0c", # Fullwidth comma
"\u3001", # Ideographic comma
"\uff0e", # Fullwidth full stop
"\u3002", # Ideographic full stop
"",
],
chunk_size=512, chunk_overlap=0)
return text_splitter.split_documents(documents)
def retrieve_documents(self, names: List[str], query: str):
results = []
for name in names:
if name not in self.knowledge_bases:
print(f"Knowledge base '{name}' does not exist.")
continue
retriever = self.knowledge_bases[name].as_retriever(
search_type="mmr",
search_kwargs={"score_threshold": 0.5, "k": 3}
)
docs = retriever.get_relevant_documents(query)
results.extend([{"name": name, "content": doc.page_content,"meta": doc.metadata} for doc in docs])
return results
def get_bases(self):
data = self.knowledge_bases.keys()
return list(data)
def get_df_bases(self):
import pandas as pd
data = self.knowledge_bases.keys()
return pd.DataFrame(list(data), columns=['列表'])
knowledgeBase = KnowledgeBaseManager()