Rauhan commited on
Commit
9a054bf
1 Parent(s): 5e597cb

UPDATE: chat history retention

Browse files
Files changed (1) hide show
  1. functions.py +40 -3
functions.py CHANGED
@@ -5,6 +5,9 @@ from langchain_qdrant import QdrantVectorStore
5
  from langchain_core.prompts.chat import ChatPromptTemplate
6
  from langchain_core.output_parsers import StrOutputParser
7
  from langchain.retrievers import ParentDocumentRetriever
 
 
 
8
  from langchain.storage import InMemoryStore
9
  from langchain.docstore.document import Document
10
  from langchain_huggingface import HuggingFaceEmbeddings
@@ -35,16 +38,24 @@ prompt = """
35
  3. **Exclusive Reliance on Training Data**: Answer user queries exclusively based on the provided training data. If a query is not covered by the training data, use the fallback response.
36
  4. **Restrictive Role Focus**: Do not answer questions or perform tasks unrelated to your role and training data.
37
  DO NOT ADD ANYTHING BY YOURSELF OR ANSWER ON YOUR OWN!
38
- Based on the context answer the following question.
39
  Context:
40
  =====================================
41
  {context}
42
  =====================================
 
 
43
  {question}
44
- NOTE: generate responses WITHOUT prepending phrases like "Response:", "Output:", or "Answer:", etc
 
 
 
 
 
45
  """
46
  prompt = ChatPromptTemplate.from_template(prompt)
47
  store = InMemoryStore()
 
48
 
49
 
50
  def createUser(username: str, password: str) -> None:
@@ -146,6 +157,25 @@ def format_docs(docs: str):
146
  else: pass
147
  return context
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def answerQuery(query: str, vectorstore: str, llmModel: str = "llama3-70b-8192") -> str:
150
  global prompt
151
  global client
@@ -167,11 +197,18 @@ def answerQuery(query: str, vectorstore: str, llmModel: str = "llama3-70b-8192")
167
  base_compressor=compressor, base_retriever=retriever
168
  )
169
  chain = (
170
- {"context": retriever | RunnableLambda(format_docs), "question": RunnablePassthrough()}
171
  | prompt
172
  | ChatGroq(model = llmModel, temperature = 0.75, max_tokens = 512)
173
  | StrOutputParser()
174
  )
 
 
 
 
 
 
 
175
  return {
176
  "output": chain.invoke(query)
177
  }
 
5
  from langchain_core.prompts.chat import ChatPromptTemplate
6
  from langchain_core.output_parsers import StrOutputParser
7
  from langchain.retrievers import ParentDocumentRetriever
8
+ from langchain_core.runnables.history import RunnableWithMessageHistory
9
+ from langchain.memory import ChatMessageHistory
10
+ from langchain_core.chat_history import BaseChatMessageHistory
11
  from langchain.storage import InMemoryStore
12
  from langchain.docstore.document import Document
13
  from langchain_huggingface import HuggingFaceEmbeddings
 
38
  3. **Exclusive Reliance on Training Data**: Answer user queries exclusively based on the provided training data. If a query is not covered by the training data, use the fallback response.
39
  4. **Restrictive Role Focus**: Do not answer questions or perform tasks unrelated to your role and training data.
40
  DO NOT ADD ANYTHING BY YOURSELF OR ANSWER ON YOUR OWN!
41
+ Based on the context answer the following question. Remember that you need to frame a meaningful answer in under 512 words.
42
  Context:
43
  =====================================
44
  {context}
45
  =====================================
46
+ Question:
47
+ =====================================
48
  {question}
49
+
50
+ Also, below I am providing you the previous question you were asked and the output you generated. It's just for your reference so that you know the topic you have been talking about and nothing else:
51
+ =====================================
52
+ {chatHistory}
53
+ =====================================
54
+ NOTE: generate responses WITHOUT prepending phrases like "Response:", "Output:", or "Answer:", etc. Also do not let the user know that you are answering from any extracted context or something.
55
  """
56
  prompt = ChatPromptTemplate.from_template(prompt)
57
  store = InMemoryStore()
58
+ chatHistoryStore = dict()
59
 
60
 
61
  def createUser(username: str, password: str) -> None:
 
157
  else: pass
158
  return context
159
 
160
+
161
+ def get_session_history(session_id: str) -> BaseChatMessageHistory:
162
+ if session_id not in store:
163
+ store[session_id] = ChatMessageHistory()
164
+ return store[session_id]
165
+
166
+
167
+ def trimMessages(chain_input):
168
+ for storeName in chatHistoryStore:
169
+ messages = chatHistoryStore[storeName].messages
170
+ if len(messages) <= 2:
171
+ pass
172
+ else:
173
+ chatHistoryStore[storeName].clear()
174
+ for message in messages[-2: ]:
175
+ chatHistoryStore[storeName].add_message(message)
176
+ return True
177
+
178
+
179
  def answerQuery(query: str, vectorstore: str, llmModel: str = "llama3-70b-8192") -> str:
180
  global prompt
181
  global client
 
197
  base_compressor=compressor, base_retriever=retriever
198
  )
199
  chain = (
200
+ {"context": retriever | RunnableLambda(format_docs), "question": RunnablePassthrough(), "chatHistory": RunnablePassthrough()}
201
  | prompt
202
  | ChatGroq(model = llmModel, temperature = 0.75, max_tokens = 512)
203
  | StrOutputParser()
204
  )
205
+ chain = RunnableWithMessageHistory(
206
+ chain,
207
+ get_session_history,
208
+ input_messages_key = "question",
209
+ history_messages_key = "chatHistory"
210
+ )
211
+ chain = RunnablePassthrough.assign(messages_trimmed = trimMessages) | chain
212
  return {
213
  "output": chain.invoke(query)
214
  }