chat_with_langchain / qa_chain /Chat_QA_chain_self.py
cosco's picture
Upload 304 files
184a47b verified
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.vectorstores import Chroma
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.chat_models import ChatOpenAI
import sys
sys.path.append('../llm-universe/project')
from qa_chain.model_to_llm import model_to_llm
from qa_chain.get_vectordb import get_vectordb
import re
class Chat_QA_chain_self:
""""
带历史记录的问答链
- model:调用的模型名称
- temperature:温度系数,控制生成的随机性
- top_k:返回检索的前k个相似文档
- chat_history:历史记录,输入一个列表,默认是一个空列表
- history_len:控制保留的最近 history_len 次对话
- file_path:建库文件所在路径
- persist_path:向量数据库持久化路径
- appid:星火
- api_key:星火、百度文心、OpenAI、智谱都需要传递的参数
- Spark_api_secret:星火秘钥
- Wenxin_secret_key:文心秘钥
- embeddings:使用的embedding模型
- embedding_key:使用的embedding模型的秘钥(智谱或者OpenAI)
"""
def __init__(self,model:str, temperature:float=0.0, top_k:int=4, chat_history:list=[], file_path:str=None, persist_path:str=None, appid:str=None, api_key:str=None, Spark_api_secret:str=None,Wenxin_secret_key:str=None, embedding = "openai",embedding_key:str=None):
self.model = model
self.temperature = temperature
self.top_k = top_k
self.chat_history = chat_history
#self.history_len = history_len
self.file_path = file_path
self.persist_path = persist_path
self.appid = appid
self.api_key = api_key
self.Spark_api_secret = Spark_api_secret
self.Wenxin_secret_key = Wenxin_secret_key
self.embedding = embedding
self.embedding_key = embedding_key
self.vectordb = get_vectordb(self.file_path, self.persist_path, self.embedding,self.embedding_key)
def clear_history(self):
"清空历史记录"
return self.chat_history.clear()
def change_history_length(self,history_len:int=1):
"""
保存指定对话轮次的历史记录
输入参数:
- history_len :控制保留的最近 history_len 次对话
- chat_history:当前的历史对话记录
输出:返回最近 history_len 次对话
"""
n = len(self.chat_history)
return self.chat_history[n-history_len:]
def answer(self, question:str=None,temperature = None, top_k = 4):
""""
核心方法,调用问答链
arguments:
- question:用户提问
"""
if len(question) == 0:
return "", self.chat_history
if len(question) == 0:
return ""
if temperature == None:
temperature = self.temperature
llm = model_to_llm(self.model, temperature, self.appid, self.api_key, self.Spark_api_secret,self.Wenxin_secret_key)
#self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
retriever = self.vectordb.as_retriever(search_type="similarity",
search_kwargs={'k': top_k}) #默认similarity,k=4
qa = ConversationalRetrievalChain.from_llm(
llm = llm,
retriever = retriever
)
#print(self.llm)
result = qa({"question": question,"chat_history": self.chat_history}) #result里有question、chat_history、answer
answer = result['answer']
answer = re.sub(r"\\n", '<br/>', answer)
self.chat_history.append((question,answer)) #更新历史记录
return self.chat_history #返回本次回答和更新后的历史记录