lchakkei commited on
Commit
5eff215
1 Parent(s): 5e5f0bf

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +109 -27
handler.py CHANGED
@@ -16,20 +16,58 @@ from langchain.document_loaders import WebBaseLoader
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from llm_for_langchain import LLM
18
  from langchain.chains.qa_with_sources import load_qa_with_sources_chain
 
 
 
 
 
 
 
 
 
 
19
 
20
  class EndpointHandler():
21
  def __init__(self, path=""):
22
 
23
- self.llm = LLM(model_name_or_path=path, bit4=False)
 
 
 
 
 
 
 
 
 
24
 
25
- # Load Vector db
26
 
27
- self.embedding_function = HuggingFaceBgeEmbeddings(
28
- model_name="BAAI/bge-large-zh",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  model_kwargs={'device': 'cuda'},
30
  encode_kwargs={'normalize_embeddings': True}
31
  )
32
-
 
33
  urls = [
34
  "https://hk.on.cc/hk/bkn/cnt/news/20221019/bkn-20221019040039334-1019_00822_001.html",
35
  "https://www.hk01.com/%E7%A4%BE%E6%9C%83%E6%96%B0%E8%81%9E/822848/%E5%89%B5%E7%A7%91%E7%B2%BE%E8%8B%B1-%E5%87%BA%E6%88%B02022%E4%B8%96%E7%95%8C%E6%8A%80%E8%83%BD%E5%A4%A7%E8%B3%BD%E7%89%B9%E5%88%A5%E8%B3%BD",
@@ -43,40 +81,84 @@ class EndpointHandler():
43
  text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 16)
44
  all_splits = text_splitter.split_documents(data)
45
 
46
- vectorstore = Chroma.from_documents(documents=all_splits, embedding=self.embedding_function)
 
47
 
48
- # vectorstore = Chroma(persist_directory="db", embedding_function=embedding_function)
 
49
 
50
- compressor = LLMChainExtractor.from_llm(self.llm)
51
- self.retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=vectorstore.as_retriever(search_kwargs={"k": 4}))
52
-
53
- prompt_template = """<s>[INST] <<SYS>> You are a helpful assistant.
54
- Use the following context to Answer the question below briefly: <<SYS>>
55
 
56
- {history}
57
 
 
58
  {context}
59
 
60
- {question} [/INST] </s>
61
  """
 
62
 
