|
from langchain import LLMChain, PromptTemplate |
|
from langchain.document_loaders import NotionDirectoryLoader |
|
from langchain.text_splitter import MarkdownTextSplitter, SpacyTextSplitter |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.vectorstores import FAISS |
|
from langchain.chains import RetrievalQA |
|
from langchain.chains.question_answering import load_qa_chain |
|
|
|
from langchain.document_loaders import NotionDirectoryLoader |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain.agents import initialize_agent, AgentType, Tool, ZeroShotAgent, AgentExecutor |
|
|
|
from models import llm |
|
|
|
|
|
class CustomEmbedding: |
|
notionDirectoryLoader = NotionDirectoryLoader( |
|
"/Users/peichao.dong/Documents/projects/dpc/ABstract/docs/pages") |
|
embeddings = HuggingFaceEmbeddings() |
|
|
|
def calculateEmbedding(self): |
|
documents = self.notionDirectoryLoader.load() |
|
|
|
|
|
|
|
text_splitter = MarkdownTextSplitter( |
|
chunk_size=2048, chunk_overlap=0) |
|
texts = text_splitter.split_documents(documents) |
|
|
|
docsearch = FAISS.from_documents(texts, self.embeddings) |
|
docsearch.save_local( |
|
folder_path="./documents/abstract.faiss") |
|
|
|
|
|
|
|
def getFAQChain(self, llm=llm(temperature=0.7)): |
|
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) |
|
docsearch = FAISS.load_local( |
|
"./documents/abstract.faiss", self.embeddings) |
|
|
|
_template = """Given the following conversation and a follow up question, rephrase the follow up question to be a chinese standalone question. |
|
|
|
Chat History: |
|
{chat_history} |
|
Follow Up Input: {question} |
|
Standalone question:""" |
|
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) |
|
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT) |
|
|
|
doc_chain = load_qa_chain(llm, chain_type="stuff") |
|
qa = ConversationalRetrievalChain( retriever= docsearch.as_retriever(search_kwargs={"k": 1}), |
|
question_generator=question_generator, |
|
combine_docs_chain=doc_chain, |
|
memory=memory) |
|
return qa |
|
|
|
def faq(self, input): |
|
qa = self.getFAQChain() |
|
response = qa({"question": f"{input}"}) |
|
return response["answer"] |
|
|
|
def getFAQAgent(self): |
|
tools = [Tool(name="ABstract system FAQ", func= self.faq, description="Useful for anwer questions about ABstract system")] |
|
memory = ConversationBufferMemory(memory_key="chat_history") |
|
|
|
prefix = """Have a conversation with a human, answering the following questions as best you can. You have access to the following tools:""" |
|
suffix = """The final Answer should be in Chines! Begin!" |
|
|
|
{chat_history} |
|
Question: {input} |
|
{agent_scratchpad}""" |
|
|
|
prompt = ZeroShotAgent.create_prompt( |
|
tools, |
|
prefix=prefix, |
|
suffix=suffix, |
|
input_variables=["input", "chat_history", "agent_scratchpad"] |
|
) |
|
|
|
llm_chain = LLMChain(llm=llm(), prompt=prompt) |
|
agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True) |
|
faq_agent = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=memory) |
|
return faq_agent |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
customerEmbedding = CustomEmbedding() |
|
customerEmbedding.calculateEmbedding() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|