yym68686 commited on
Commit
1de140d
·
1 Parent(s): 9874f60

🐛 Bug: Fix the bug where the official Claude API does not correctly pass the token count.

Browse files
Files changed (1) hide show
  1. response.py +19 -9
response.py CHANGED
@@ -5,7 +5,7 @@ from datetime import datetime
5
  from log_config import logger
6
 
7
 
8
- async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, tokens_use=None, total_tokens=None):
9
  sample_data = {
10
  "id": "chatcmpl-9ijPeRHa0wtyA2G8wq5z8FC3wGMzc",
11
  "object": "chat.completion.chunk",
@@ -29,6 +29,10 @@ async def generate_sse_response(timestamp, model, content=None, tools_id=None, f
29
  # sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
30
  if role:
31
  sample_data["choices"][0]["delta"] = {"role": role, "content": ""}
 
 
 
 
32
  json_data = json.dumps(sample_data, ensure_ascii=False)
33
 
34
  # 构建SSE响应
@@ -68,7 +72,7 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
68
  json_data = json.loads( "{" + line + "}")
69
  content = json_data.get('text', '')
70
  content = "\n".join(content.split("\\n"))
71
- sse_string = await generate_sse_response(timestamp, model, content)
72
  yield sse_string
73
  except json.JSONDecodeError:
74
  logger.error(f"无法解析JSON: {line}")
@@ -114,7 +118,7 @@ async def fetch_vertex_claude_response_stream(client, url, headers, payload, mod
114
  json_data = json.loads( "{" + line + "}")
115
  content = json_data.get('text', '')
116
  content = "\n".join(content.split("\\n"))
117
- sse_string = await generate_sse_response(timestamp, model, content)
118
  yield sse_string
119
  except json.JSONDecodeError:
120
  logger.error(f"无法解析JSON: {line}")
@@ -163,6 +167,7 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
163
  yield error_message
164
  return
165
  buffer = ""
 
166
  async for chunk in response.aiter_text():
167
  # logger.info(f"chunk: {repr(chunk)}")
168
  buffer += chunk
@@ -171,20 +176,25 @@ async def fetch_claude_response_stream(client, url, headers, payload, model):
171
  # logger.info(line)
172
 
173
  if line.startswith("data:"):
174
- line = line[5:]
175
- if line.startswith(" "):
176
- line = line[1:]
177
  resp: dict = json.loads(line)
178
  message = resp.get("message")
179
  if message:
180
- tokens_use = resp.get("usage")
181
  role = message.get("role")
182
  if role:
183
  sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role)
184
  yield sse_string
 
185
  if tokens_use:
186
- total_tokens = tokens_use["input_tokens"] + tokens_use["output_tokens"]
187
- # print("\n\rtotal_tokens", total_tokens)
 
 
 
 
 
 
 
188
  tool_use = resp.get("content_block")
189
  tools_id = None
190
  function_call_name = None
 
5
  from log_config import logger
6
 
7
 
8
+ async def generate_sse_response(timestamp, model, content=None, tools_id=None, function_call_name=None, function_call_content=None, role=None, total_tokens=0, prompt_tokens=0, completion_tokens=0):
9
  sample_data = {
10
  "id": "chatcmpl-9ijPeRHa0wtyA2G8wq5z8FC3wGMzc",
11
  "object": "chat.completion.chunk",
 
29
  # sample_data["choices"][0]["delta"] = {"tool_calls":[{"index":0,"function":{"id": tools_id, "name": function_call_name}}]}
30
  if role:
31
  sample_data["choices"][0]["delta"] = {"role": role, "content": ""}
32
+ if total_tokens:
33
+ total_tokens = prompt_tokens + completion_tokens
34
+ sample_data["usage"] = {"prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens,"total_tokens": total_tokens}
35
+ sample_data["choices"] = []
36
  json_data = json.dumps(sample_data, ensure_ascii=False)
37
 
38
  # 构建SSE响应
 
72
  json_data = json.loads( "{" + line + "}")
73
  content = json_data.get('text', '')
74
  content = "\n".join(content.split("\\n"))
75
+ sse_string = await generate_sse_response(timestamp, model, content=content)
76
  yield sse_string
77
  except json.JSONDecodeError:
78
  logger.error(f"无法解析JSON: {line}")
 
118
  json_data = json.loads( "{" + line + "}")
119
  content = json_data.get('text', '')
120
  content = "\n".join(content.split("\\n"))
121
+ sse_string = await generate_sse_response(timestamp, model, content=content)
122
  yield sse_string
123
  except json.JSONDecodeError:
124
  logger.error(f"无法解析JSON: {line}")
 
167
  yield error_message
168
  return
169
  buffer = ""
170
+ input_tokens = 0
171
  async for chunk in response.aiter_text():
172
  # logger.info(f"chunk: {repr(chunk)}")
173
  buffer += chunk
 
176
  # logger.info(line)
177
 
178
  if line.startswith("data:"):
179
+ line = line.lstrip("data: ")
 
 
180
  resp: dict = json.loads(line)
181
  message = resp.get("message")
182
  if message:
 
183
  role = message.get("role")
184
  if role:
185
  sse_string = await generate_sse_response(timestamp, model, None, None, None, None, role)
186
  yield sse_string
187
+ tokens_use = message.get("usage")
188
  if tokens_use:
189
+ input_tokens = tokens_use.get("input_tokens", 0)
190
+ usage = resp.get("usage")
191
+ if usage:
192
+ output_tokens = usage.get("output_tokens", 0)
193
+ total_tokens = input_tokens + output_tokens
194
+ sse_string = await generate_sse_response(timestamp, model, None, None, None, None, None, total_tokens, input_tokens, output_tokens)
195
+ yield sse_string
196
+ # print("\n\rtotal_tokens", total_tokens)
197
+
198
  tool_use = resp.get("content_block")
199
  tools_id = None
200
  function_call_name = None