jianuo's picture
first
09321b6
raw
history blame
No virus
4.04 kB
import os
from typing import Dict, Iterable, List, Union
import json
from langchain.document_loaders import (PyPDFLoader, TextLoader,
UnstructuredFileLoader)
from langchain.embeddings import ModelScopeEmbeddings
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS, VectorStore
class Retrieval:
def __init__(self,
embedding: Embeddings = None,
vs_cls: VectorStore = None,
top_k: int = 5,
vs_params: Dict = {}):
self.embedding = embedding or ModelScopeEmbeddings(
model_id='damo/nlp_gte_sentence-embedding_chinese-base')
self.top_k = top_k
self.vs_cls = vs_cls or FAISS
self.vs_params = vs_params
self.vs = None
def construct(self, docs):
assert len(docs) > 0
if isinstance(docs[0], str):
self.vs = self.vs_cls.from_texts(docs, self.embedding,
**self.vs_params)
elif isinstance(docs[0], Document):
self.vs = self.vs_cls.from_documents(docs, self.embedding,
**self.vs_params)
def retrieve(self, query: str) -> List[str]:
res = self.vs.similarity_search(query, k=self.top_k)
if 'page' in res[0].metadata:
res.sort(key=lambda doc: doc.metadata['page'])
return [r.page_content for r in res]
class ToolRetrieval(Retrieval):
def __init__(self,
embedding: Embeddings = None,
vs_cls: VectorStore = None,
top_k: int = 5,
vs_params: Dict = {}):
super().__init__(embedding, vs_cls, top_k, vs_params)
def retrieve(self, query: str) -> Dict[str, str]:
res = self.vs.similarity_search(query, k=self.top_k)
final_res = {}
for r in res:
content = r.page_content
name = json.loads(content)['name']
final_res[name] = content
return final_res
class KnowledgeRetrieval(Retrieval):
def __init__(self,
docs,
embedding: Embeddings = None,
vs_cls: VectorStore = None,
top_k: int = 5,
vs_params: Dict = {}):
super().__init__(embedding, vs_cls, top_k, vs_params)
self.construct(docs)
@classmethod
def from_file(cls,
file_path: Union[str, list],
embedding: Embeddings = None,
vs_cls: VectorStore = None,
top_k: int = 5,
vs_params: Dict = {}):
textsplitter = CharacterTextSplitter()
all_files = []
if isinstance(file_path, str) and os.path.isfile(file_path):
all_files.append(file_path)
elif isinstance(file_path, list):
all_files = file_path
elif os.path.isdir(file_path):
for root, dirs, files in os.walk(file_path):
for f in files:
all_files.append(os.path.join(root, f))
else:
raise ValueError('file_path must be a file or a directory')
docs = []
for f in all_files:
if f.lower().endswith('.txt'):
loader = TextLoader(f, autodetect_encoding=True)
docs += (loader.load_and_split(textsplitter))
elif f.lower().endswith('.md'):
loader = UnstructuredFileLoader(f, mode='elements')
docs += loader.load()
elif f.lower().endswith('.pdf'):
loader = PyPDFLoader(f)
docs += (loader.load_and_split(textsplitter))
else:
print(f'not support file type: {f}, will be support soon')
if len(docs) == 0:
return None
else:
return cls(docs, embedding, vs_cls, top_k, vs_params)