carolanderson commited on
Commit
7081223
·
1 Parent(s): f2f3156

adjust decoding controls

Browse files
Files changed (1) hide show
  1. app.py +72 -97
app.py CHANGED
@@ -1,86 +1,76 @@
1
  import logging
2
- import os
3
 
4
  from langchain.chains import LLMChain
5
  from langchain.chat_models import ChatOpenAI
6
  from langchain.llms import HuggingFaceHub
7
  from langchain.prompts.chat import (
8
- PromptTemplate,
9
- ChatPromptTemplate,
10
- MessagesPlaceholder,
11
- SystemMessagePromptTemplate,
12
- HumanMessagePromptTemplate,
13
  )
14
  from langchain.memory import ConversationBufferWindowMemory
15
  from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
16
- from langchain.schema import AIMessage, HumanMessage
17
  from openai.error import AuthenticationError
18
  import streamlit as st
19
 
20
- from langchain import verbose
21
- verbose = True
22
-
23
 
24
  def setup_memory():
25
  msgs = StreamlitChatMessageHistory(key="basic_chat_app")
26
- memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history",
27
  chat_memory=msgs,
28
  return_messages=True)
29
  logging.info("setting up new chat memory")
30
  return memory
31
-
32
 
33
- def use_existing_chain(model, provider, temp, max_tokens):
 
34
  # TODO: consider whether prompt needs to be checked here
35
  if "mistral" in model:
36
  return False
37
  if "current_chain" in st.session_state:
38
  current_chain = st.session_state.current_chain
39
  if (current_chain.model == model) \
40
- and (current_chain.provider == provider) \
41
- and (current_chain.temp == temp) \
42
- and (current_chain.max_tokens == max_tokens):
43
  return True
44
  return False
45
 
46
 
47
  class CurrentChain():
48
- def __init__(self, model, provider, prompt, memory, temp, max_tokens=64):
49
  self.model = model
50
  self.provider = provider
51
- self.temp = temp
52
- self.max_tokens=max_tokens
53
-
54
  logging.info(f"setting up new chain with params {model_name}, {provider}, {temp}")
55
- if provider == "OpenAI":
56
- llm = ChatOpenAI(model_name=model, temperature=temp)
 
 
57
  elif provider == "HuggingFace":
58
- # TODO: expose the controls below as widgets and clean up init
59
  llm = HuggingFaceHub(repo_id=model,
60
- model_kwargs={"temperature": temp,
61
- "max_new_tokens": 256,
62
- "top_p" : 0.95,
63
- "repetition_penalty" : 1.0,
64
- "do_sample" : True,
65
- "seed" : 42})
66
 
67
  self.conversation = LLMChain(
68
- llm=llm,
69
- prompt=prompt,
70
- verbose=True,
71
- memory=memory
72
- )
73
 
74
 
75
  def format_mistral_prompt(message, history):
76
- prompt = "<s>"
77
- for user_prompt, bot_response in history:
78
- prompt += f"[INST] {user_prompt} [/INST]"
79
- prompt += f" {bot_response}</s> "
80
- prompt += f"[INST] {message} [/INST]"
81
- return prompt
82
-
83
-
84
  if __name__ == "__main__":
85
  logging.basicConfig(level=logging.INFO)
86
 
@@ -88,56 +78,53 @@ if __name__ == "__main__":
88
  st.write("On small screens, click the `>` at top left to choose options")
89
  with st.expander("How conversation history works"):
90
  st.write("To keep input lengths down and costs reasonable,"
91
- " this bot only 'remembers' the past three turns of conversation.")
92
- st.write("To clear all memory and start fresh, click 'Clear history'" )
 
93
  st.sidebar.title("Choose options")
94
 
95
  #### USER INPUT ######
96
  model_name = st.sidebar.selectbox(
97
- label = "Choose a model",
98
- options = ["gpt-3.5-turbo (OpenAI)",
99
- # "bigscience/bloom (HuggingFace)", # runs
100
- # "microsoft/DialoGPT-medium (HuggingFace)", # throws error
101
- # "google/flan-t5-xxl (HuggingFace)", # runs
102
- "mistralai/Mistral-7B-Instruct-v0.1 (HuggingFace)"
103
- ],
104
  help="Which LLM to use",
105
  )
106
 
