lchakkei commited on
Commit
5e5f0bf
1 Parent(s): c4df466

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +26 -88
handler.py CHANGED
@@ -16,35 +16,20 @@ from langchain.document_loaders import WebBaseLoader
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from llm_for_langchain import LLM
18
  from langchain.chains.qa_with_sources import load_qa_with_sources_chain
19
- from langchain.chains.combine_documents import create_stuff_documents_chain
20
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
21
- from langchain_core.messages import HumanMessage
22
- from langchain_core.output_parsers import StrOutputParser
23
- from langchain_core.runnables import RunnableLambda, RunnableBranch, RunnablePassthrough
24
- from operator import itemgetter
25
- from langchain.schema import format_document
26
- from langchain.memory import ConversationBufferMemory
27
- from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string
28
-
29
 
30
  class EndpointHandler():
31
  def __init__(self, path=""):
32
 
33
- # Config LangChain
34
- # os.environ["LANGCHAIN_TRACING_V2"] = "true"
35
- # os.environ["LANGCHAIN_API_KEY"] =
36
-
37
- # Create LLM
38
- chat = LLM(model_name_or_path=path, bit4=False)
39
-
40
- # Create Text-Embedding Model
41
- embedding_function = HuggingFaceBgeEmbeddings(
42
- model_name="DMetaSoul/Dmeta-embedding",
43
  model_kwargs={'device': 'cuda'},
44
  encode_kwargs={'normalize_embeddings': True}
45
  )
46
-
47
- # Load Vector db
48
  urls = [
49
  "https://hk.on.cc/hk/bkn/cnt/news/20221019/bkn-20221019040039334-1019_00822_001.html",
50
  "https://www.hk01.com/%E7%A4%BE%E6%9C%83%E6%96%B0%E8%81%9E/822848/%E5%89%B5%E7%A7%91%E7%B2%BE%E8%8B%B1-%E5%87%BA%E6%88%B02022%E4%B8%96%E7%95%8C%E6%8A%80%E8%83%BD%E5%A4%A7%E8%B3%BD%E7%89%B9%E5%88%A5%E8%B3%BD",
@@ -58,87 +43,40 @@ class EndpointHandler():
58
  text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 16)
59
  all_splits = text_splitter.split_documents(data)
60
 
61
- vectorstore = Chroma.from_documents(documents=all_splits, embedding=embedding_function)
62
- retriever = vectorstore.as_retriever(search_kwargs={"k": 4})
63
 
64
- compressor = LLMChainExtractor.from_llm(chat)
65
- retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
66
 
67
- _template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
 
68
 
69
- Chat History:
70
- {chat_history}
71
- Follow Up Input: {question}
72
- Standalone question:"""
73
 
74
- CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template)
75
 
76
- template = """Answer the question based only on the following context:
77
  {context}
78
 
