🐛 Bug: Fix the bug where the official Claude API does not correctly pass the token count.
Browse files- response.py +19 -9
response.py
CHANGED
@@ -5,7 +5,7 @@ from datetime import datetime
|
|
5 |
from log_config import logger
|
6 |
|
7 |
|
8 |
-
async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None,
|
9 |
sample_data = {
|
10 |
"id": "chatcmpl-9ijPeRHa0wtyA2G8wq5z8FC3wGMzc",
|
11 |
"object": "chat.completion.chunk",
|
@@ -29,6 +29,10 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
|
|
29 |
# sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
|
30 |
if role:
|
31 |
sample_data["choices"][0]["delta"] = {"role": role, "content": ""}
|
|
|
|
|
|
|
|
|
32 |
json_data = json.dumps(sample_data, ensure_ascii=False)
|
33 |
|
34 |
# 构建SSE响应
|
@@ -68,7 +72,7 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
|
68 |
json_data = json.loads( "{" + line + "}")
|
69 |
content = json_data.get('text', '')
|
70 |
content = "\n".join(content.split("\\n"))
|
71 |
-
sse_string = await generate_sse_response(timestamp, model, content)
|
72 |
yield sse_string
|
73 |
except json.JSONDecodeError:
|
74 |
logger.error(f"无法解析JSON: {line}")
|
@@ -114,7 +118,7 @@ async def fetch_vertex_claude_response_stream(client, url, headers, payload, mod
|
|
114 |
json_data = json.loads( "{" + line + "}")
|
115 |
content = json_data.get('text', '')
|
116 |
content = "\n".join(content.split("\\n"))
|
117 |
-
sse_string = await generate_sse_response(timestamp, model, content)
|
118 |
yield sse_string
|
119 |
except json.JSONDecodeError:
|
120 |
logger.error(f"无法解析JSON: {line}")
|
@@ -163,6 +167,7 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
|
|
163 |
yield error_message
|
164 |
return
|
165 |
buffer = ""
|
|
|
166 |
async for chunk in response.aiter_text():
|
167 |
# logger.info(f"chunk: {repr(chunk)}")
|
168 |
buffer += chunk
|
@@ -171,20 +176,25 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
|
|
171 |
# logger.info(line)
|
172 |
|
173 |
if line.startswith("data:"):
|
174 |
-
line = line
|
175 |
-
if line.startswith(" "):
|
176 |
-
line = line[1:]
|
177 |
resp: dict = json.loads(line)
|
178 |
message = resp.get("message")
|
179 |
if message:
|
180 |
-
tokens_use = resp.get("usage")
|
181 |
role = message.get("role")
|
182 |
if role:
|
183 |
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role)
|
184 |
yield sse_string
|
|
|
185 |
if tokens_use:
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
tool_use = resp.get("content_block")
|
189 |
tools_id = None
|
190 |
function_call_name = None
|
|
|
5 |
from log_config import logger
|
6 |
|
7 |
|
8 |
+
async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0):
|
9 |
sample_data = {
|
10 |
"id": "chatcmpl-9ijPeRHa0wtyA2G8wq5z8FC3wGMzc",
|
11 |
"object": "chat.completion.chunk",
|
|
|
29 |
# sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
|
30 |
if role:
|
31 |
sample_data["choices"][0]["delta"] = {"role": role, "content": ""}
|
32 |
+
if total_tokens:
|
33 |
+
total_tokens = prompt_tokens + completion_tokens
|
34 |
+
sample_data["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens,"total_tokens": total_tokens}
|
35 |
+
sample_data["choices"] = []
|
36 |
json_data = json.dumps(sample_data, ensure_ascii=False)
|
37 |
|
38 |
# 构建SSE响应
|
|
|
72 |
json_data = json.loads( "{" + line + "}")
|
73 |
content = json_data.get('text', '')
|
74 |
content = "\n".join(content.split("\\n"))
|
75 |
+
sse_string = await generate_sse_response(timestamp, model, content=content)
|
76 |
yield sse_string
|
77 |
except json.JSONDecodeError:
|
78 |
logger.error(f"无法解析JSON: {line}")
|
|
|
118 |
json_data = json.loads( "{" + line + "}")
|
119 |
content = json_data.get('text', '')
|
120 |
content = "\n".join(content.split("\\n"))
|
121 |
+
sse_string = await generate_sse_response(timestamp, model, content=content)
|
122 |
yield sse_string
|
123 |
except json.JSONDecodeError:
|
124 |
logger.error(f"无法解析JSON: {line}")
|
|
|
167 |
yield error_message
|
168 |
return
|
169 |
buffer = ""
|
170 |
+
input_tokens = 0
|
171 |
async for chunk in response.aiter_text():
|
172 |
# logger.info(f"chunk: {repr(chunk)}")
|
173 |
buffer += chunk
|
|
|
176 |
# logger.info(line)
|
177 |
|
178 |
if line.startswith("data:"):
|
179 |
+
line = line.lstrip("data: ")
|
|
|
|
|
180 |
resp: dict = json.loads(line)
|
181 |
message = resp.get("message")
|
182 |
if message:
|
|
|
183 |
role = message.get("role")
|
184 |
if role:
|
185 |
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role)
|
186 |
yield sse_string
|
187 |
+
tokens_use = message.get("usage")
|
188 |
if tokens_use:
|
189 |
+
input_tokens = tokens_use.get("input_tokens", 0)
|
190 |
+
usage = resp.get("usage")
|
191 |
+
if usage:
|
192 |
+
output_tokens = usage.get("output_tokens", 0)
|
193 |
+
total_tokens = input_tokens + output_tokens
|
194 |
+
sse_string = await generate_sse_response(timestamp, model, None, None, None, None, None, total_tokens, input_tokens, output_tokens)
|
195 |
+
yield sse_string
|
196 |
+
# print("\n\rtotal_tokens", total_tokens)
|
197 |
+
|
198 |
tool_use = resp.get("content_block")
|
199 |
tools_id = None
|
200 |
function_call_name = None
|