lchakkei commited on
Commit
f707439
1 Parent(s): d379398

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +115 -0
handler.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import locale
3
+ from typing import Dict, List, Any
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
+ from langchain.llms import HuggingFacePipeline
6
+ from langchain.retrievers.document_compressors import LLMChainExtractor
7
+ from langchain.retrievers import ContextualCompressionRetriever
8
+ from langchain.vectorstores import Chroma
9
+ from langchain import PromptTemplate, LLMChain
10
+ from langchain.chains import RetrievalQA, ConversationalRetrievalChain
11
+ from langchain.prompts import PromptTemplate
12
+ from langchain.prompts.prompt import PromptTemplate
13
+ from langchain.memory import ConversationBufferMemory
14
+ from langchain.embeddings import HuggingFaceBgeEmbeddings
15
+ from langchain.document_loaders import WebBaseLoader
16
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
17
+ from llm_for_langchain import LLM
18
+ from langchain.chains.qa_with_sources import load_qa_with_sources_chain
19
+ from langchain.chains.combine_documents import create_stuff_documents_chain
20
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
21
+ from langchain_core.messages import HumanMessage
22
+ from langchain_core.output_parsers import StrOutputParser
23
+ from langchain_core.runnables import RunnableBranch
24
+
25
+ class EndpointHandler():
26
+ def __init__(self, path=""):
27
+
28
+ # Config LangChain
29
+ os.environ["LANGCHAIN_TRACING_V2"] = "true"
30
+ os.environ["LANGCHAIN_API_KEY"] = getpass.getpass()
31
+
32
+ # Create LLM
33
+ chat = LLM(model_name_or_path=path, bit4=False)
34
+
35
+ # Create Text-Embedding Model
36
+ embedding_function = HuggingFaceBgeEmbeddings(
37
+ model_name="DMetaSoul/Dmeta-embedding",
38
+ model_kwargs={'device': 'cuda'},
39
+ encode_kwargs={'normalize_embeddings': True}
40
+ )
41
+
42
+ # Load Vector db
43
+ urls = [
44
+ "https://hk.on.cc/hk/bkn/cnt/news/20221019/bkn-20221019040039334-1019_00822_001.html",
45
+ "https://www.hk01.com/%E7%A4%BE%E6%9C%83%E6%96%B0%E8%81%9E/822848/%E5%89%B5%E7%A7%91%E7%B2%BE%E8%8B%B1-%E5%87%BA%E6%88%B02022%E4%B8%96%E7%95%8C%E6%8A%80%E8%83%BD%E5%A4%A7%E8%B3%BD%E7%89%B9%E5%88%A5%E8%B3%BD",
46
+ "https://www.wenweipo.com/epaper/view/newsDetail/1582436861224292352.html",
47
+ "https://www.thinkhk.com/article/2023-03/24/59874.html"
48
+ ]
49
+
50
+ loader = WebBaseLoader(urls)
51
+ data = loader.load()
52
+
53
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 16)
54
+ all_splits = text_splitter.split_documents(data)
55
+
56
+ vectorstore = Chroma.from_documents(documents=all_splits, embedding=embedding_function)
57
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
58
+
59
+ compressor = LLMChainExtractor.from_llm(self.llm)
60
+ retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
61
+
62
+ SYSTEM_TEMPLATE = """
63
+ Answer the user's questions based on the below context.
64
+ If the context doesn't contain any relevant information to the question, don't make something up and just say "I don't know":
65
+
66
+ <context>
67
+ {context}
68
+ </context>
69
+ """
70
+
71
+ question_answering_prompt = ChatPromptTemplate.from_messages(
72
+ [
73
+ (
74
+ "system",
75
+ SYSTEM_TEMPLATE,
76
+ ),
77
+ MessagesPlaceholder(variable_name="messages"),
78
+ ]
79
+ )
80
+
81
+ # Wrap the retriever
82
+ query_transforming_retriever_chain = RunnableBranch(
83
+ (
84
+ lambda x: len(x.get("messages", [])) == 1,
85
+ # If only one message, then we just pass that message's content to retriever
86
+ (lambda x: x["messages"][-1].content) | retriever,
87
+ ),
88
+ # If messages, then we pass inputs to LLM chain to transform the query, then pass to retriever
89
+ question_answering_prompt | chat | StrOutputParser() | retriever,
90
+ ).with_config(run_name="chat_retriever_chain")
91
+
92
+ document_chain = create_stuff_documents_chain(chat, question_answering_prompt)
93
+
94
+ self.conversational_retrieval_chain = RunnablePassthrough.assign(
95
+ context=query_transforming_retriever_chain,
96
+ ).assign(
97
+ answer=document_chain,
98
+ )
99
+
100
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
101
+ # pseudo
102
+ # self.model(input)
103
+ inputs = data.pop("inputs", data)
104
+ output = self.conversational_retrieval_chain.invoke(
105
+ {
106
+ "messages": [
107
+ HumanMessage(content=inputs)
108
+ ],
109
+ }
110
+ )
111
+ print(output['answer'])
112
+
113
+ return output
114
+
115
+