Spaces:
Sleeping
Sleeping
UPDATE: chat history retention
Browse files- functions.py +40 -3
functions.py
CHANGED
@@ -5,6 +5,9 @@ from langchain_qdrant import QdrantVectorStore
|
|
5 |
from langchain_core.prompts.chat import ChatPromptTemplate
|
6 |
from langchain_core.output_parsers import StrOutputParser
|
7 |
from langchain.retrievers import ParentDocumentRetriever
|
|
|
|
|
|
|
8 |
from langchain.storage import InMemoryStore
|
9 |
from langchain.docstore.document import Document
|
10 |
from langchain_huggingface import HuggingFaceEmbeddings
|
@@ -35,16 +38,24 @@ prompt = """
|
|
35 |
3. **Exclusive Reliance on Training Data**: Answer user queries exclusively based on the provided training data. If a query is not covered by the training data, use the fallback response.
|
36 |
4. **Restrictive Role Focus**: Do not answer questions or perform tasks unrelated to your role and training data.
|
37 |
DO NOT ADD ANYTHING BY YOURSELF OR ANSWER ON YOUR OWN!
|
38 |
-
Based on the context answer the following question.
|
39 |
Context:
|
40 |
=====================================
|
41 |
{context}
|
42 |
=====================================
|
|
|
|
|
43 |
{question}
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
45 |
"""
|
46 |
prompt = ChatPromptTemplate.from_template(prompt)
|
47 |
store = InMemoryStore()
|
|
|
48 |
|
49 |
|
50 |
def createUser(username: str, password: str) -> None:
|
@@ -146,6 +157,25 @@ def format_docs(docs: str):
|
|
146 |
else: pass
|
147 |
return context
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
def answerQuery(query: str, vectorstore: str, llmModel: str = "llama3-70b-8192") -> str:
|
150 |
global prompt
|
151 |
global client
|
@@ -167,11 +197,18 @@ def answerQuery(query: str, vectorstore: str, llmModel: str = "llama3-70b-8192")
|
|
167 |
base_compressor=compressor, base_retriever=retriever
|
168 |
)
|
169 |
chain = (
|
170 |
-
{"context": retriever | RunnableLambda(format_docs), "question": RunnablePassthrough()}
|
171 |
| prompt
|
172 |
| ChatGroq(model = llmModel, temperature = 0.75, max_tokens = 512)
|
173 |
| StrOutputParser()
|
174 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
return {
|
176 |
"output": chain.invoke(query)
|
177 |
}
|
|
|
5 |
from langchain_core.prompts.chat import ChatPromptTemplate
|
6 |
from langchain_core.output_parsers import StrOutputParser
|
7 |
from langchain.retrievers import ParentDocumentRetriever
|
8 |
+
from langchain_core.runnables.history import RunnableWithMessageHistory
|
9 |
+
from langchain.memory import ChatMessageHistory
|
10 |
+
from langchain_core.chat_history import BaseChatMessageHistory
|
11 |
from langchain.storage import InMemoryStore
|
12 |
from langchain.docstore.document import Document
|
13 |
from langchain_huggingface import HuggingFaceEmbeddings
|
|
|
38 |
3. **Exclusive Reliance on Training Data**: Answer user queries exclusively based on the provided training data. If a query is not covered by the training data, use the fallback response.
|
39 |
4. **Restrictive Role Focus**: Do not answer questions or perform tasks unrelated to your role and training data.
|
40 |
DO NOT ADD ANYTHING BY YOURSELF OR ANSWER ON YOUR OWN!
|
41 |
+
Based on the context answer the following question. Remember that you need to frame a meaningful answer in under 512 words.
|
42 |
Context:
|
43 |
=====================================
|
44 |
{context}
|
45 |
=====================================
|
46 |
+
Question:
|
47 |
+
=====================================
|
48 |
{question}
|
49 |
+
|
50 |
+
Also, below I am providing you the previous question you were asked and the output you generated. It's just for your reference so that you know the topic you have been talking about and nothing else:
|
51 |
+
=====================================
|
52 |
+
{chatHistory}
|
53 |
+
=====================================
|
54 |
+
NOTE: generate responses WITHOUT prepending phrases like "Response:", "Output:", or "Answer:", etc. Also do not let the user know that you are answering from any extracted context or something.
|
55 |
"""
|
56 |
prompt = ChatPromptTemplate.from_template(prompt)
|
57 |
store = InMemoryStore()
|
58 |
+
chatHistoryStore = dict()
|
59 |
|
60 |
|
61 |
def createUser(username: str, password: str) -> None:
|
|
|
157 |
else: pass
|
158 |
return context
|
159 |
|
160 |
+
|
161 |
+
def get_session_history(session_id: str) -> BaseChatMessageHistory:
|
162 |
+
if session_id not in store:
|
163 |
+
store[session_id] = ChatMessageHistory()
|
164 |
+
return store[session_id]
|
165 |
+
|
166 |
+
|
167 |
+
def trimMessages(chain_input):
|
168 |
+
for storeName in chatHistoryStore:
|
169 |
+
messages = chatHistoryStore[storeName].messages
|
170 |
+
if len(messages) <= 2:
|
171 |
+
pass
|
172 |
+
else:
|
173 |
+
chatHistoryStore[storeName].clear()
|
174 |
+
for message in messages[-2: ]:
|
175 |
+
chatHistoryStore[storeName].add_message(message)
|
176 |
+
return True
|
177 |
+
|
178 |
+
|
179 |
def answerQuery(query: str, vectorstore: str, llmModel: str = "llama3-70b-8192") -> str:
|
180 |
global prompt
|
181 |
global client
|
|
|
197 |
base_compressor=compressor, base_retriever=retriever
|
198 |
)
|
199 |
chain = (
|
200 |
+
{"context": retriever | RunnableLambda(format_docs), "question": RunnablePassthrough(), "chatHistory": RunnablePassthrough()}
|
201 |
| prompt
|
202 |
| ChatGroq(model = llmModel, temperature = 0.75, max_tokens = 512)
|
203 |
| StrOutputParser()
|
204 |
)
|
205 |
+
chain = RunnableWithMessageHistory(
|
206 |
+
chain,
|
207 |
+
get_session_history,
|
208 |
+
input_messages_key = "question",
|
209 |
+
history_messages_key = "chatHistory"
|
210 |
+
)
|
211 |
+
chain = RunnablePassthrough.assign(messages_trimmed = trimMessages) | chain
|
212 |
return {
|
213 |
"output": chain.invoke(query)
|
214 |
}
|