79
- Question: {question}
80
  """
81
- ANSWER_PROMPT = ChatPromptTemplate.from_template(template)
82
-
83
- self.memory = ConversationBufferMemory(
84
- return_messages=True, output_key="answer", input_key="question"
85
- )
86
-
87
- # First we add a step to load memory
88
- # This adds a "memory" key to the input object
89
- loaded_memory = RunnablePassthrough.assign(
90
- chat_history=RunnableLambda(self.memory.load_memory_variables) | itemgetter("history"),
91
- )
92
- # Now we calculate the standalone question
93
- standalone_question = {
94
- "standalone_question": {
95
- "question": lambda x: x["question"],
96
- "chat_history": lambda x: get_buffer_string(x["chat_history"]),
97
- }
98
- | CONDENSE_QUESTION_PROMPT
99
- | chat(temperature=0)
100
- | StrOutputParser(),
101
- }
102
 
103
- DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
 
104
 
105
- def _combine_documents(
106
- docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, document_separator="\n\n"
107
- ):
108
- doc_strings = [format_document(doc, document_prompt) for doc in docs]
109
- return document_separator.join(doc_strings)
110
-
111
- # Now we retrieve the documents
112
- retrieved_documents = {
113
- "docs": itemgetter("standalone_question") | retriever,
114
- "question": lambda x: x["standalone_question"],
115
- }
116
- # Now we construct the inputs for the final prompt
117
- final_inputs = {
118
- "context": lambda x: _combine_documents(x["docs"]),
119
- "question": itemgetter("question"),
120
- }
121
- # And finally, we do the part that returns the answers
122
- answer = {
123
- "answer": final_inputs | ANSWER_PROMPT | chat,
124
- "docs": itemgetter("docs"),
125
- }
126
- # And now we put it all together!
127
- self.final_chain = loaded_memory | standalone_question | retrieved_documents | answer
128
 
129
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
130
  # pseudo
131
  # self.model(input)
132
  inputs = data.pop("inputs", data)
133
- result = self.final_chain.invoke(inputs)
134
- print(result['answer'])
135
-
136
- # Note that the memory does not save automatically
137
- # This will be improved in the future
138
- # For now you need to save it yourself
139
- self.memory.save_context(inputs, {"answer": result["answer"].content})
140
- self.memory.load_memory_variables({})
141
 
142
- return result
143
 
144
 
 
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from llm_for_langchain import LLM
18
  from langchain.chains.qa_with_sources import load_qa_with_sources_chain
 
 
 
 
 
 
 
 
 
 
19
 
20
  class EndpointHandler():
21
  def __init__(self, path=""):
22
 
23
+ self.llm = LLM(model_name_or_path=path, bit4=False)
24
+
25
+ # Load Vector db
26
+
27
+ self.embedding_function = HuggingFaceBgeEmbeddings(
28
+ model_name="BAAI/bge-large-zh",
 
 
 
 
29
  model_kwargs={'device': 'cuda'},
30
  encode_kwargs={'normalize_embeddings': True}
31
  )
32
+
 
33
  urls = [
34
  "https://hk.on.cc/hk/bkn/cnt/news/20221019/bkn-20221019040039334-1019_00822_001.html",
35
  "https://www.hk01.com/%E7%A4%BE%E6%9C%83%E6%96%B0%E8%81%9E/822848/%E5%89%B5%E7%A7%91%E7%B2%BE%E8%8B%B1-%E5%87%BA%E6%88%B02022%E4%B8%96%E7%95%8C%E6%8A%80%E8%83%BD%E5%A4%A7%E8%B3%BD%E7%89%B9%E5%88%A5%E8%B3%BD",
 
43
  text_splitter = RecursiveCharacterTextSplitter(chunk_size = 1000, chunk_overlap = 16)
44
  all_splits = text_splitter.split_documents(data)
45
 
46
+ vectorstore = Chroma.from_documents(documents=all_splits, embedding=self.embedding_function)
 
47
 
48
+ # vectorstore = Chroma(persist_directory="db", embedding_function=embedding_function)
 
49
 
50
+ compressor = LLMChainExtractor.from_llm(self.llm)
51
+ self.retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=vectorstore.as_retriever(search_kwargs={"k": 4}))
52
 
53
+ prompt_template = """<s>[INST] <<SYS>> You are a helpful assistant.
54
+ Use the following context to Answer the question below briefly: <<SYS>>
 
 
55
 
56
+ {history}
57
 
 
58
  {context}
59
 
60
+ {question} [/INST] </s>
61
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ prompt = PromptTemplate(input_variables=["history", "context", "question"], template=prompt_template)
64
+ memory = ConversationBufferMemory(input_key='question', memory_key='history', return_messages=True)
65
 
66
+ self.qa_chain = RetrievalQA.from_chain_type(
67
+ self.llm,
68
+ chain_type="stuff",
69
+ retriever=self.retriever,
70
+ chain_type_kwargs={"prompt": prompt, "memory": memory}
71
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
74
  # pseudo
75
  # self.model(input)
76
  inputs = data.pop("inputs", data)
77
+ output = self.qa_chain(inputs)
78
+ print(output)
 
 
 
 
 
 
79
 
80
+ return output
81
 
82