ysharma HF staff commited on
Commit
a58bd0b
1 Parent(s): 890f63d

update glm stream

Browse files
Files changed (1) hide show
  1. app.py +26 -80
app.py CHANGED
@@ -93,75 +93,6 @@ def predict_chatgpt(inputs, top_p_chatgpt, temperature_chatgpt, openai_api_key,
93
  yield chat, history, chat_counter_chatgpt # this resembles {chatbot: chat, state: history}
94
 
95
 
96
- #Predict function for OPENCHATKIT
97
- def predict_together(model: str,
98
- inputs: str,
99
- top_p: float,
100
- temperature: float,
101
- top_k: int,
102
- repetition_penalty: float,
103
- watermark: bool,
104
- chatbot,
105
- history,):
106
-
107
- client = Client(os.getenv("API_URL_TGTHR")) #get_client(model)
108
- # debug
109
- #print(f"^^client is - {client}")
110
- user_name, assistant_name = "<human>: ", "<bot>: "
111
- preprompt = openchat_preprompt
112
- sep = '\n'
113
-
114
- history.append(inputs)
115
-
116
- past = []
117
- for data in chatbot:
118
- user_data, model_data = data
119
-
120
- if not user_data.startswith(user_name):
121
- user_data = user_name + user_data
122
- if not model_data.startswith("\n" + assistant_name):
123
- model_data = "\n" + assistant_name + model_data
124
-
125
- past.append(user_data + model_data.rstrip() + "\n")
126
-
127
- if not inputs.startswith(user_name):
128
- inputs = user_name + inputs
129
-
130
- total_inputs = preprompt + "".join(past) + inputs + "\n" + assistant_name.rstrip()
131
- # truncate total_inputs
132
- #total_inputs = total_inputs[-1000:]
133
-
134
- partial_words = ""
135
-
136
- for i, response in enumerate(client.generate_stream(
137
- total_inputs,
138
- top_p=top_p,
139
- top_k=top_k,
140
- repetition_penalty=repetition_penalty,
141
- watermark=watermark,
142
- temperature=temperature,
143
- max_new_tokens=500,
144
- stop_sequences=[user_name.rstrip(), assistant_name.rstrip()],
145
- )):
146
- if response.token.special:
147
- continue
148
-
149
- partial_words = partial_words + response.token.text
150
- if partial_words.endswith(user_name.rstrip()):
151
- partial_words = partial_words.rstrip(user_name.rstrip())
152
- if partial_words.endswith(assistant_name.rstrip()):
153
- partial_words = partial_words.rstrip(assistant_name.rstrip())
154
-
155
- if i == 0:
156
- history.append(" " + partial_words)
157
- else:
158
- history[-1] = partial_words
159
-
160
- chat = [
161
- (history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)
162
- ]
163
- yield chat, history
164
-
165
  # Define function to generate model predictions and update the history
166
  def predict_glm(input, history=[]):
167
  response, history = model_glm.chat(tokenizer_glm, input, history)
@@ -177,6 +108,21 @@ def translate_Chinese_English(chinese_text):
177
  trans_eng_text = tokenizer_chtoen.batch_decode(generated_tokens, skip_special_tokens=True)
178
  return trans_eng_text[0]
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  """
181
  def predict(input, max_length, top_p, temperature, history=None):
182
  if history is None:
@@ -185,7 +131,7 @@ def predict(input, max_length, top_p, temperature, history=None):
185
  temperature=temperature):
186
  updates = []
187
  for query, response in history:
188
- updates.append(gr.update(visible=True, value="用户:" + query))
189
  updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response))
190
  if len(updates) < MAX_BOXES:
191
  updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
@@ -265,21 +211,21 @@ with gr.Blocks(css="""#col_container {width: 1000px; margin-left: auto; margin-r
265
  inputs.submit( predict_chatgpt,
266
  [inputs, top_p_chatgpt, temperature_chatgpt, openai_api_key, chat_counter_chatgpt, chatbot_chatgpt, state_chatgpt],
267
  [chatbot_chatgpt, state_chatgpt, chat_counter_chatgpt],)
268
- #inputs.submit( predict_together,
269
- # [temp_textbox_together, inputs, top_p, temperature, top_k, repetition_penalty, watermark, chatbot_together, state_together, ],
270
- # [chatbot_together, state_together],)
271
- inputs.submit( predict_glm,
 
 
 
 
 
 
272
  [inputs, state_glm, ],
273
  [chatbot_glm, state_glm],)
274
  b1.click( predict_chatgpt,
275
  [inputs, top_p_chatgpt, temperature_chatgpt, openai_api_key, chat_counter_chatgpt, chatbot_chatgpt, state_chatgpt],
276
  [chatbot_chatgpt, state_chatgpt, chat_counter_chatgpt],)
277
- #b1.click( predict_together,
278
- # [temp_textbox_together, inputs, top_p, temperature, top_k, repetition_penalty, watermark, chatbot_together, state_together, ],
279
- # [chatbot_together, state_together],)
280
- b1.click( predict_glm,
281
- [inputs, state_glm, ],
282
- [chatbot_glm, state_glm],)
283
 
284
  b2.click(reset_chat, [chatbot_chatgpt, state_chatgpt], [chatbot_chatgpt, state_chatgpt])
285
  #b2.click(reset_chat, [chatbot_together, state_together], [chatbot_together, state_together])
 
93
  yield chat, history, chat_counter_chatgpt # this resembles {chatbot: chat, state: history}
94
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  # Define function to generate model predictions and update the history
97
  def predict_glm(input, history=[]):
98
  response, history = model_glm.chat(tokenizer_glm, input, history)
 
108
  trans_eng_text = tokenizer_chtoen.batch_decode(generated_tokens, skip_special_tokens=True)
109
  return trans_eng_text[0]
110
 
111
+ # Define function to generate model predictions and update the history
112
+ def predict_glm_stream(input, history=[]): #, top_p, temperature):
113
+ response, history = model_glm.chat(tokenizer_glm, input, history)
114
+ print(f"outside for loop resonse is ^^- {response}")
115
+ print(f"outside for loop history is ^^- {history}")
116
+ top_p, temperature = 1.0, 1.0
117
+ for response, history in model.stream_chat(tokenizer_glm, input, history, top_p=top_p, temperature=temperature): #max_length=max_length,
118
+ print(f"In for loop resonse is ^^- {response}")
119
+ print(f"In for loop history is ^^- {history}")
120
+ # translate Chinese to English
121
+ history = [(query, translate_Chinese_English(response)) for query, response in history]
122
+ print(f"In for loop translated history is ^^- {history}")
123
+ yield history, history #[history] + updates
124
+
125
+
126
  """
127
  def predict(input, max_length, top_p, temperature, history=None):
128
  if history is None:
 
131
  temperature=temperature):
132
  updates = []
133
  for query, response in history:
134
+ updates.append(gr.update(visible=True, value="user:" + query)) #用户
135
  updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response))
136
  if len(updates) < MAX_BOXES:
137
  updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
 
211
  inputs.submit( predict_chatgpt,
212
  [inputs, top_p_chatgpt, temperature_chatgpt, openai_api_key, chat_counter_chatgpt, chatbot_chatgpt, state_chatgpt],
213
  [chatbot_chatgpt, state_chatgpt, chat_counter_chatgpt],)
214
+ #inputs.submit( predict_glm,
215
+ # [inputs, state_glm, ],
216
+ # [chatbot_glm, state_glm],)
217
+ #b1.click( predict_glm,
218
+ # [inputs, state_glm, ],
219
+ # [chatbot_glm, state_glm],)
220
+ inputs.submit( predict_glm_stream,
221
+ [inputs, state_glm, ],
222
+ [chatbot_glm, state_glm],)
223
+ b1.click( predict_glm_stream,
224
  [inputs, state_glm, ],
225
  [chatbot_glm, state_glm],)
226
  b1.click( predict_chatgpt,
227
  [inputs, top_p_chatgpt, temperature_chatgpt, openai_api_key, chat_counter_chatgpt, chatbot_chatgpt, state_chatgpt],
228
  [chatbot_chatgpt, state_chatgpt, chat_counter_chatgpt],)
 
 
 
 
 
 
229
 
230
  b2.click(reset_chat, [chatbot_chatgpt, state_chatgpt], [chatbot_chatgpt, state_chatgpt])
231
  #b2.click(reset_chat, [chatbot_together, state_together], [chatbot_together, state_together])