use ConversationChain + ConversationSummaryBufferMemory
Browse files- Makefile +6 -0
- app_modules/llm_chat_chain.py +14 -11
- app_modules/llm_inference.py +12 -4
- app_modules/llm_loader.py +1 -0
- server.py +6 -5
- unit_test.py +1 -1
Makefile
CHANGED
@@ -12,9 +12,15 @@ endif
|
|
12 |
test:
|
13 |
python test.py
|
14 |
|
|
|
|
|
|
|
15 |
chat:
|
16 |
python test.py chat
|
17 |
|
|
|
|
|
|
|
18 |
unittest:
|
19 |
python unit_test.py $(TEST)
|
20 |
|
|
|
12 |
test:
|
13 |
python test.py
|
14 |
|
15 |
+
test2:
|
16 |
+
python server.py
|
17 |
+
|
18 |
chat:
|
19 |
python test.py chat
|
20 |
|
21 |
+
chat2:
|
22 |
+
python unit_test.py chat
|
23 |
+
|
24 |
unittest:
|
25 |
python unit_test.py $(TEST)
|
26 |
|
app_modules/llm_chat_chain.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import os
|
|
|
2 |
|
3 |
-
from langchain import
|
4 |
-
from langchain.chains import ConversationalRetrievalChain
|
5 |
from langchain.chains.base import Chain
|
6 |
-
from langchain.memory import
|
7 |
|
8 |
from app_modules.llm_inference import LLMInference
|
9 |
|
@@ -12,7 +12,7 @@ def get_llama_2_prompt_template():
|
|
12 |
B_INST, E_INST = "[INST]", "[/INST]"
|
13 |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
14 |
|
15 |
-
instruction = "Chat History:\n\n{
|
16 |
system_prompt = "You are a helpful assistant, you always only answer for the assistant then you stop. Read the chat history to get context"
|
17 |
# system_prompt = """\
|
18 |
# You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. \n\nDo not output any emotional expression. Read the chat history to get context.\
|
@@ -32,20 +32,20 @@ class ChatChain(LLMInference):
|
|
32 |
get_llama_2_prompt_template()
|
33 |
if os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
|
34 |
else """You are a chatbot having a conversation with a human.
|
35 |
-
{
|
36 |
-
Human: {
|
37 |
Chatbot:"""
|
38 |
)
|
39 |
|
40 |
print(f"template: {template}")
|
41 |
|
42 |
-
prompt = PromptTemplate(
|
43 |
-
input_variables=["chat_history", "question"], template=template
|
44 |
-
)
|
45 |
|
46 |
-
memory =
|
|
|
|
|
47 |
|
48 |
-
llm_chain =
|
49 |
llm=self.llm_loader.llm,
|
50 |
prompt=prompt,
|
51 |
verbose=True,
|
@@ -53,3 +53,6 @@ Chatbot:"""
|
|
53 |
)
|
54 |
|
55 |
return llm_chain
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
+
from typing import List, Optional
|
3 |
|
4 |
+
from langchain import ConversationChain, PromptTemplate
|
|
|
5 |
from langchain.chains.base import Chain
|
6 |
+
from langchain.memory import ConversationSummaryBufferMemory
|
7 |
|
8 |
from app_modules.llm_inference import LLMInference
|
9 |
|
|
|
12 |
B_INST, E_INST = "[INST]", "[/INST]"
|
13 |
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
14 |
|
15 |
+
instruction = "Chat History:\n\n{history} \n\nUser: {input}"
|
16 |
system_prompt = "You are a helpful assistant, you always only answer for the assistant then you stop. Read the chat history to get context"
|
17 |
# system_prompt = """\
|
18 |
# You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information. \n\nDo not output any emotional expression. Read the chat history to get context.\
|
|
|
32 |
get_llama_2_prompt_template()
|
33 |
if os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
|
34 |
else """You are a chatbot having a conversation with a human.
|
35 |
+
{history}
|
36 |
+
Human: {input}
|
37 |
Chatbot:"""
|
38 |
)
|
39 |
|
40 |
print(f"template: {template}")
|
41 |
|
42 |
+
prompt = PromptTemplate(input_variables=["history", "input"], template=template)
|
|
|
|
|
43 |
|
44 |
+
memory = ConversationSummaryBufferMemory(
|
45 |
+
llm=self.llm_loader.llm, max_token_limit=1024, return_messages=True
|
46 |
+
)
|
47 |
|
48 |
+
llm_chain = ConversationChain(
|
49 |
llm=self.llm_loader.llm,
|
50 |
prompt=prompt,
|
51 |
verbose=True,
|
|
|
53 |
)
|
54 |
|
55 |
return llm_chain
|
56 |
+
|
57 |
+
def run_chain(self, chain, inputs, callbacks: Optional[List] = []):
|
58 |
+
return chain({"input": inputs["question"]}, callbacks)
|
app_modules/llm_inference.py
CHANGED
@@ -4,6 +4,7 @@ import time
|
|
4 |
import urllib
|
5 |
from queue import Queue
|
6 |
from threading import Thread
|
|
|
7 |
|
8 |
from langchain.chains.base import Chain
|
9 |
|
@@ -29,6 +30,9 @@ class LLMInference(metaclass=abc.ABCMeta):
|
|
29 |
|
30 |
return self.chain
|
31 |
|
|
|
|
|
|
|
32 |
def call_chain(
|
33 |
self,
|
34 |
inputs,
|
@@ -45,9 +49,11 @@ class LLMInference(metaclass=abc.ABCMeta):
|
|
45 |
|
46 |
chain = self.get_chain()
|
47 |
result = (
|
48 |
-
self.
|
|
|
|
|
49 |
if streaming_handler is not None
|
50 |
-
else chain
|
51 |
)
|
52 |
|
53 |
if "answer" in result:
|
@@ -67,9 +73,11 @@ class LLMInference(metaclass=abc.ABCMeta):
|
|
67 |
self.llm_loader.lock.release()
|
68 |
|
69 |
def _execute_chain(self, chain, inputs, q, sh):
|
70 |
-
q.put(chain
|
71 |
|
72 |
-
def
|
|
|
|
|
73 |
que = Queue()
|
74 |
|
75 |
t = Thread(
|
|
|
4 |
import urllib
|
5 |
from queue import Queue
|
6 |
from threading import Thread
|
7 |
+
from typing import List, Optional
|
8 |
|
9 |
from langchain.chains.base import Chain
|
10 |
|
|
|
30 |
|
31 |
return self.chain
|
32 |
|
33 |
+
def run_chain(self, chain, inputs, callbacks: Optional[List] = []):
|
34 |
+
return chain(inputs, callbacks)
|
35 |
+
|
36 |
def call_chain(
|
37 |
self,
|
38 |
inputs,
|
|
|
49 |
|
50 |
chain = self.get_chain()
|
51 |
result = (
|
52 |
+
self._run_chain_with_streaming_handler(
|
53 |
+
chain, inputs, streaming_handler, testing
|
54 |
+
)
|
55 |
if streaming_handler is not None
|
56 |
+
else self.run_chain(chain, inputs)
|
57 |
)
|
58 |
|
59 |
if "answer" in result:
|
|
|
73 |
self.llm_loader.lock.release()
|
74 |
|
75 |
def _execute_chain(self, chain, inputs, q, sh):
|
76 |
+
q.put(self.run_chain(chain, inputs, callbacks=[sh]))
|
77 |
|
78 |
+
def _run_chain_with_streaming_handler(
|
79 |
+
self, chain, inputs, streaming_handler, testing
|
80 |
+
):
|
81 |
que = Queue()
|
82 |
|
83 |
t = Thread(
|
app_modules/llm_loader.py
CHANGED
@@ -188,6 +188,7 @@ class LLMLoader:
|
|
188 |
)
|
189 |
elif self.llm_model_type == "hftgi":
|
190 |
HFTGI_SERVER_URL = os.environ.get("HFTGI_SERVER_URL")
|
|
|
191 |
self.llm = HuggingFaceTextGenInference(
|
192 |
inference_server_url=HFTGI_SERVER_URL,
|
193 |
max_new_tokens=self.max_tokens_limit / 2,
|
|
|
188 |
)
|
189 |
elif self.llm_model_type == "hftgi":
|
190 |
HFTGI_SERVER_URL = os.environ.get("HFTGI_SERVER_URL")
|
191 |
+
self.max_tokens_limit = 4096
|
192 |
self.llm = HuggingFaceTextGenInference(
|
193 |
inference_server_url=HFTGI_SERVER_URL,
|
194 |
max_new_tokens=self.max_tokens_limit / 2,
|
server.py
CHANGED
@@ -78,17 +78,18 @@ def chat_sync(
|
|
78 |
) -> str:
|
79 |
print("question@chat_sync:", question)
|
80 |
result = do_chat(question, history, chat_id, None)
|
81 |
-
return result["
|
82 |
|
83 |
|
84 |
if __name__ == "__main__":
|
85 |
# print_llm_response(json.loads(chat("What's deep learning?", [])))
|
86 |
chat_start = timer()
|
87 |
-
chat_sync("
|
88 |
chat_sync("more on finance", chat_id="test_user")
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
92 |
chat_end = timer()
|
93 |
total_time = chat_end - chat_start
|
94 |
print(f"Total time used: {total_time:.3f} s")
|
|
|
78 |
) -> str:
|
79 |
print("question@chat_sync:", question)
|
80 |
result = do_chat(question, history, chat_id, None)
|
81 |
+
return result["response"]
|
82 |
|
83 |
|
84 |
if __name__ == "__main__":
|
85 |
# print_llm_response(json.loads(chat("What's deep learning?", [])))
|
86 |
chat_start = timer()
|
87 |
+
chat_sync("what's deep learning?", chat_id="test_user")
|
88 |
chat_sync("more on finance", chat_id="test_user")
|
89 |
+
chat_sync("more on Sentiment analysis", chat_id="test_user")
|
90 |
+
chat_sync("Write the game 'snake' in python", chat_id="test_user")
|
91 |
+
chat_sync("给我讲一个年轻人奋斗创业最终取得成功的故事。", chat_id="test_user")
|
92 |
+
chat_sync("给这个故事起一个标题", chat_id="test_user")
|
93 |
chat_end = timer()
|
94 |
total_time = chat_end - chat_start
|
95 |
print(f"Total time used: {total_time:.3f} s")
|
unit_test.py
CHANGED
@@ -170,7 +170,7 @@ def chat():
|
|
170 |
end = timer()
|
171 |
print(f"Completed in {end - start:.3f}s")
|
172 |
|
173 |
-
chat_history.append((query, result["
|
174 |
|
175 |
chat_end = timer()
|
176 |
print(f"Total time used: {chat_end - chat_start:.3f}s")
|
|
|
170 |
end = timer()
|
171 |
print(f"Completed in {end - start:.3f}s")
|
172 |
|
173 |
+
chat_history.append((query, result["response"]))
|
174 |
|
175 |
chat_end = timer()
|
176 |
print(f"Total time used: {chat_end - chat_start:.3f}s")
|