63
- prompt = PromptTemplate(input_variables=["history", "context", "question"], template=prompt_template)
64
- memory = ConversationBufferMemory(input_key='question', memory_key='history', return_messages=True)
65
-
66
- self.qa_chain = RetrievalQA.from_chain_type(
67
- self.llm,
68
- chain_type="stuff",
69
- retriever=self.retriever,
70
- chain_type_kwargs={"prompt": prompt, "memory": memory}
71
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
74
  # pseudo
75
  # self.model(input)
76
  inputs = data.pop("inputs", data)
77
- output = self.qa_chain(inputs)
78
- print(output)
79
-
80
- return output
81
-
82
 
 
 
 
 
 
 
 
 
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from llm_for_langchain import LLM
18
  from langchain.chains.qa_with_sources import load_qa_with_sources_chain
19
+ from langchain.chains.combine_documents import create_stuff_documents_chain
20
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
21
+ from langchain_core.messages import HumanMessage
22
+ from langchain_core.output_parsers import StrOutputParser
23
+ from langchain_core.runnables import RunnableLambda, RunnableBranch, RunnablePassthrough
24
+ from operator import itemgetter
25
+ from langchain.schema import format_document
26
+ from langchain.memory import ConversationBufferMemory
27
+ from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string
28
+
29
 
30
  class EndpointHandler():
31
  def __init__(self, path=""):
32
 
33
+ # Config LangChain
34
+ # os.environ["LANGCHAIN_TRACING_V2"] = "true"
35
+ # os.environ["LANGCHAIN_API_KEY"] =
36
+
37
+ # Create LLM
38
+
39
+ # load the tokenizer and the quantized mistral model
40
+ model = AutoModelForCausalLM.from_pretrained(
41
+ path,
42
+ device_map="auto")
43
 
44
+ tokenizer = AutoTokenizer.from_pretrained(path)
45
 
46
+ # using HuggingFace's pipeline
47
+ pipeline = pipeline(
48
+ "text-generation",
49
+ model=model,
50
+ tokenizer=tokenizer,
51
+ use_cache=True,
52
+ device_map="auto",
53
+ max_new_tokens=5000,
54
+ do_sample=True,
55
+ top_k=1,
56
+ temperature = 0.01,
57
+ num_return_sequences=1,
58
+ eos_token_id=tokenizer.eos_token_id,
59
+ pad_token_id=tokenizer.eos_token_id,
60
+ )
61
+ chat = HuggingFacePipeline(pipeline=pipeline)
62
+
63
+ # Create Text-Embedding Model
64
+ embedding_function = HuggingFaceBgeEmbeddings(
65
+ model_name="DMetaSoul/Dmeta-embedding",
66
  model_kwargs={'device': 'cuda'},
67
  encode_kwargs={'normalize_embeddings': True}
68
  )
69
+
70
+ # Load Vector db
71
  urls = [
72
  "https://hk.on.cc/hk/bkn/cnt/news/20221019/bkn-20221019040039334-1019_00822_001.html",
73
  "https://www.hk01.com/%E7%A4%BE%E6%9C%83%E6%96%B0%E8%81%9E/822848/%E5%89%B5%E7%A7%91%E7%B2%BE%E8%8B%B1-%E5%87%BA%E6%88%B02022%E4%B8%96%E7%95%8C%E6%8A%80%E8%83%BD%E5%A4%A7%E8%B3%BD%E7%89%B9%E5%88%A5%E8%B3%BD",
 
81
  text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 16)
82
  all_splits = text_splitter.split_documents(data)
83
 
84
+ vectorstore = Chroma.from_documents(documents=all_splits, embedding=embedding_function)
85
+ retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
86
 
87
+ compressor = LLMChainExtractor.from_llm(chat)
88
+ retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
89
 
90
+ _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
91
+ Chat History:
92
+ {chat_history}
93
+ Follow Up Input: {question}
94
+ Standalone question:"""
95
 
96
+ CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
97
 
98
+ template = """Answer the question based only on the following context:
99
  {context}
100
 
101
+ Question: {question}
102
  """
103
+ ANSWER_PROMPT = ChatPromptTemplate.from_template(template)
104
 
105
+ self.memory = ConversationBufferMemory(
106
+ return_messages=True, output_key="answer", input_key="question"
 
 
 
 
 
 
107
  )
108
+
109
+ # First we add a step to load memory
110
+ # This adds a "memory" key to the input object
111
+ loaded_memory = RunnablePassthrough.assign(
112
+ chat_history=RunnableLambda(self.memory.load_memory_variables) | itemgetter("history"),
113
+ )
114
+ # Now we calculate the standalone question
115
+ standalone_question = {
116
+ "standalone_question": {
117
+ "question": lambda x: x["question"],
118
+ "chat_history": lambda x: get_buffer_string(x["chat_history"]),
119
+ }
120
+ | CONDENSE_QUESTION_PROMPT
121
+ | chat(temperature=0)
122
+ | StrOutputParser(),
123
+ }
124
+
125
+ DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
126
+
127
+ def _combine_documents(
128
+ docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
129
+ ):
130
+ doc_strings = [format_document(doc, document_prompt) for doc in docs]
131
+ return document_separator.join(doc_strings)
132
+
133
+ # Now we retrieve the documents
134
+ retrieved_documents = {
135
+ "docs": itemgetter("standalone_question") | retriever,
136
+ "question": lambda x: x["standalone_question"],
137
+ }
138
+ # Now we construct the inputs for the final prompt
139
+ final_inputs = {
140
+ "context": lambda x: _combine_documents(x["docs"]),
141
+ "question": itemgetter("question"),
142
+ }
143
+ # And finally, we do the part that returns the answers
144
+ answer = {
145
+ "answer": final_inputs | ANSWER_PROMPT | chat,
146
+ "docs": itemgetter("docs"),
147
+ }
148
+ # And now we put it all together!
149
+ self.final_chain = loaded_memory | standalone_question | retrieved_documents | answer
150
 
151
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
152
  # pseudo
153
  # self.model(input)
154
  inputs = data.pop("inputs", data)
155
+ result = self.final_chain.invoke(inputs)
156
+ print(result['answer'])
 
 
 
157
 
158
+ # Note that the memory does not save automatically
159
+ # This will be improved in the future
160
+ # For now you need to save it yourself
161
+ self.memory.save_context(inputs, {"answer": result["answer"].content})
162
+ self.memory.load_memory_variables({})
163
+
164
+ return result