JunchuanYu commited on
Commit
0f78229
1 Parent(s): 393f574

Delete chat_func.py

Browse files
Files changed (1) hide show
  1. chat_func.py +0 -296
chat_func.py DELETED
@@ -1,296 +0,0 @@
1
- # -*- coding:utf-8 -*-
2
- from __future__ import annotations
3
- from typing import TYPE_CHECKING, List
4
-
5
- import logging
6
- import json
7
- import os
8
- import requests
9
-
10
- from tqdm import tqdm
11
-
12
- from utils import *
13
-
14
-
15
- if TYPE_CHECKING:
16
- from typing import TypedDict
17
-
18
- class DataframeData(TypedDict):
19
- headers: List[str]
20
- data: List[List[str | int | bool]]
21
-
22
-
23
- initial_prompt = "You are a helpful assistant."
24
- API_URL = "https://api.openai.com/v1/chat/completions"
25
-
26
- def get_response(
27
- openai_api_key, system_prompt, history, stream, selected_model
28
- ):
29
- headers = {
30
- "Content-Type": "application/json",
31
- "Authorization": f"Bearer {openai_api_key}",
32
- }
33
-
34
- history = [construct_system(system_prompt), *history]
35
-
36
- payload = {
37
- "model": selected_model,
38
- "messages": history, # [{"role": "user", "content": f"{inputs}"}],
39
- "temperature": 1.0, # 1.0,
40
- "top_p": 1.0, # 1.0,
41
- "n": 1,
42
- "stream": stream,
43
- "presence_penalty": 0,
44
- "frequency_penalty": 0,
45
- }
46
- if stream:
47
- timeout = timeout_streaming
48
- else:
49
- timeout = timeout_all
50
-
51
- # 获取环境变量中的代理设置
52
- http_proxy = os.environ.get("HTTP_PROXY") or os.environ.get("http_proxy")
53
- https_proxy = os.environ.get("HTTPS_PROXY") or os.environ.get("https_proxy")
54
-
55
- # 如果存在代理设置,使用它们
56
- proxies = {}
57
- if http_proxy:
58
- logging.info(f"Using HTTP proxy: {http_proxy}")
59
- proxies["http"] = http_proxy
60
- if https_proxy:
61
- logging.info(f"Using HTTPS proxy: {https_proxy}")
62
- proxies["https"] = https_proxy
63
-
64
- # 如果有代理,使用代理发送请求,否则使用默认设置发送请求
65
- if proxies:
66
- response = requests.post(
67
- API_URL,
68
- headers=headers,
69
- json=payload,
70
- stream=True,
71
- timeout=timeout,
72
- proxies=proxies,
73
- )
74
- else:
75
- response = requests.post(
76
- API_URL,
77
- headers=headers,
78
- json=payload,
79
- stream=True,
80
- timeout=timeout,
81
- )
82
- return response
83
-
84
-
85
- def stream_predict(
86
- openai_api_key,
87
- system_prompt,
88
- history,
89
- inputs,
90
- chatbot,
91
- all_token_counts,
92
- selected_model,
93
- fake_input=None,
94
- display_append=""
95
- ):
96
- def get_return_value():
97
- return chatbot, history, status_text, all_token_counts
98
- # logging.info("实时回答模式")
99
- partial_words = ""
100
- counter = 0
101
- status_text = "answering……"
102
- history.append(construct_user(inputs))
103
- history.append(construct_assistant(""))
104
- if fake_input:
105
- chatbot.append((fake_input, ""))
106
- else:
107
- chatbot.append((inputs, ""))
108
- user_token_count = 0
109
- if len(all_token_counts) == 0:
110
- system_prompt_token_count = count_token(construct_system(system_prompt))
111
- user_token_count = (
112
- count_token(construct_user(inputs)) + system_prompt_token_count
113
- )
114
- else:
115
- user_token_count = count_token(construct_user(inputs))
116
- all_token_counts.append(user_token_count)
117
- logging.info(f"input token count: {user_token_count}")
118
- yield get_return_value()
119
- try:
120
- response = get_response(
121
- openai_api_key,
122
- system_prompt,
123
- history,
124
- True,
125
- selected_model,
126
- )
127
- except requests.exceptions.ConnectTimeout:
128
- status_text = (
129
- standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
130
- )
131
- yield get_return_value()
132
- return
133
- except requests.exceptions.ReadTimeout:
134
- status_text = standard_error_msg + read_timeout_prompt + error_retrieve_prompt
135
- yield get_return_value()
136
- return
137
-
138
- yield get_return_value()
139
- error_json_str = ""
140
-
141
- for chunk in tqdm(response.iter_lines()):
142
- if counter == 0:
143
- counter += 1
144
- continue
145
- counter += 1
146
- # check whether each line is non-empty
147
- if chunk:
148
- chunk = chunk.decode()
149
- chunklength = len(chunk)
150
- try:
151
- chunk = json.loads(chunk[6:])
152
- except json.JSONDecodeError:
153
- logging.info(chunk)
154
- error_json_str += chunk
155
- status_text = f"JSON file parsing error. Please reset the conversation. received content: {error_json_str}"
156
- yield get_return_value()
157
- continue
158
- # decode each line as response data is in bytes
159
- if chunklength > 6 and "delta" in chunk["choices"][0]:
160
- finish_reason = chunk["choices"][0]["finish_reason"]
161
- status_text = construct_token_message(
162
- sum(all_token_counts), stream=True
163
- )
164
- if finish_reason == "stop":
165
- yield get_return_value()
166
- break
167
- try:
168
- partial_words = (
169
- partial_words + chunk["choices"][0]["delta"]["content"]
170
- )
171
- except KeyError:
172
- status_text = (
173
- standard_error_msg
174
- + "Token count has reached the maxtoken limit. Please reset the conversation. Current Token Count: "
175
- + str(sum(all_token_counts))
176
- )
177
- yield get_return_value()
178
- break
179
- history[-1] = construct_assistant(partial_words)
180
- chatbot[-1] = (chatbot[-1][0], partial_words+display_append)
181
- all_token_counts[-1] += 1
182
- yield get_return_value()
183
-
184
-
185
- def predict_all(
186
- openai_api_key,
187
- system_prompt,
188
- history,
189
- inputs,
190
- chatbot,
191
- all_token_counts,
192
- selected_model,
193
- fake_input=None,
194
- display_append=""
195
- ):
196
- # logging.info("一次性回答模式")
197
- history.append(construct_user(inputs))
198
- history.append(construct_assistant(""))
199
- if fake_input:
200
- chatbot.append((fake_input, ""))
201
- else:
202
- chatbot.append((inputs, ""))
203
- all_token_counts.append(count_token(construct_user(inputs)))
204
- try:
205
- response = get_response(
206
- openai_api_key,
207
- system_prompt,
208
- history,
209
- False,
210
- selected_model,
211
- )
212
- except requests.exceptions.ConnectTimeout:
213
- status_text = (
214
- standard_error_msg + connection_timeout_prompt + error_retrieve_prompt
215
- )
216
- return chatbot, history, status_text, all_token_counts
217
- except requests.exceptions.ProxyError:
218
- status_text = standard_error_msg + proxy_error_prompt + error_retrieve_prompt
219
- return chatbot, history, status_text, all_token_counts
220
- except requests.exceptions.SSLError:
221
- status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
222
- return chatbot, history, status_text, all_token_counts
223
- response = json.loads(response.text)
224
- content = response["choices"][0]["message"]["content"]
225
- history[-1] = construct_assistant(content)
226
- chatbot[-1] = (chatbot[-1][0], content+display_append)
227
- total_token_count = response["usage"]["total_tokens"]
228
- all_token_counts[-1] = total_token_count - sum(all_token_counts)
229
- status_text = construct_token_message(total_token_count)
230
- return chatbot, history, status_text, all_token_counts
231
-
232
-
233
- def predict(
234
- openai_api_key,
235
- system_prompt,
236
- history,
237
- inputs,
238
- chatbot,
239
- all_token_counts,
240
- stream=True,
241
- selected_model=MODELS[0],
242
- use_websearch=False,
243
- files = None,
244
- should_check_token_count=True,
245
- ): # repetition_penalty, top_k
246
-
247
- old_inputs = ""
248
- link_references = ""
249
-
250
- if len(openai_api_key) != 51:
251
- status_text = standard_error_msg + no_apikey_msg
252
- logging.info(status_text)
253
- chatbot.append((inputs, ""))
254
- if len(history) == 0:
255
- history.append(construct_user(inputs))
256
- history.append("")
257
- all_token_counts.append(0)
258
- else:
259
- history[-2] = construct_user(inputs)
260
- yield chatbot, history, status_text, all_token_counts
261
- return
262
-
263
- yield chatbot, history, "answering……", all_token_counts
264
-
265
- if stream:
266
- # logging.info("使用流式传输")
267
- iter = stream_predict(
268
- openai_api_key,
269
- system_prompt,
270
- history,
271
- inputs,
272
- chatbot,
273
- all_token_counts,
274
- selected_model,
275
- fake_input=old_inputs,
276
- display_append=link_references
277
- )
278
- for chatbot, history, status_text, all_token_counts in iter:
279
- yield chatbot, history, status_text, all_token_counts
280
- else:
281
- # logging.info("不使用流式传输")
282
- chatbot, history, status_text, all_token_counts = predict_all(
283
- openai_api_key,
284
- system_prompt,
285
- history,
286
- inputs,
287
- chatbot,
288
- all_token_counts,
289
- selected_model,
290
- fake_input=old_inputs,
291
- display_append=link_references
292
- )
293
- yield chatbot, history, status_text, all_token_counts
294
-
295
- logging.info(f"The current token count is{all_token_counts}")
296
-