kcarnold commited on
Commit
898f051
·
1 Parent(s): fe6cd36

Make it actually a predictive-text interface

Browse files
Files changed (1) hide show
  1. app.py +33 -17
app.py CHANGED
@@ -213,6 +213,7 @@ def type_assistant_response():
213
  st.session_state['messages'] = []
214
  messages = st.session_state.messages
215
 
 
216
  for message in st.session_state.messages[:-1]:
217
  with st.chat_message(message["role"]):
218
  st.markdown(message["content"])
@@ -228,37 +229,52 @@ def type_assistant_response():
228
  if len(messages) == 0:
229
  st.stop()
230
 
231
- response = requests.post(
232
- f"{API_SERVER}/continue_messages",
233
- json={
234
- "messages": messages,
235
- "n_branch_tokens": 5,
236
- "n_future_tokens": 2
237
- }
238
- )
239
- if response.status_code != 200:
240
- st.error("Error fetching response")
241
- st.write(response.text)
242
- st.stop()
243
- response.raise_for_status()
244
- response = response.json()
245
-
246
  # Display assistant response in chat message container
247
  with st.chat_message("assistant"):
248
- st.write(messages[-1]['content'])
 
 
 
 
249
  def append_token(word):
250
  messages[-1]['content'] = (
251
- messages[-1]['content'] + word
252
  )
253
 
254
  allow_multi_word = st.checkbox("Allow multi-word predictions", value=False)
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  continuations = response['continuations']
257
  for i, (col, continuation) in enumerate(zip(st.columns(len(continuations)), continuations)):
258
  token = continuation['doc_text']
259
  with col:
260
  if not allow_multi_word and ' ' in token[1:]:
261
  token = token[0] + token[1:].split(' ', 1)[0]
 
 
 
 
 
 
 
 
 
262
  token_display = show_token(token)
263
  st.button(token_display, on_click=append_token, args=(token,), key=i, use_container_width=True)
264
 
 
213
  st.session_state['messages'] = []
214
  messages = st.session_state.messages
215
 
216
+ # All but the last message happens normally.
217
  for message in st.session_state.messages[:-1]:
218
  with st.chat_message(message["role"]):
219
  st.markdown(message["content"])
 
229
  if len(messages) == 0:
230
  st.stop()
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  # Display assistant response in chat message container
233
  with st.chat_message("assistant"):
234
+ #st.write(messages[-1]['content'])
235
+ msg_in_progress = st.text_area("Assistant response", value=messages[-1]['content'], placeholder="Clicking the buttons below will update this field. You can also edit it directly; press Ctrl+Enter to apply changes.", height=300)
236
+ # strip spaces (but not newlines) to avoid a tokenization issue
237
+ msg_in_progress = msg_in_progress.rstrip(' ')
238
+
239
  def append_token(word):
240
  messages[-1]['content'] = (
241
+ msg_in_progress + word
242
  )
243
 
244
  allow_multi_word = st.checkbox("Allow multi-word predictions", value=False)
245
 
246
+ response = requests.post(
247
+ f"{API_SERVER}/continue_messages",
248
+ json={
249
+ "messages": messages[:-1] + [
250
+ {"role": "assistant", "content": msg_in_progress},
251
+ ],
252
+ "n_branch_tokens": 5,
253
+ "n_future_tokens": 2
254
+ }
255
+ )
256
+ if response.status_code != 200:
257
+ st.error("Error fetching response")
258
+ st.write(response.text)
259
+ st.stop()
260
+ response.raise_for_status()
261
+ response = response.json()
262
+
263
  continuations = response['continuations']
264
  for i, (col, continuation) in enumerate(zip(st.columns(len(continuations)), continuations)):
265
  token = continuation['doc_text']
266
  with col:
267
  if not allow_multi_word and ' ' in token[1:]:
268
  token = token[0] + token[1:].split(' ', 1)[0]
269
+
270
+ # if not allow_multi_word:
271
+ # import re
272
+ # split_result = re.split(r'(\s+)', token, maxsplit=1)
273
+ # assert len(split_result) == 3
274
+ # before_ws, token, after_ws = split_result
275
+ # print(repr(split_result))
276
+ # if before_ws != '':
277
+ # token = before_ws
278
  token_display = show_token(token)
279
  st.button(token_display, on_click=append_token, args=(token,), key=i, use_container_width=True)
280