File size: 4,805 Bytes
f707439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import locale
from typing import Dict, List, Any
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.llms import HuggingFacePipeline
from langchain.retrievers.document_compressors import LLMChainExtractor
from langchain.retrievers import ContextualCompressionRetriever
from langchain.vectorstores import Chroma
from langchain import PromptTemplate, LLMChain
from langchain.chains import RetrievalQA, ConversationalRetrievalChain
from langchain.prompts import PromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.memory import ConversationBufferMemory
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from llm_for_langchain import LLM
from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableBranch
    
class EndpointHandler():
    def __init__(self, path=""):

        # Config LangChain
        os.environ["LANGCHAIN_TRACING_V2"] = "true"
        os.environ["LANGCHAIN_API_KEY"] = getpass.getpass()

        # Create LLM
        chat = LLM(model_name_or_path=path, bit4=False)

        # Create Text-Embedding Model
        embedding_function = HuggingFaceBgeEmbeddings(
            model_name="DMetaSoul/Dmeta-embedding",
            model_kwargs={'device': 'cuda'},
            encode_kwargs={'normalize_embeddings': True}
        )
        
        # Load Vector db
        urls = [
            "https://hk.on.cc/hk/bkn/cnt/news/20221019/bkn-20221019040039334-1019_00822_001.html",
            "https://www.hk01.com/%E7%A4%BE%E6%9C%83%E6%96%B0%E8%81%9E/822848/%E5%89%B5%E7%A7%91%E7%B2%BE%E8%8B%B1-%E5%87%BA%E6%88%B02022%E4%B8%96%E7%95%8C%E6%8A%80%E8%83%BD%E5%A4%A7%E8%B3%BD%E7%89%B9%E5%88%A5%E8%B3%BD",
            "https://www.wenweipo.com/epaper/view/newsDetail/1582436861224292352.html",
            "https://www.thinkhk.com/article/2023-03/24/59874.html"
        ]

        loader = WebBaseLoader(urls)
        data = loader.load()

        text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 16)
        all_splits = text_splitter.split_documents(data)
        
        vectorstore = Chroma.from_documents(documents=all_splits, embedding=embedding_function)
        retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
        
        compressor = LLMChainExtractor.from_llm(self.llm)
        retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
        
        SYSTEM_TEMPLATE = """
        Answer the user's questions based on the below context. 
        If the context doesn't contain any relevant information to the question, don't make something up and just say "I don't know":
        
        <context>
        {context}
        </context>
        """
        
        question_answering_prompt = ChatPromptTemplate.from_messages(
            [
                (
                    "system",
                    SYSTEM_TEMPLATE,
                ),
                MessagesPlaceholder(variable_name="messages"),
            ]
        )

        # Wrap the retriever
        query_transforming_retriever_chain = RunnableBranch(
            (
                lambda x: len(x.get("messages", [])) == 1,
                # If only one message, then we just pass that message's content to retriever
                (lambda x: x["messages"][-1].content) | retriever,
            ),
            # If messages, then we pass inputs to LLM chain to transform the query, then pass to retriever
            question_answering_prompt | chat | StrOutputParser() | retriever,
        ).with_config(run_name="chat_retriever_chain")
        
        document_chain = create_stuff_documents_chain(chat, question_answering_prompt)
        
        self.conversational_retrieval_chain = RunnablePassthrough.assign(
            context=query_transforming_retriever_chain,
        ).assign(
            answer=document_chain,
        )

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        # pseudo
        # self.model(input)
        inputs = data.pop("inputs", data)
        output = self.conversational_retrieval_chain.invoke(
                    {
                        "messages": [
                            HumanMessage(content=inputs)
                        ],
                    }
                )
        print(output['answer'])
        
        return output