Spaces:
Sleeping
Sleeping
Adrian Cowham
commited on
Commit
•
4d55d8d
1
Parent(s):
e7172f3
modifed to accept chat history from client
Browse files- src/app.py +48 -2
src/app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
|
2 |
import os
|
3 |
from threading import Lock
|
4 |
from typing import Any, Dict, Optional, Tuple
|
@@ -26,8 +26,10 @@ system_template = """
|
|
26 |
The context below contains excerpts from 'How to Win Friends & Influence People,' by Dail Carnegie. You must only use the information in the context below to formulate your response. If there is not enough information to formulate a response, you must respond with
|
27 |
"I'm sorry, but I can't find the answer to your question in, the book How to Win Friends & Influence People."
|
28 |
|
29 |
-
|
30 |
{context}
|
|
|
|
|
31 |
{chat_history}
|
32 |
"""
|
33 |
|
@@ -60,6 +62,44 @@ def getretriever():
|
|
60 |
|
61 |
retriever = getretriever()
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
def getanswer(chain, question, history):
|
64 |
if hasattr(chain, "value"):
|
65 |
chain = chain.value
|
@@ -129,4 +169,10 @@ with gr.Blocks() as block:
|
|
129 |
ex4 = gr.Button(value="Why should I try to get along with people better?", variant="primary")
|
130 |
ex4.click(getanswer, inputs=[chain_state, ex4, state], outputs=[chatbot, state, message])
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
block.launch(debug=True)
|
|
|
1 |
+
import json
|
2 |
import os
|
3 |
from threading import Lock
|
4 |
from typing import Any, Dict, Optional, Tuple
|
|
|
26 |
The context below contains excerpts from 'How to Win Friends & Influence People,' by Dail Carnegie. You must only use the information in the context below to formulate your response. If there is not enough information to formulate a response, you must respond with
|
27 |
"I'm sorry, but I can't find the answer to your question in, the book How to Win Friends & Influence People."
|
28 |
|
29 |
+
Begin context:
|
30 |
{context}
|
31 |
+
End context.
|
32 |
+
|
33 |
{chat_history}
|
34 |
"""
|
35 |
|
|
|
62 |
|
63 |
retriever = getretriever()
|
64 |
|
65 |
+
def predict(message):
|
66 |
+
print(message)
|
67 |
+
msgJson = json.loads(message)
|
68 |
+
print(msgJson)
|
69 |
+
messages = [
|
70 |
+
SystemMessagePromptTemplate.from_template(system_template),
|
71 |
+
HumanMessagePromptTemplate.from_template("{question}")
|
72 |
+
]
|
73 |
+
qa_prompt = ChatPromptTemplate.from_messages(messages)
|
74 |
+
|
75 |
+
llm = ChatOpenAI(
|
76 |
+
openai_api_key=API_KEY,
|
77 |
+
model_name=MODEL,
|
78 |
+
verbose=True)
|
79 |
+
memory = AnswerConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
80 |
+
for msg in msgJson["history"]:
|
81 |
+
memory.save_context({"input": msg[0]}, {"answer": msg[1]})
|
82 |
+
|
83 |
+
chain = ConversationalRetrievalChain.from_llm(
|
84 |
+
llm,
|
85 |
+
retriever=retriever,
|
86 |
+
return_source_documents=USE_VERBOSE,
|
87 |
+
memory=memory,
|
88 |
+
verbose=USE_VERBOSE,
|
89 |
+
combine_docs_chain_kwargs={"prompt": qa_prompt})
|
90 |
+
chain.rephrase_question = False
|
91 |
+
lock = Lock()
|
92 |
+
lock.acquire()
|
93 |
+
try:
|
94 |
+
output = chain({"question": msgJson["question"]})
|
95 |
+
output = output["answer"]
|
96 |
+
except Exception as e:
|
97 |
+
print(e)
|
98 |
+
raise e
|
99 |
+
finally:
|
100 |
+
lock.release()
|
101 |
+
return output
|
102 |
+
|
103 |
def getanswer(chain, question, history):
|
104 |
if hasattr(chain, "value"):
|
105 |
chain = chain.value
|
|
|
169 |
ex4 = gr.Button(value="Why should I try to get along with people better?", variant="primary")
|
170 |
ex4.click(getanswer, inputs=[chain_state, ex4, state], outputs=[chatbot, state, message])
|
171 |
|
172 |
+
ex5 = gr.Button(value="How do I cite a Reddit thread?", variant="primary")
|
173 |
+
ex5.click(getanswer, inputs=[chain_state, ex5, state], outputs=[chatbot, state, message])
|
174 |
+
|
175 |
+
predictBtn = gr.Button(value="Predict", visible=False)
|
176 |
+
predictBtn.click(predict, inputs=[message], outputs=[message])
|
177 |
+
|
178 |
block.launch(debug=True)
|