File size: 4,039 Bytes
09321b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
from typing import Dict, Iterable, List, Union

import json
from langchain.document_loaders import (PyPDFLoader, TextLoader,
                                        UnstructuredFileLoader)
from langchain.embeddings import ModelScopeEmbeddings
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS, VectorStore


class Retrieval:

    def __init__(self,
                 embedding: Embeddings = None,
                 vs_cls: VectorStore = None,
                 top_k: int = 5,
                 vs_params: Dict = {}):
        self.embedding = embedding or ModelScopeEmbeddings(
            model_id='damo/nlp_gte_sentence-embedding_chinese-base')
        self.top_k = top_k
        self.vs_cls = vs_cls or FAISS
        self.vs_params = vs_params
        self.vs = None

    def construct(self, docs):
        assert len(docs) > 0
        if isinstance(docs[0], str):
            self.vs = self.vs_cls.from_texts(docs, self.embedding,
                                             **self.vs_params)
        elif isinstance(docs[0], Document):
            self.vs = self.vs_cls.from_documents(docs, self.embedding,
                                                 **self.vs_params)

    def retrieve(self, query: str) -> List[str]:
        res = self.vs.similarity_search(query, k=self.top_k)
        if 'page' in res[0].metadata:
            res.sort(key=lambda doc: doc.metadata['page'])
        return [r.page_content for r in res]


class ToolRetrieval(Retrieval):

    def __init__(self,
                 embedding: Embeddings = None,
                 vs_cls: VectorStore = None,
                 top_k: int = 5,
                 vs_params: Dict = {}):
        super().__init__(embedding, vs_cls, top_k, vs_params)

    def retrieve(self, query: str) -> Dict[str, str]:
        res = self.vs.similarity_search(query, k=self.top_k)

        final_res = {}

        for r in res:
            content = r.page_content
            name = json.loads(content)['name']
            final_res[name] = content

        return final_res


class KnowledgeRetrieval(Retrieval):

    def __init__(self,
                 docs,
                 embedding: Embeddings = None,
                 vs_cls: VectorStore = None,
                 top_k: int = 5,
                 vs_params: Dict = {}):
        super().__init__(embedding, vs_cls, top_k, vs_params)
        self.construct(docs)

    @classmethod
    def from_file(cls,
                  file_path: Union[str, list],
                  embedding: Embeddings = None,
                  vs_cls: VectorStore = None,
                  top_k: int = 5,
                  vs_params: Dict = {}):

        textsplitter = CharacterTextSplitter()
        all_files = []
        if isinstance(file_path, str) and os.path.isfile(file_path):
            all_files.append(file_path)
        elif isinstance(file_path, list):
            all_files = file_path
        elif os.path.isdir(file_path):
            for root, dirs, files in os.walk(file_path):
                for f in files:
                    all_files.append(os.path.join(root, f))
        else:
            raise ValueError('file_path must be a file or a directory')

        docs = []
        for f in all_files:
            if f.lower().endswith('.txt'):
                loader = TextLoader(f, autodetect_encoding=True)
                docs += (loader.load_and_split(textsplitter))
            elif f.lower().endswith('.md'):
                loader = UnstructuredFileLoader(f, mode='elements')
                docs += loader.load()
            elif f.lower().endswith('.pdf'):
                loader = PyPDFLoader(f)
                docs += (loader.load_and_split(textsplitter))
            else:
                print(f'not support file type: {f}, will be support soon')

        if len(docs) == 0:
            return None
        else:
            return cls(docs, embedding, vs_cls, top_k, vs_params)