File size: 6,162 Bytes
494b300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import TextLoader, JSONLoader, PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.docstore.in_memory import InMemoryDocstore
import faiss
import os
import glob
from typing import Any,List,Dict
from embedding import Embedding


class KnowledgeBaseManager:
    def __init__(self, base_path="./knowledge_bases", embedding_dim=512, batch_size=16):
        self.base_path = base_path
        self.embedding_dim = embedding_dim
        self.batch_size = batch_size
        self.embeddings = Embedding()
        self.knowledge_bases: Dict[str, FAISS] = {}
        os.makedirs(self.base_path, exist_ok=True)

        faiss_files = glob.glob(os.path.join(base_path, '*.faiss'))
        # 获取不带后缀的名称
        file_names_without_extension = [os.path.splitext(os.path.basename(file))[0] for file in faiss_files]
        for name in file_names_without_extension:
            self.load_knowledge_base(name)


    def create_knowledge_base(self, name: str):
        index = faiss.IndexFlatL2(self.embedding_dim)
        kb = FAISS(self.embeddings, index, InMemoryDocstore(), {})
        if name in self.knowledge_bases:
            print(f"Knowledge base '{name}' already exists.")
            return
        
        self.knowledge_bases[name] = kb
        self.save_knowledge_base(name)
        print(f"Knowledge base '{name}' created.")

    def delete_knowledge_base(self, name: str):
        if name in self.knowledge_bases:
            del self.knowledge_bases[name]
            os.remove(os.path.join(self.base_path, f"{name}.faiss"))
            print(f"Knowledge base '{name}' deleted.")
        else:
            print(f"Knowledge base '{name}' does not exist.")

    def load_knowledge_base(self, name: str):
        kb_path = os.path.join(self.base_path, f"{name}.faiss")
        if os.path.exists(kb_path):
            self.knowledge_bases[name] = FAISS.load_local(self.base_path, self.embeddings, name, allow_dangerous_deserialization=True)
            print(f"Knowledge base '{name}' loaded.")
        else:
            print(f"Knowledge base '{name}' does not exist.")

    def save_knowledge_base(self, name: str):
        if name in self.knowledge_bases:
            self.knowledge_bases[name].save_local(self.base_path, name)
            print(f"Knowledge base '{name}' saved.")
        else:
            print(f"Knowledge base '{name}' does not exist.")

    # Document(page_content = '渠道版', metadata = {
	# 'source': './files/input/PS004.pdf',
	# 'page': 0
    # }), Document(page_content = '2/20.', metadata = {
    #     'source': './files/input/PS004.pdf',
    #     'page': 1
    # })
    def add_documents_to_kb(self, name: str, file_paths: List[str]):
        if name not in self.knowledge_bases:
            print(f"Knowledge base '{name}' does not exist.")
            self.create_knowledge_base(name)
        
        kb = self.knowledge_bases[name]
        documents = self.load_documents(file_paths)
        print(f"Loaded {len(documents)} documents.")
        print(documents)
        pages = self.split_documents(documents)
        print(f"Split documents into {len(pages)} pages.")
        # print(pages)
        
        doc_ids = []
        for i in range(0, len(pages), self.batch_size):
            batch = pages[i:i+self.batch_size]
            doc_ids.extend(kb.add_documents(batch))
        
        self.save_knowledge_base(name)
        return doc_ids

    def load_documents(self, file_paths: List[str]):
        documents = []
        for file_path in file_paths:
            loader = self.get_loader(file_path)
            documents.extend(loader.load())
        return documents

    def get_loader(self, file_path: str):
        if file_path.endswith('.txt'):
            return TextLoader(file_path)
        elif file_path.endswith('.json'):
            return JSONLoader(file_path)
        elif file_path.endswith('.pdf'):
            return PyPDFLoader(file_path)
        else:
            raise ValueError("Unsupported file format")

    def split_documents(self, documents):
        text_splitter = RecursiveCharacterTextSplitter(separators=[
                                                    "\n\n",
                                                    "\n",
                                                    " ",
                                                    ".",
                                                    ",",
                                                    "\u200b",  # Zero-width space
                                                    "\uff0c",  # Fullwidth comma
                                                    "\u3001",  # Ideographic comma
                                                    "\uff0e",  # Fullwidth full stop
                                                    "\u3002",  # Ideographic full stop
                                                    "",
                                                ],
                                                chunk_size=512, chunk_overlap=0)
        return text_splitter.split_documents(documents)

    def retrieve_documents(self, names: List[str], query: str):
        results = []
        for name in names:
            if name not in self.knowledge_bases:
                print(f"Knowledge base '{name}' does not exist.")
                continue
            
            retriever = self.knowledge_bases[name].as_retriever(
                search_type="mmr",
                search_kwargs={"score_threshold": 0.5, "k": 3}
            )
            docs = retriever.get_relevant_documents(query)
            results.extend([{"name": name, "content": doc.page_content,"meta": doc.metadata} for doc in docs])
            
        
        return results
    
    def get_bases(self):
        data = self.knowledge_bases.keys()
        return list(data)
    
    def get_df_bases(self):
        import pandas as pd
        data = self.knowledge_bases.keys()
        return pd.DataFrame(list(data), columns=['列表'])

knowledgeBase = KnowledgeBaseManager()