Spaces:
Runtime error
Runtime error
update glm stream
Browse files
app.py
CHANGED
@@ -93,75 +93,6 @@ def predict_chatgpt(inputs, top_p_chatgpt, temperature_chatgpt, openai_api_key,
|
|
93 |
yield chat, history, chat_counter_chatgpt # this resembles {chatbot: chat, state: history}
|
94 |
|
95 |
|
96 |
-
#Predict function for OPENCHATKIT
|
97 |
-
def predict_together(model: str,
|
98 |
-
inputs: str,
|
99 |
-
top_p: float,
|
100 |
-
temperature: float,
|
101 |
-
top_k: int,
|
102 |
-
repetition_penalty: float,
|
103 |
-
watermark: bool,
|
104 |
-
chatbot,
|
105 |
-
history,):
|
106 |
-
|
107 |
-
client = Client(os.getenv("API_URL_TGTHR")) #get_client(model)
|
108 |
-
# debug
|
109 |
-
#print(f"^^client is - {client}")
|
110 |
-
user_name, assistant_name = "<human>: ", "<bot>: "
|
111 |
-
preprompt = openchat_preprompt
|
112 |
-
sep = '\n'
|
113 |
-
|
114 |
-
history.append(inputs)
|
115 |
-
|
116 |
-
past = []
|
117 |
-
for data in chatbot:
|
118 |
-
user_data, model_data = data
|
119 |
-
|
120 |
-
if not user_data.startswith(user_name):
|
121 |
-
user_data = user_name + user_data
|
122 |
-
if not model_data.startswith("\n" + assistant_name):
|
123 |
-
model_data = "\n" + assistant_name + model_data
|
124 |
-
|
125 |
-
past.append(user_data + model_data.rstrip() + "\n")
|
126 |
-
|
127 |
-
if not inputs.startswith(user_name):
|
128 |
-
inputs = user_name + inputs
|
129 |
-
|
130 |
-
total_inputs = preprompt + "".join(past) + inputs + "\n" + assistant_name.rstrip()
|
131 |
-
# truncate total_inputs
|
132 |
-
#total_inputs = total_inputs[-1000:]
|
133 |
-
|
134 |
-
partial_words = ""
|
135 |
-
|
136 |
-
for i, response in enumerate(client.generate_stream(
|
137 |
-
total_inputs,
|
138 |
-
top_p=top_p,
|
139 |
-
top_k=top_k,
|
140 |
-
repetition_penalty=repetition_penalty,
|
141 |
-
watermark=watermark,
|
142 |
-
temperature=temperature,
|
143 |
-
max_new_tokens=500,
|
144 |
-
stop_sequences=[user_name.rstrip(), assistant_name.rstrip()],
|
145 |
-
)):
|
146 |
-
if response.token.special:
|
147 |
-
continue
|
148 |
-
|
149 |
-
partial_words = partial_words + response.token.text
|
150 |
-
if partial_words.endswith(user_name.rstrip()):
|
151 |
-
partial_words = partial_words.rstrip(user_name.rstrip())
|
152 |
-
if partial_words.endswith(assistant_name.rstrip()):
|
153 |
-
partial_words = partial_words.rstrip(assistant_name.rstrip())
|
154 |
-
|
155 |
-
if i == 0:
|
156 |
-
history.append(" " + partial_words)
|
157 |
-
else:
|
158 |
-
history[-1] = partial_words
|
159 |
-
|
160 |
-
chat = [
|
161 |
-
(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)
|
162 |
-
]
|
163 |
-
yield chat, history
|
164 |
-
|
165 |
# Define function to generate model predictions and update the history
|
166 |
def predict_glm(input, history=[]):
|
167 |
response, history = model_glm.chat(tokenizer_glm, input, history)
|
@@ -177,6 +108,21 @@ def translate_Chinese_English(chinese_text):
|
|
177 |
trans_eng_text = tokenizer_chtoen.batch_decode(generated_tokens, skip_special_tokens=True)
|
178 |
return trans_eng_text[0]
|
179 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
"""
|
181 |
def predict(input, max_length, top_p, temperature, history=None):
|
182 |
if history is None:
|
@@ -185,7 +131,7 @@ def predict(input, max_length, top_p, temperature, history=None):
|
|
185 |
temperature=temperature):
|
186 |
updates = []
|
187 |
for query, response in history:
|
188 |
-
updates.append(gr.update(visible=True, value="
|
189 |
updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response))
|
190 |
if len(updates) < MAX_BOXES:
|
191 |
updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
|
@@ -265,21 +211,21 @@ with gr.Blocks(css="""#col_container {width: 1000px; margin-left: auto; margin-r
|
|
265 |
inputs.submit( predict_chatgpt,
|
266 |
[inputs, top_p_chatgpt, temperature_chatgpt, openai_api_key, chat_counter_chatgpt, chatbot_chatgpt, state_chatgpt],
|
267 |
[chatbot_chatgpt, state_chatgpt, chat_counter_chatgpt],)
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
[inputs, state_glm, ],
|
273 |
[chatbot_glm, state_glm],)
|
274 |
b1.click( predict_chatgpt,
|
275 |
[inputs, top_p_chatgpt, temperature_chatgpt, openai_api_key, chat_counter_chatgpt, chatbot_chatgpt, state_chatgpt],
|
276 |
[chatbot_chatgpt, state_chatgpt, chat_counter_chatgpt],)
|
277 |
-
#b1.click( predict_together,
|
278 |
-
# [temp_textbox_together, inputs, top_p, temperature, top_k, repetition_penalty, watermark, chatbot_together, state_together, ],
|
279 |
-
# [chatbot_together, state_together],)
|
280 |
-
b1.click( predict_glm,
|
281 |
-
[inputs, state_glm, ],
|
282 |
-
[chatbot_glm, state_glm],)
|
283 |
|
284 |
b2.click(reset_chat, [chatbot_chatgpt, state_chatgpt], [chatbot_chatgpt, state_chatgpt])
|
285 |
#b2.click(reset_chat, [chatbot_together, state_together], [chatbot_together, state_together])
|
|
|
93 |
yield chat, history, chat_counter_chatgpt # this resembles {chatbot: chat, state: history}
|
94 |
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
# Define function to generate model predictions and update the history
|
97 |
def predict_glm(input, history=[]):
|
98 |
response, history = model_glm.chat(tokenizer_glm, input, history)
|
|
|
108 |
trans_eng_text = tokenizer_chtoen.batch_decode(generated_tokens, skip_special_tokens=True)
|
109 |
return trans_eng_text[0]
|
110 |
|
111 |
+
# Define function to generate model predictions and update the history
|
112 |
+
def predict_glm_stream(input, history=[]): #, top_p, temperature):
|
113 |
+
response, history = model_glm.chat(tokenizer_glm, input, history)
|
114 |
+
print(f"outside for loop resonse is ^^- {response}")
|
115 |
+
print(f"outside for loop history is ^^- {history}")
|
116 |
+
top_p, temperature = 1.0, 1.0
|
117 |
+
for response, history in model.stream_chat(tokenizer_glm, input, history, top_p=top_p, temperature=temperature): #max_length=max_length,
|
118 |
+
print(f"In for loop resonse is ^^- {response}")
|
119 |
+
print(f"In for loop history is ^^- {history}")
|
120 |
+
# translate Chinese to English
|
121 |
+
history = [(query, translate_Chinese_English(response)) for query, response in history]
|
122 |
+
print(f"In for loop translated history is ^^- {history}")
|
123 |
+
yield history, history #[history] + updates
|
124 |
+
|
125 |
+
|
126 |
"""
|
127 |
def predict(input, max_length, top_p, temperature, history=None):
|
128 |
if history is None:
|
|
|
131 |
temperature=temperature):
|
132 |
updates = []
|
133 |
for query, response in history:
|
134 |
+
updates.append(gr.update(visible=True, value="user:" + query)) #用户
|
135 |
updates.append(gr.update(visible=True, value="ChatGLM-6B:" + response))
|
136 |
if len(updates) < MAX_BOXES:
|
137 |
updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
|
|
|
211 |
inputs.submit( predict_chatgpt,
|
212 |
[inputs, top_p_chatgpt, temperature_chatgpt, openai_api_key, chat_counter_chatgpt, chatbot_chatgpt, state_chatgpt],
|
213 |
[chatbot_chatgpt, state_chatgpt, chat_counter_chatgpt],)
|
214 |
+
#inputs.submit( predict_glm,
|
215 |
+
# [inputs, state_glm, ],
|
216 |
+
# [chatbot_glm, state_glm],)
|
217 |
+
#b1.click( predict_glm,
|
218 |
+
# [inputs, state_glm, ],
|
219 |
+
# [chatbot_glm, state_glm],)
|
220 |
+
inputs.submit( predict_glm_stream,
|
221 |
+
[inputs, state_glm, ],
|
222 |
+
[chatbot_glm, state_glm],)
|
223 |
+
b1.click( predict_glm_stream,
|
224 |
[inputs, state_glm, ],
|
225 |
[chatbot_glm, state_glm],)
|
226 |
b1.click( predict_chatgpt,
|
227 |
[inputs, top_p_chatgpt, temperature_chatgpt, openai_api_key, chat_counter_chatgpt, chatbot_chatgpt, state_chatgpt],
|
228 |
[chatbot_chatgpt, state_chatgpt, chat_counter_chatgpt],)
|
|
|
|
|
|
|
|
|
|
|
|
|
229 |
|
230 |
b2.click(reset_chat, [chatbot_chatgpt, state_chatgpt], [chatbot_chatgpt, state_chatgpt])
|
231 |
#b2.click(reset_chat, [chatbot_together, state_together], [chatbot_together, state_together])
|