Spaces:
No application file
No application file
from langchain.base_language import BaseLanguageModel | |
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor | |
from langchain.memory import ConversationBufferMemory, ReadOnlySharedMemory | |
from langchain.chains import LLMChain, RetrievalQA | |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings | |
from langchain.prompts import PromptTemplate | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.vectorstores import Chroma | |
from loader import DialogueLoader | |
from chains.dialogue_answering.prompts import ( | |
DIALOGUE_PREFIX, | |
DIALOGUE_SUFFIX, | |
SUMMARY_PROMPT | |
) | |
class DialogueWithSharedMemoryChains: | |
zero_shot_react_llm: BaseLanguageModel = None | |
ask_llm: BaseLanguageModel = None | |
embeddings: HuggingFaceEmbeddings = None | |
embedding_model: str = None | |
vector_search_top_k: int = 6 | |
dialogue_path: str = None | |
dialogue_loader: DialogueLoader = None | |
device: str = None | |
def __init__(self, zero_shot_react_llm: BaseLanguageModel = None, ask_llm: BaseLanguageModel = None, | |
params: dict = None): | |
self.zero_shot_react_llm = zero_shot_react_llm | |
self.ask_llm = ask_llm | |
params = params or {} | |
self.embedding_model = params.get('embedding_model', 'GanymedeNil/text2vec-large-chinese') | |
self.vector_search_top_k = params.get('vector_search_top_k', 6) | |
self.dialogue_path = params.get('dialogue_path', '') | |
self.device = 'cuda' if params.get('use_cuda', False) else 'cpu' | |
self.dialogue_loader = DialogueLoader(self.dialogue_path) | |
self._init_cfg() | |
self._init_state_of_history() | |
self.memory_chain, self.memory = self._agents_answer() | |
self.agent_chain = self._create_agent_chain() | |
def _init_cfg(self): | |
model_kwargs = { | |
'device': self.device | |
} | |
self.embeddings = HuggingFaceEmbeddings(model_name=self.embedding_model, model_kwargs=model_kwargs) | |
def _init_state_of_history(self): | |
documents = self.dialogue_loader.load() | |
text_splitter = CharacterTextSplitter(chunk_size=3, chunk_overlap=1) | |
texts = text_splitter.split_documents(documents) | |
docsearch = Chroma.from_documents(texts, self.embeddings, collection_name="state-of-history") | |
self.state_of_history = RetrievalQA.from_chain_type(llm=self.ask_llm, chain_type="stuff", | |
retriever=docsearch.as_retriever()) | |
def _agents_answer(self): | |
memory = ConversationBufferMemory(memory_key="chat_history") | |
readonly_memory = ReadOnlySharedMemory(memory=memory) | |
memory_chain = LLMChain( | |
llm=self.ask_llm, | |
prompt=SUMMARY_PROMPT, | |
verbose=True, | |
memory=readonly_memory, # use the read-only memory to prevent the tool from modifying the memory | |
) | |
return memory_chain, memory | |
def _create_agent_chain(self): | |
dialogue_participants = self.dialogue_loader.dialogue.participants_to_export() | |
tools = [ | |
Tool( | |
name="State of Dialogue History System", | |
func=self.state_of_history.run, | |
description=f"Dialogue with {dialogue_participants} - The answers in this section are very useful " | |
f"when searching for chat content between {dialogue_participants}. Input should be a " | |
f"complete question. " | |
), | |
Tool( | |
name="Summary", | |
func=self.memory_chain.run, | |
description="useful for when you summarize a conversation. The input to this tool should be a string, " | |
"representing who will read this summary. " | |
) | |
] | |
prompt = ZeroShotAgent.create_prompt( | |
tools, | |
prefix=DIALOGUE_PREFIX, | |
suffix=DIALOGUE_SUFFIX, | |
input_variables=["input", "chat_history", "agent_scratchpad"] | |
) | |
llm_chain = LLMChain(llm=self.zero_shot_react_llm, prompt=prompt) | |
agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True) | |
agent_chain = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=self.memory) | |
return agent_chain | |