107
- st.sidebar.write("Set the decoding temperature. Higher temperatures give "
108
- "more unpredictable outputs.")
109
-
110
  temp = st.sidebar.slider(
111
  label="Temperature",
112
  min_value=float(0),
113
- max_value=1.0,
114
  step=0.1,
115
  value=0.4,
116
- help="Set the decoding temperature"
117
- )
118
-
119
- max_tokens = st.sidebar.slider(
120
- label="Max tokens",
121
- min_value=32,
122
- max_value=2048,
123
- step=1,
124
- value=1028,
125
- help="Set the maximum number of tokens to generate"
126
- ) # TODO: edit this, not currently using
127
  ##########################
128
-
129
- model = model_name.split("(")[0].rstrip() # remove name of model provider
130
  provider = model_name.split("(")[-1].split(")")[0]
131
 
 
 
 
 
 
 
 
 
132
  if "session_memory" not in st.session_state:
133
  st.session_state.session_memory = setup_memory() # for openai
134
 
135
  if "history" not in st.session_state:
136
- st.session_state.history = [] # for mistral
137
 
138
  if "mistral" in model:
139
- prompt = PromptTemplate(input_variables=["input"],
140
- template="{input}")
141
  else:
142
  prompt = ChatPromptTemplate(
143
  messages=[
@@ -147,33 +134,32 @@ if __name__ == "__main__":
147
  MessagesPlaceholder(variable_name="chat_history"),
148
  HumanMessagePromptTemplate.from_template("{input}")
149
  ],
150
- verbose=True
151
- )
152
-
153
- if use_existing_chain(model, provider, temp, max_tokens):
154
  chain = st.session_state.current_chain
155
  else:
156
- chain = CurrentChain(model,
157
- provider,
158
- prompt,
159
- st.session_state.session_memory,
160
- temp,
161
- max_tokens)
162
- st.session_state.current_chain = chain
163
 
164
  conversation = chain.conversation
165
-
166
  if st.button("Clear history"):
167
- conversation.memory.clear() # for openai
168
  st.session_state.history = [] # for mistral
169
  logging.info("history cleared")
170
-
171
  for user_msg, asst_msg in st.session_state.history:
172
  with st.chat_message("user"):
173
  st.write(user_msg)
174
  with st.chat_message("assistant"):
175
  st.write(asst_msg)
176
-
177
  text = st.chat_input()
178
  if text:
179
  with st.chat_message("user"):
@@ -192,14 +178,3 @@ if __name__ == "__main__":
192
  st.write(result)
193
  except (AuthenticationError, ValueError):
194
  st.warning("Supply a valid API key", icon="⚠️")
195
-
196
-
197
-
198
-
199
-
200
-
201
-
202
-
203
-
204
-
205
-
 
1
  import logging
 
2
 
3
  from langchain.chains import LLMChain
4
  from langchain.chat_models import ChatOpenAI
5
  from langchain.llms import HuggingFaceHub
6
  from langchain.prompts.chat import (
7
+ PromptTemplate,
8
+ ChatPromptTemplate,
9
+ MessagesPlaceholder,
10
+ SystemMessagePromptTemplate,
11
+ HumanMessagePromptTemplate,
12
  )
13
  from langchain.memory import ConversationBufferWindowMemory
14
  from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
 
15
  from openai.error import AuthenticationError
16
  import streamlit as st
17
 
 
 
 
18
 
19
  def setup_memory():
20
  msgs = StreamlitChatMessageHistory(key="basic_chat_app")
21
+ memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history",
22
  chat_memory=msgs,
23
  return_messages=True)
24
  logging.info("setting up new chat memory")
25
  return memory
 
26
 
27
+
28
+ def use_existing_chain(model, provider, model_kwargs):
29
  # TODO: consider whether prompt needs to be checked here
30
  if "mistral" in model:
31
  return False
32
  if "current_chain" in st.session_state:
33
  current_chain = st.session_state.current_chain
34
  if (current_chain.model == model) \
35
+ and (current_chain.provider == provider) \
36
+ and (current_chain.model_kwargs == model_kwargs):
 
37
  return True
38
  return False
39
 
40
 
41
  class CurrentChain():
42
+ def __init__(self, model, provider, prompt, memory, model_kwargs):
43
  self.model = model
44
  self.provider = provider
45
+ self.model_kwargs = model_kwargs
46
+
 
47
  logging.info(f"setting up new chain with params {model_name}, {provider}, {temp}")
