matthoffner commited on
Commit
ab83be1
1 Parent(s): 8a15493

Try to fix repeating on User

Browse files
Files changed (1) hide show
  1. demo.py +16 -9
demo.py CHANGED
@@ -98,18 +98,18 @@ def chat():
98
  if not message or (message == RETRY_COMMAND and len(chat_history) == 0):
99
  yield chat_history
100
  return
101
-
102
  if message == RETRY_COMMAND and chat_history:
103
  prev_turn = chat_history.pop(-1)
104
  user_message, _ = prev_turn
105
  message = user_message
106
-
107
  prompt = format_chat_prompt(message, chat_history, instructions)
108
  chat_history = chat_history + [[message, ""]]
109
  stream = llm(
110
  prompt,
111
  max_new_tokens=1024,
112
- stop=[STOP_STR, "<|endoftext|>"],
113
  temperature=temperature,
114
  top_p=top_p,
115
  stream=True
@@ -117,20 +117,27 @@ def chat():
117
  acc_text = ""
118
  for idx, response in enumerate(stream):
119
  text_token = response
120
-
121
- if text_token in STOP_SUSPECT_LIST:
122
- acc_text += text_token
 
 
 
 
123
  continue
124
-
125
  if idx == 0 and text_token.startswith(" "):
126
  text_token = text_token[1:]
127
-
128
  acc_text += text_token
 
 
 
129
  last_turn = list(chat_history.pop(-1))
130
  last_turn[-1] += acc_text
131
  chat_history = chat_history + [last_turn]
132
  yield chat_history
133
- acc_text = ""
134
 
135
  def delete_last_turn(chat_history):
136
  if chat_history:
 
98
  if not message or (message == RETRY_COMMAND and len(chat_history) == 0):
99
  yield chat_history
100
  return
101
+
102
  if message == RETRY_COMMAND and chat_history:
103
  prev_turn = chat_history.pop(-1)
104
  user_message, _ = prev_turn
105
  message = user_message
106
+
107
  prompt = format_chat_prompt(message, chat_history, instructions)
108
  chat_history = chat_history + [[message, ""]]
109
  stream = llm(
110
  prompt,
111
  max_new_tokens=1024,
112
+ stop=[STOP_STR, ""],
113
  temperature=temperature,
114
  top_p=top_p,
115
  stream=True
 
117
  acc_text = ""
118
  for idx, response in enumerate(stream):
119
  text_token = response
120
+
121
+ if acc_text.endswith(STOP_STR):
122
+ last_turn = list(chat_history.pop(-1))
123
+ last_turn[-1] += acc_text[:-len(STOP_STR)]
124
+ chat_history = chat_history + [last_turn]
125
+ yield chat_history
126
+ acc_text = text_token
127
  continue
128
+
129
  if idx == 0 and text_token.startswith(" "):
130
  text_token = text_token[1:]
131
+
132
  acc_text += text_token
133
+
134
+ # If there's any remaining text after the loop
135
+ if acc_text:
136
  last_turn = list(chat_history.pop(-1))
137
  last_turn[-1] += acc_text
138
  chat_history = chat_history + [last_turn]
139
  yield chat_history
140
+
141
 
142
  def delete_last_turn(chat_history):
143
  if chat_history: