JohnSmith9982
commited on
Commit
•
0405ac0
1
Parent(s):
2c3bb3b
Update utils.py
Browse files
utils.py
CHANGED
@@ -49,7 +49,7 @@ def postprocess(
|
|
49 |
return y
|
50 |
|
51 |
def count_token(input_str):
|
52 |
-
encoding = tiktoken.
|
53 |
length = len(encoding.encode(input_str))
|
54 |
return length
|
55 |
|
@@ -144,14 +144,20 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, prev
|
|
144 |
try:
|
145 |
response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True)
|
146 |
except requests.exceptions.ConnectTimeout:
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
yield get_return_value()
|
149 |
return
|
150 |
|
151 |
chatbot.append((parse_text(inputs), ""))
|
152 |
yield get_return_value()
|
153 |
|
154 |
-
for chunk in
|
155 |
if counter == 0:
|
156 |
counter += 1
|
157 |
continue
|
@@ -160,7 +166,12 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, prev
|
|
160 |
if chunk:
|
161 |
chunk = chunk.decode()
|
162 |
chunklength = len(chunk)
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
164 |
# decode each line as response data is in bytes
|
165 |
if chunklength > 6 and "delta" in chunk['choices'][0]:
|
166 |
finish_reason = chunk['choices'][0]['finish_reason']
|
@@ -169,7 +180,12 @@ def stream_predict(openai_api_key, system_prompt, history, inputs, chatbot, prev
|
|
169 |
print("生成完毕")
|
170 |
yield get_return_value()
|
171 |
break
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
173 |
if token_counter == 0:
|
174 |
history.append(construct_assistant(" " + partial_words))
|
175 |
else:
|
|
|
49 |
return y
|
50 |
|
51 |
def count_token(input_str):
|
52 |
+
encoding = tiktoken.get_encoding("cl100k_base")
|
53 |
length = len(encoding.encode(input_str))
|
54 |
return length
|
55 |
|
|
|
144 |
try:
|
145 |
response = get_response(openai_api_key, system_prompt, history, temperature, top_p, True)
|
146 |
except requests.exceptions.ConnectTimeout:
|
147 |
+
history.pop()
|
148 |
+
status_text = standard_error_msg + "连接超时,无法获取对话。" + error_retrieve_prompt
|
149 |
+
yield get_return_value()
|
150 |
+
return
|
151 |
+
except requests.exceptions.ReadTimeout:
|
152 |
+
history.pop()
|
153 |
+
status_text = standard_error_msg + "读取超时,无法获取对话。" + error_retrieve_prompt
|
154 |
yield get_return_value()
|
155 |
return
|
156 |
|
157 |
chatbot.append((parse_text(inputs), ""))
|
158 |
yield get_return_value()
|
159 |
|
160 |
+
for chunk in response.iter_lines():
|
161 |
if counter == 0:
|
162 |
counter += 1
|
163 |
continue
|
|
|
166 |
if chunk:
|
167 |
chunk = chunk.decode()
|
168 |
chunklength = len(chunk)
|
169 |
+
try:
|
170 |
+
chunk = json.loads(chunk[6:])
|
171 |
+
except json.JSONDecodeError:
|
172 |
+
status_text = f"JSON解析错误。请重置对话。收到的内容: {chunk}"
|
173 |
+
yield get_return_value()
|
174 |
+
break
|
175 |
# decode each line as response data is in bytes
|
176 |
if chunklength > 6 and "delta" in chunk['choices'][0]:
|
177 |
finish_reason = chunk['choices'][0]['finish_reason']
|
|
|
180 |
print("生成完毕")
|
181 |
yield get_return_value()
|
182 |
break
|
183 |
+
try:
|
184 |
+
partial_words = partial_words + chunk['choices'][0]["delta"]["content"]
|
185 |
+
except KeyError:
|
186 |
+
status_text = standard_error_msg + "API回复中找不到内容。很可能是Token计数达到上限了。请重置对话。当前Token计数: " + str(sum(previous_token_count)+token_counter+user_token_count)
|
187 |
+
yield get_return_value()
|
188 |
+
break
|
189 |
if token_counter == 0:
|
190 |
history.append(construct_assistant(" " + partial_words))
|
191 |
else:
|