48
+ if provider == "OpenAI":
49
+ llm = ChatOpenAI(model_name=model,
50
+ temperature=model_kwargs['temperature']
51
+ )
52
  elif provider == "HuggingFace":
 
53
  llm = HuggingFaceHub(repo_id=model,
54
+ model_kwargs=model_kwargs
55
+ )
 
 
 
 
56
 
57
  self.conversation = LLMChain(
58
+ llm=llm,
59
+ prompt=prompt,
60
+ verbose=True,
61
+ memory=memory
62
+ )
63
 
64
 
65
  def format_mistral_prompt(message, history):
66
+ prompt = "<s>"
67
+ for user_prompt, bot_response in history:
68
+ prompt += f"[INST] {user_prompt} [/INST]"
69
+ prompt += f" {bot_response}</s> "
70
+ prompt += f"[INST] {message} [/INST]"
71
+ return prompt
72
+
73
+
74
  if __name__ == "__main__":
75
  logging.basicConfig(level=logging.INFO)
76
 
 
78
  st.write("On small screens, click the `>` at top left to choose options")
79
  with st.expander("How conversation history works"):
80
  st.write("To keep input lengths down and costs reasonable,"
81
+ " only the past three turns of conversation "
82
+ " are used for OpenAI models. Otherwise the entire chat history is used.")
83
+ st.write("To clear all memory and start fresh, click 'Clear history'")
84
  st.sidebar.title("Choose options")
85
 
86
  #### USER INPUT ######
87
  model_name = st.sidebar.selectbox(
88
+ label="Choose a model",
89
+ options=["gpt-3.5-turbo (OpenAI)",
90
+ # "bigscience/bloom (HuggingFace)", # runs
91
+ # "google/flan-t5-xxl (HuggingFace)", # runs
92
+ "mistralai/Mistral-7B-Instruct-v0.1 (HuggingFace)"
93
+ ],
 
94
  help="Which LLM to use",
95
  )
96
 
 
 
 
97
  temp = st.sidebar.slider(
98
  label="Temperature",
99
  min_value=float(0),
100
+ max_value=2.0,
101
  step=0.1,
102
  value=0.4,
103
+ help="Set the decoding temperature. "
104
+ "Higher temps give more unpredictable outputs."
105
+ )
 
 
 
 
 
 
 
 
106
  ##########################
107
+
108
+ model = model_name.split("(")[0].rstrip() # remove name of model provider
109
  provider = model_name.split("(")[-1].split(")")[0]
110
 
111
+ model_kwargs = {"temperature": temp,
112
+ "max_new_tokens": 256,
113
+ "repetition_penalty": 1.0,
114
+ "top_p": 0.95,
115
+ "do_sample": True,
116
+ "seed": 42}
117
+ # TODO: maybe expose more of these to the user
118
+
119
  if "session_memory" not in st.session_state:
120
  st.session_state.session_memory = setup_memory() # for openai
121
 
122
  if "history" not in st.session_state:
123
+ st.session_state.history = [] # for mistral
124
 
125
  if "mistral" in model:
126
+ prompt = PromptTemplate(input_variables=["input"],
127
+ template="{input}")
128
  else:
129
  prompt = ChatPromptTemplate(
130
  messages=[
 
134
  MessagesPlaceholder(variable_name="chat_history"),
135
  HumanMessagePromptTemplate.from_template("{input}")
136
  ],
137
+ verbose=True
138
+ )
139
+
140
+ if use_existing_chain(model, provider, model_kwargs):
141
  chain = st.session_state.current_chain
142
  else:
143
+ chain = CurrentChain(model,
144
+ provider,
145
+ prompt,
146
+ st.session_state.session_memory,
147
+ model_kwargs)
148
+ st.session_state.current_chain = chain
 
149
 
150
  conversation = chain.conversation
151
+
152
  if st.button("Clear history"):
153
+ conversation.memory.clear() # for openai
154
  st.session_state.history = [] # for mistral
155
  logging.info("history cleared")
156
+
157
  for user_msg, asst_msg in st.session_state.history:
158
  with st.chat_message("user"):
159
  st.write(user_msg)
160
  with st.chat_message("assistant"):
161
  st.write(asst_msg)
162
+
163
  text = st.chat_input()
164
  if text:
165
  with st.chat_message("user"):
 
178
  st.write(result)
179
  except (AuthenticationError, ValueError):
180
  st.warning("Supply a valid API key", icon="⚠️")