dh-mc commited on
Commit
8282222
1 Parent(s): b2d49c7

added USE_LLAMA_2_PROMPT_TEMPLATE

Browse files
Files changed (2) hide show
  1. .env.example +1 -0
  2. app_modules/llm_chat_chain.py +22 -1
.env.example CHANGED
@@ -19,6 +19,7 @@ HF_PIPELINE_DEVICE_TYPE=
19
  # LOAD_QUANTIZED_MODEL=4bit
20
  # LOAD_QUANTIZED_MODEL=8bit
21
 
 
22
  DISABLE_MODEL_PRELOADING=true
23
  CHAT_HISTORY_ENABLED=true
24
  SHOW_PARAM_SETTINGS=false
 
19
  # LOAD_QUANTIZED_MODEL=4bit
20
  # LOAD_QUANTIZED_MODEL=8bit
21
 
22
+ USE_LLAMA_2_PROMPT_TEMPLATE=true
23
  DISABLE_MODEL_PRELOADING=true
24
  CHAT_HISTORY_ENABLED=true
25
  SHOW_PARAM_SETTINGS=false
app_modules/llm_chat_chain.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from langchain import LLMChain, PromptTemplate
2
  from langchain.chains import ConversationalRetrievalChain
3
  from langchain.chains.base import Chain
@@ -6,19 +8,38 @@ from langchain.memory import ConversationBufferMemory
6
  from app_modules.llm_inference import LLMInference
7
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  class ChatChain(LLMInference):
10
  def __init__(self, llm_loader):
11
  super().__init__(llm_loader)
12
 
13
  def create_chain(self) -> Chain:
14
- template = """You are a chatbot having a conversation with a human.
 
 
 
15
  {chat_history}
16
  Human: {question}
17
  Chatbot:"""
 
 
 
18
 
19
  prompt = PromptTemplate(
20
  input_variables=["chat_history", "question"], template=template
21
  )
 
22
  memory = ConversationBufferMemory(memory_key="chat_history")
23
 
24
  llm_chain = LLMChain(
 
1
+ import os
2
+
3
  from langchain import LLMChain, PromptTemplate
4
  from langchain.chains import ConversationalRetrievalChain
5
  from langchain.chains.base import Chain
 
8
  from app_modules.llm_inference import LLMInference
9
 
10
 
11
+ def get_llama_2_prompt_template():
12
+ B_INST, E_INST = "[INST]", "[/INST]"
13
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
14
+
15
+ instruction = "Chat History:\n\n{chat_history} \n\nUser: {question}"
16
+ system_prompt = "You are a helpful assistant, you always only answer for the assistant then you stop. read the chat history to get context"
17
+
18
+ SYSTEM_PROMPT = B_SYS + system_prompt + E_SYS
19
+ prompt_template = B_INST + SYSTEM_PROMPT + instruction + E_INST
20
+ return prompt_template
21
+
22
+
23
  class ChatChain(LLMInference):
24
  def __init__(self, llm_loader):
25
  super().__init__(llm_loader)
26
 
27
  def create_chain(self) -> Chain:
28
+ template = (
29
+ get_llama_2_prompt_template()
30
+ if os.environ.get("USE_LLAMA_2_PROMPT_TEMPLATE") == "true"
31
+ else """You are a chatbot having a conversation with a human.
32
  {chat_history}
33
  Human: {question}
34
  Chatbot:"""
35
+ )
36
+
37
+ print(f"template: {template}")
38
 
39
  prompt = PromptTemplate(
40
  input_variables=["chat_history", "question"], template=template
41
  )
42
+
43
  memory = ConversationBufferMemory(memory_key="chat_history")
44
 
45
  llm_chain = LLMChain(