File size: 4,106 Bytes
cbbcc75
 
 
90729fc
 
 
cbbcc75
 
 
 
 
2a0c033
90729fc
87d8ff9
 
90729fc
 
 
2a0c033
90729fc
 
 
 
2a0c033
 
cbbcc75
2a0c033
 
90729fc
 
 
 
2a0c033
90729fc
cbbcc75
 
87d8ff9
cbbcc75
90729fc
2a0c033
cbbcc75
5c65101
cbbcc75
 
 
 
 
 
 
 
3772cf0
2a0c033
cbbcc75
 
 
 
90729fc
2a0c033
 
 
 
90729fc
2a0c033
 
 
 
5c65101
 
2a0c033
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90729fc
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from langchain import LLMChain, PromptTemplate
from langchain.document_loaders import NotionDirectoryLoader
from langchain.text_splitter import MarkdownTextSplitter, SpacyTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.chains.question_answering import load_qa_chain

from langchain.document_loaders import NotionDirectoryLoader
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
from langchain.agents import initialize_agent, AgentType, Tool, ZeroShotAgent, AgentExecutor

from models import llm


class CustomEmbedding:
    notionDirectoryLoader = NotionDirectoryLoader(
        "/Users/peichao.dong/Documents/projects/dpc/ABstract/docs/pages")
    embeddings = HuggingFaceEmbeddings()

    def calculateEmbedding(self):
        documents = self.notionDirectoryLoader.load()
        # text_splitter = SpacyTextSplitter(
        #     chunk_size=2048, pipeline="zh_core_web_sm", chunk_overlap=0)
        
        text_splitter = MarkdownTextSplitter(
            chunk_size=2048, chunk_overlap=0)
        texts = text_splitter.split_documents(documents)

        docsearch = FAISS.from_documents(texts, self.embeddings)
        docsearch.save_local(
            folder_path="./documents/abstract.faiss")

    

    def getFAQChain(self, llm=llm(temperature=0.7)):
        memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
        docsearch = FAISS.load_local(
            "./documents/abstract.faiss", self.embeddings)
        # retriever = VectorStoreRetriever(vectorstore=docsearch)
        _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a chinese standalone question.

        Chat History:
        {chat_history}
        Follow Up Input: {question}
        Standalone question:"""
        CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
        question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
        
        doc_chain = load_qa_chain(llm, chain_type="stuff")
        qa = ConversationalRetrievalChain( retriever= docsearch.as_retriever(search_kwargs={"k": 1}),
                                                    question_generator=question_generator,
                                                    combine_docs_chain=doc_chain,
                                                    memory=memory)
        return qa

    def faq(self, input):
        qa = self.getFAQChain()
        response = qa({"question": f"{input}"})
        return response["answer"]

    def getFAQAgent(self):
        tools = [Tool(name="ABstract system FAQ", func= self.faq, description="Useful for anwer questions about ABstract system")]
        memory = ConversationBufferMemory(memory_key="chat_history")

        prefix = """Have a conversation with a human, answering the following questions as best you can.  You have access to the following tools:"""
        suffix = """The final Answer should be in Chines! Begin!"

        {chat_history}
        Question: {input}
        {agent_scratchpad}"""

        prompt = ZeroShotAgent.create_prompt(
            tools, 
            prefix=prefix, 
            suffix=suffix, 
            input_variables=["input", "chat_history", "agent_scratchpad"]
        )

        llm_chain = LLMChain(llm=llm(), prompt=prompt)
        agent = ZeroShotAgent(llm_chain=llm_chain, tools=tools, verbose=True)
        faq_agent = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=memory)
        return faq_agent
        # faq_agent = initialize_agent(tools= tools, llm=llm(), agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION, verbose=True)


if __name__ == "__main__":
    customerEmbedding = CustomEmbedding()
    customerEmbedding.calculateEmbedding()
# # customerEmbedding.calculateNotionEmbedding()

# faq_chain = customerEmbedding.getFAQChain()
# result = faq_chain.run(
#     "Smart Domain 分层架构")

# print(result)