|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings |
|
from vectorstores import MyFAISS |
|
from langchain.document_loaders import UnstructuredFileLoader, TextLoader, CSVLoader |
|
from configs.model_config import * |
|
import datetime |
|
from textsplitter import ChineseTextSplitter |
|
from typing import List |
|
from utils import torch_gc |
|
from tqdm import tqdm |
|
from pypinyin import lazy_pinyin |
|
from loader import UnstructuredPaddleImageLoader, UnstructuredPaddlePDFLoader |
|
from models.base import (BaseAnswer, |
|
AnswerResult) |
|
from models.loader.args import parser |
|
from models.loader import LoaderCheckPoint |
|
import models.shared as shared |
|
from agent import bing_search |
|
from langchain.docstore.document import Document |
|
from functools import lru_cache |
|
from textsplitter.zh_title_enhance import zh_title_enhance |
|
|
|
|
|
|
|
def _embeddings_hash(self): |
|
return hash(self.model_name) |
|
|
|
|
|
HuggingFaceEmbeddings.__hash__ = _embeddings_hash |
|
|
|
|
|
|
|
@lru_cache(CACHED_VS_NUM) |
|
def load_vector_store(vs_path, embeddings): |
|
return MyFAISS.load_local(vs_path, embeddings) |
|
|
|
|
|
def tree(filepath, ignore_dir_names=None, ignore_file_names=None): |
|
"""返回两个列表,第一个列表为 filepath 下全部文件的完整路径, 第二个为对应的文件名""" |
|
if ignore_dir_names is None: |
|
ignore_dir_names = [] |
|
if ignore_file_names is None: |
|
ignore_file_names = [] |
|
ret_list = [] |
|
if isinstance(filepath, str): |
|
if not os.path.exists(filepath): |
|
print("路径不存在") |
|
return None, None |
|
elif os.path.isfile(filepath) and os.path.basename(filepath) not in ignore_file_names: |
|
return [filepath], [os.path.basename(filepath)] |
|
elif os.path.isdir(filepath) and os.path.basename(filepath) not in ignore_dir_names: |
|
for file in os.listdir(filepath): |
|
fullfilepath = os.path.join(filepath, file) |
|
if os.path.isfile(fullfilepath) and os.path.basename(fullfilepath) not in ignore_file_names: |
|
ret_list.append(fullfilepath) |
|
if os.path.isdir(fullfilepath) and os.path.basename(fullfilepath) not in ignore_dir_names: |
|
ret_list.extend(tree(fullfilepath, ignore_dir_names, ignore_file_names)[0]) |
|
return ret_list, [os.path.basename(p) for p in ret_list] |
|
|
|
|
|
def load_file(filepath, sentence_size=SENTENCE_SIZE, using_zh_title_enhance=ZH_TITLE_ENHANCE): |
|
if filepath.lower().endswith(".md"): |
|
loader = UnstructuredFileLoader(filepath, mode="elements") |
|
docs = loader.load() |
|
elif filepath.lower().endswith(".txt"): |
|
loader = TextLoader(filepath, autodetect_encoding=True) |
|
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) |
|
docs = loader.load_and_split(textsplitter) |
|
elif filepath.lower().endswith(".pdf"): |
|
loader = UnstructuredPaddlePDFLoader(filepath) |
|
textsplitter = ChineseTextSplitter(pdf=True, sentence_size=sentence_size) |
|
docs = loader.load_and_split(textsplitter) |
|
elif filepath.lower().endswith(".jpg") or filepath.lower().endswith(".png"): |
|
loader = UnstructuredPaddleImageLoader(filepath, mode="elements") |
|
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) |
|
docs = loader.load_and_split(text_splitter=textsplitter) |
|
elif filepath.lower().endswith(".csv"): |
|
loader = CSVLoader(filepath) |
|
docs = loader.load() |
|
else: |
|
loader = UnstructuredFileLoader(filepath, mode="elements") |
|
textsplitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) |
|
docs = loader.load_and_split(text_splitter=textsplitter) |
|
if using_zh_title_enhance: |
|
docs = zh_title_enhance(docs) |
|
write_check_file(filepath, docs) |
|
return docs |
|
|
|
|
|
def write_check_file(filepath, docs): |
|
folder_path = os.path.join(os.path.dirname(filepath), "tmp_files") |
|
if not os.path.exists(folder_path): |
|
os.makedirs(folder_path) |
|
fp = os.path.join(folder_path, 'load_file.txt') |
|
with open(fp, 'a+', encoding='utf-8') as fout: |
|
fout.write("filepath=%s,len=%s" % (filepath, len(docs))) |
|
fout.write('\n') |
|
for i in docs: |
|
fout.write(str(i)) |
|
fout.write('\n') |
|
fout.close() |
|
|
|
|
|
def generate_prompt(related_docs: List[str], |
|
query: str, |
|
prompt_template: str = PROMPT_TEMPLATE, ) -> str: |
|
context = "\n".join([doc.page_content for doc in related_docs]) |
|
prompt = prompt_template.replace("{question}", query).replace("{context}", context) |
|
return prompt |
|
|
|
|
|
def search_result2docs(search_results): |
|
docs = [] |
|
for result in search_results: |
|
doc = Document(page_content=result["snippet"] if "snippet" in result.keys() else "", |
|
metadata={"source": result["link"] if "link" in result.keys() else "", |
|
"filename": result["title"] if "title" in result.keys() else ""}) |
|
docs.append(doc) |
|
return docs |
|
|
|
|
|
class LocalDocQA: |
|
llm: BaseAnswer = None |
|
embeddings: object = None |
|
top_k: int = VECTOR_SEARCH_TOP_K |
|
chunk_size: int = CHUNK_SIZE |
|
chunk_conent: bool = True |
|
score_threshold: int = VECTOR_SEARCH_SCORE_THRESHOLD |
|
|
|
def init_cfg(self, |
|
embedding_model: str = EMBEDDING_MODEL, |
|
embedding_device=EMBEDDING_DEVICE, |
|
llm_model: BaseAnswer = None, |
|
top_k=VECTOR_SEARCH_TOP_K, |
|
): |
|
self.llm = llm_model |
|
self.embeddings = HuggingFaceEmbeddings(model_name="C:/Users/Administrator/text2vec-large-chinese", |
|
model_kwargs={'device': embedding_device}) |
|
|
|
|
|
|
|
self.top_k = top_k |
|
|
|
def init_knowledge_vector_store(self, |
|
filepath: str or List[str], |
|
vs_path: str or os.PathLike = None, |
|
sentence_size=SENTENCE_SIZE): |
|
loaded_files = [] |
|
failed_files = [] |
|
if isinstance(filepath, str): |
|
if not os.path.exists(filepath): |
|
print("路径不存在") |
|
return None |
|
elif os.path.isfile(filepath): |
|
file = os.path.split(filepath)[-1] |
|
try: |
|
docs = load_file(filepath, sentence_size) |
|
logger.info(f"{file} 已成功加载") |
|
loaded_files.append(filepath) |
|
except Exception as e: |
|
logger.error(e) |
|
logger.info(f"{file} 未能成功加载") |
|
return None |
|
elif os.path.isdir(filepath): |
|
docs = [] |
|
for fullfilepath, file in tqdm(zip(*tree(filepath, ignore_dir_names=['tmp_files'])), desc="加载文件"): |
|
try: |
|
docs += load_file(fullfilepath, sentence_size) |
|
loaded_files.append(fullfilepath) |
|
except Exception as e: |
|
logger.error(e) |
|
failed_files.append(file) |
|
|
|
if len(failed_files) > 0: |
|
logger.info("以下文件未能成功加载:") |
|
for file in failed_files: |
|
logger.info(f"{file}\n") |
|
|
|
else: |
|
docs = [] |
|
for file in filepath: |
|
try: |
|
docs += load_file(file) |
|
logger.info(f"{file} 已成功加载") |
|
loaded_files.append(file) |
|
except Exception as e: |
|
logger.error(e) |
|
logger.info(f"{file} 未能成功加载") |
|
if len(docs) > 0: |
|
logger.info("文件加载完毕,正在生成向量库") |
|
if vs_path and os.path.isdir(vs_path) and "index.faiss" in os.listdir(vs_path): |
|
vector_store = load_vector_store(vs_path, self.embeddings) |
|
vector_store.add_documents(docs) |
|
torch_gc() |
|
else: |
|
if not vs_path: |
|
vs_path = os.path.join(KB_ROOT_PATH, |
|
f"""{"".join(lazy_pinyin(os.path.splitext(file)[0]))}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""", |
|
"vector_store") |
|
vector_store = MyFAISS.from_documents(docs, self.embeddings) |
|
torch_gc() |
|
|
|
vector_store.save_local(vs_path) |
|
return vs_path, loaded_files |
|
else: |
|
logger.info("文件均未成功加载,请检查依赖包或替换为其他文件再次上传。") |
|
return None, loaded_files |
|
|
|
def one_knowledge_add(self, vs_path, one_title, one_conent, one_content_segmentation, sentence_size): |
|
try: |
|
if not vs_path or not one_title or not one_conent: |
|
logger.info("知识库添加错误,请确认知识库名字、标题、内容是否正确!") |
|
return None, [one_title] |
|
docs = [Document(page_content=one_conent + "\n", metadata={"source": one_title})] |
|
if not one_content_segmentation: |
|
text_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) |
|
docs = text_splitter.split_documents(docs) |
|
if os.path.isdir(vs_path) and os.path.isfile(vs_path + "/index.faiss"): |
|
vector_store = load_vector_store(vs_path, self.embeddings) |
|
vector_store.add_documents(docs) |
|
else: |
|
vector_store = MyFAISS.from_documents(docs, self.embeddings) |
|
torch_gc() |
|
vector_store.save_local(vs_path) |
|
return vs_path, [one_title] |
|
except Exception as e: |
|
logger.error(e) |
|
return None, [one_title] |
|
|
|
def get_knowledge_based_answer(self, query, vs_path, chat_history=[], streaming: bool = STREAMING): |
|
vector_store = load_vector_store(vs_path, self.embeddings) |
|
vector_store.chunk_size = self.chunk_size |
|
vector_store.chunk_conent = self.chunk_conent |
|
vector_store.score_threshold = self.score_threshold |
|
related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k) |
|
torch_gc() |
|
if len(related_docs_with_score) > 0: |
|
prompt = generate_prompt(related_docs_with_score, query) |
|
else: |
|
prompt = query |
|
|
|
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history, |
|
streaming=streaming): |
|
resp = answer_result.llm_output["answer"] |
|
history = answer_result.history |
|
history[-1][0] = query |
|
response = {"query": query, |
|
"result": resp, |
|
"source_documents": related_docs_with_score} |
|
yield response, history |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_knowledge_based_conent_test(self, query, vs_path, chunk_conent, |
|
score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD, |
|
vector_search_top_k=VECTOR_SEARCH_TOP_K, chunk_size=CHUNK_SIZE): |
|
vector_store = load_vector_store(vs_path, self.embeddings) |
|
|
|
vector_store.chunk_conent = chunk_conent |
|
vector_store.score_threshold = score_threshold |
|
vector_store.chunk_size = chunk_size |
|
related_docs_with_score = vector_store.similarity_search_with_score(query, k=vector_search_top_k) |
|
if not related_docs_with_score: |
|
response = {"query": query, |
|
"source_documents": []} |
|
return response, "" |
|
torch_gc() |
|
prompt = "\n".join([doc.page_content for doc in related_docs_with_score]) |
|
response = {"query": query, |
|
"source_documents": related_docs_with_score} |
|
return response, prompt |
|
|
|
def get_search_result_based_answer(self, query, chat_history=[], streaming: bool = STREAMING): |
|
results = bing_search(query) |
|
result_docs = search_result2docs(results) |
|
prompt = generate_prompt(result_docs, query) |
|
|
|
for answer_result in self.llm.generatorAnswer(prompt=prompt, history=chat_history, |
|
streaming=streaming): |
|
resp = answer_result.llm_output["answer"] |
|
history = answer_result.history |
|
history[-1][0] = query |
|
response = {"query": query, |
|
"result": resp, |
|
"source_documents": result_docs} |
|
yield response, history |
|
|
|
def delete_file_from_vector_store(self, |
|
filepath: str or List[str], |
|
vs_path): |
|
vector_store = load_vector_store(vs_path, self.embeddings) |
|
status = vector_store.delete_doc(filepath) |
|
return status |
|
|
|
def update_file_from_vector_store(self, |
|
filepath: str or List[str], |
|
vs_path, |
|
docs: List[Document],): |
|
vector_store = load_vector_store(vs_path, self.embeddings) |
|
status = vector_store.update_doc(filepath, docs) |
|
return status |
|
|
|
def list_file_from_vector_store(self, |
|
vs_path, |
|
fullpath=False): |
|
vector_store = load_vector_store(vs_path, self.embeddings) |
|
docs = vector_store.list_docs() |
|
if fullpath: |
|
return docs |
|
else: |
|
return [os.path.split(doc)[-1] for doc in docs] |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
args = None |
|
args = parser.parse_args(args=['--model-dir', '/media/checkpoint/', '--model', 'chatglm-6b', '--no-remote-model']) |
|
|
|
args_dict = vars(args) |
|
shared.loaderCheckPoint = LoaderCheckPoint(args_dict) |
|
llm_model_ins = shared.loaderLLM() |
|
llm_model_ins.set_history_len(LLM_HISTORY_LEN) |
|
|
|
local_doc_qa = LocalDocQA() |
|
local_doc_qa.init_cfg(llm_model=llm_model_ins) |
|
query = "本项目使用的embedding模型是什么,消耗多少显存" |
|
vs_path = "/media/gpt4-pdf-chatbot-langchain/dev-langchain-ChatGLM/vector_store/test" |
|
last_print_len = 0 |
|
|
|
|
|
|
|
|
|
for resp, history in local_doc_qa.get_search_result_based_answer(query=query, |
|
chat_history=[], |
|
streaming=True): |
|
print(resp["result"][last_print_len:], end="", flush=True) |
|
last_print_len = len(resp["result"]) |
|
source_text = [f"""出处 [{inum + 1}] {doc.metadata['source'] if doc.metadata['source'].startswith("http") |
|
else os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" |
|
|
|
for inum, doc in |
|
enumerate(resp["source_documents"])] |
|
logger.info("\n\n" + "\n\n".join(source_text)) |
|
pass |
|
|