yym68686 commited on
Commit
7d44776
·
1 Parent(s): 3b159d8

Add feature: support vertex claude API using tool use functionality.

Browse files
Files changed (5) hide show
  1. main.py +49 -3
  2. request.py +58 -85
  3. response.py +51 -2
  4. test/provider_test.py +1 -1
  5. utils.py +1 -0
main.py CHANGED
@@ -5,7 +5,7 @@ import secrets
5
  from contextlib import asynccontextmanager
6
 
7
  from fastapi.middleware.cors import CORSMiddleware
8
- from fastapi import FastAPI, HTTPException, Depends
9
  from fastapi.responses import StreamingResponse, JSONResponse
10
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
11
 
@@ -40,6 +40,37 @@ async def lifespan(app: FastAPI):
40
 
41
  app = FastAPI(lifespan=lifespan)
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # 配置 CORS 中间件
44
  app.add_middleware(
45
  CORSMiddleware,
@@ -219,9 +250,24 @@ def generate_api_key():
219
  api_key = "sk-" + secrets.token_urlsafe(32)
220
  return JSONResponse(content={"api_key": api_key})
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  # async def on_fetch(request, env):
223
  # import asgi
224
-
225
  # return await asgi.fetch(app, request, env)
226
 
227
  if __name__ == '__main__':
@@ -232,5 +278,5 @@ if __name__ == '__main__':
232
  port=8000,
233
  reload=True,
234
  ws="none",
235
- log_level="warning"
236
  )
 
5
  from contextlib import asynccontextmanager
6
 
7
  from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi import FastAPI, HTTPException, Depends, Request
9
  from fastapi.responses import StreamingResponse, JSONResponse
10
  from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
11
 
 
40
 
41
  app = FastAPI(lifespan=lifespan)
42
 
43
+ # from time import time
44
+ # from collections import defaultdict
45
+ # import asyncio
46
+
47
+ # class StatsMiddleware:
48
+ # def __init__(self):
49
+ # self.request_counts = defaultdict(int)
50
+ # self.request_times = defaultdict(float)
51
+ # self.ip_counts = defaultdict(lambda: defaultdict(int))
52
+ # self.lock = asyncio.Lock()
53
+
54
+ # async def __call__(self, request: Request, call_next):
55
+ # start_time = time()
56
+ # response = await call_next(request)
57
+ # process_time = time() - start_time
58
+
59
+ # endpoint = f"{request.method} {request.url.path}"
60
+ # client_ip = request.client.host
61
+
62
+ # async with self.lock:
63
+ # self.request_counts[endpoint] += 1
64
+ # self.request_times[endpoint] += process_time
65
+ # self.ip_counts[endpoint][client_ip] += 1
66
+
67
+ # return response
68
+ # # 创建 StatsMiddleware 实例
69
+ # stats_middleware = StatsMiddleware()
70
+
71
+ # # 添加 StatsMiddleware
72
+ # app.add_middleware(StatsMiddleware)
73
+
74
  # 配置 CORS 中间件
75
  app.add_middleware(
76
  CORSMiddleware,
 
250
  api_key = "sk-" + secrets.token_urlsafe(32)
251
  return JSONResponse(content={"api_key": api_key})
252
 
253
+ # @app.get("/stats")
254
+ # async def get_stats(token: str = Depends(verify_api_key)):
255
+ # async with stats_middleware.lock:
256
+ # return {
257
+ # "request_counts": dict(stats_middleware.request_counts),
258
+ # "average_request_times": {
259
+ # endpoint: total_time / count
260
+ # for endpoint, total_time in stats_middleware.request_times.items()
261
+ # for count in [stats_middleware.request_counts[endpoint]]
262
+ # },
263
+ # "ip_counts": {
264
+ # endpoint: dict(ips)
265
+ # for endpoint, ips in stats_middleware.ip_counts.items()
266
+ # }
267
+ # }
268
+
269
  # async def on_fetch(request, env):
270
  # import asgi
 
271
  # return await asgi.fetch(app, request, env)
272
 
273
  if __name__ == '__main__':
 
278
  port=8000,
279
  reload=True,
280
  ws="none",
281
+ # log_level="warning"
282
  )
request.py CHANGED
@@ -363,7 +363,7 @@ async def get_vertex_gemini_payload(request, engine, provider):
363
 
364
  async def get_vertex_claude_payload(request, engine, provider):
365
  headers = {
366
- 'Content-Type': 'application/json'
367
  }
368
  if provider.get("client_email") and provider.get("private_key"):
369
  access_token = get_access_token(provider['client_email'], provider['private_key'])
@@ -386,12 +386,10 @@ async def get_vertex_claude_payload(request, engine, provider):
386
  url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL=model, stream=claude_stream)
387
 
388
  messages = []
389
- systemInstruction = None
390
- function_arguments = None
391
  for msg in request.messages:
392
- if msg.role == "assistant":
393
- msg.role = "model"
394
  tool_calls = None
 
395
  if isinstance(msg.content, list):
396
  content = []
397
  for item in msg.content:
@@ -402,109 +400,84 @@ async def get_vertex_claude_payload(request, engine, provider):
402
  image_message = await get_image_message(item.image_url.url, engine)
403
  content.append(image_message)
404
  else:
405
- content = [{"text": msg.content}]
406
  tool_calls = msg.tool_calls
 
407
 
408
  if tool_calls:
409
- tool_call = tool_calls[0]
410
- function_arguments = {
411
- "functionCall": {
 
 
412
  "name": tool_call.function.name,
413
- "args": json.loads(tool_call.function.arguments)
414
- }
415
- }
416
- messages.append(
417
- {
418
- "role": "model",
419
- "parts": [function_arguments]
420
- }
421
- )
422
- elif msg.role == "tool":
423
- function_call_name = function_arguments["functionCall"]["name"]
424
- messages.append(
425
- {
426
- "role": "function",
427
- "parts": [{
428
- "functionResponse": {
429
- "name": function_call_name,
430
- "response": {
431
- "name": function_call_name,
432
- "content": {
433
- "result": msg.content,
434
- }
435
- }
436
- }
437
- }]
438
- }
439
- )
440
  elif msg.role != "system":
441
- messages.append({"role": msg.role, "parts": content})
442
  elif msg.role == "system":
443
- systemInstruction = {"parts": content}
444
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
445
 
 
446
  payload = {
447
- "contents": messages,
448
- # "safetySettings": [
449
- # {
450
- # "category": "HARM_CATEGORY_HARASSMENT",
451
- # "threshold": "BLOCK_NONE"
452
- # },
453
- # {
454
- # "category": "HARM_CATEGORY_HATE_SPEECH",
455
- # "threshold": "BLOCK_NONE"
456
- # },
457
- # {
458
- # "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
459
- # "threshold": "BLOCK_NONE"
460
- # },
461
- # {
462
- # "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
463
- # "threshold": "BLOCK_NONE"
464
- # }
465
- # ]
466
- "generationConfig": {
467
- "temperature": 0.5,
468
- "max_output_tokens": 8192,
469
- "top_k": 40,
470
- "top_p": 0.95
471
- },
472
  }
473
- if systemInstruction:
474
- payload["system_instruction"] = systemInstruction
475
 
476
  miss_fields = [
477
  'model',
478
  'messages',
479
- 'stream',
480
- 'tool_choice',
481
- 'temperature',
482
- 'top_p',
483
- 'max_tokens',
484
  'presence_penalty',
485
  'frequency_penalty',
486
  'n',
487
  'user',
488
  'include_usage',
489
- 'logprobs',
490
- 'top_logprobs'
491
  ]
492
 
493
  for field, value in request.model_dump(exclude_unset=True).items():
494
  if field not in miss_fields and value is not None:
495
- if field == "tools":
496
- payload.update({
497
- "tools": [{
498
- "function_declarations": [tool["function"] for tool in value]
499
- }],
500
- "tool_config": {
501
- "function_calling_config": {
502
- "mode": "AUTO"
503
- }
504
- }
505
- })
506
- else:
507
- payload[field] = value
 
 
 
508
 
509
  return url, headers, payload
510
 
 
363
 
364
  async def get_vertex_claude_payload(request, engine, provider):
365
  headers = {
366
+ 'Content-Type': 'application/json',
367
  }
368
  if provider.get("client_email") and provider.get("private_key"):
369
  access_token = get_access_token(provider['client_email'], provider['private_key'])
 
386
  url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL=model, stream=claude_stream)
387
 
388
  messages = []
389
+ system_prompt = None
 
390
  for msg in request.messages:
 
 
391
  tool_calls = None
392
+ tool_call_id = None
393
  if isinstance(msg.content, list):
394
  content = []
395
  for item in msg.content:
 
400
  image_message = await get_image_message(item.image_url.url, engine)
401
  content.append(image_message)
402
  else:
403
+ content = msg.content
404
  tool_calls = msg.tool_calls
405
+ tool_call_id = msg.tool_call_id
406
 
407
  if tool_calls:
408
+ tool_calls_list = []
409
+ for tool_call in tool_calls:
410
+ tool_calls_list.append({
411
+ "type": "tool_use",
412
+ "id": tool_call.id,
413
  "name": tool_call.function.name,
414
+ "input": json.loads(tool_call.function.arguments),
415
+ })
416
+ messages.append({"role": msg.role, "content": tool_calls_list})
417
+ elif tool_call_id:
418
+ messages.append({"role": "user", "content": [{
419
+ "type": "tool_result",
420
+ "tool_use_id": tool_call.id,
421
+ "content": content
422
+ }]})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  elif msg.role != "system":
424
+ messages.append({"role": msg.role, "content": content})
425
  elif msg.role == "system":
426
+ system_prompt = content
427
 
428
+ conversation_len = len(messages) - 1
429
+ message_index = 0
430
+ while message_index < conversation_len:
431
+ if messages[message_index]["role"] == messages[message_index + 1]["role"]:
432
+ if messages[message_index].get("content"):
433
+ if isinstance(messages[message_index]["content"], list):
434
+ messages[message_index]["content"].extend(messages[message_index + 1]["content"])
435
+ elif isinstance(messages[message_index]["content"], str) and isinstance(messages[message_index + 1]["content"], list):
436
+ content_list = [{"type": "text", "text": messages[message_index]["content"]}]
437
+ content_list.extend(messages[message_index + 1]["content"])
438
+ messages[message_index]["content"] = content_list
439
+ else:
440
+ messages[message_index]["content"] += messages[message_index + 1]["content"]
441
+ messages.pop(message_index + 1)
442
+ conversation_len = conversation_len - 1
443
+ else:
444
+ message_index = message_index + 1
445
 
446
+ model = provider['model'][request.model]
447
  payload = {
448
+ "anthropic_version": "vertex-2023-10-16",
449
+ "messages": messages,
450
+ "system": system_prompt or "You are Claude, a large language model trained by Anthropic.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
  }
 
 
452
 
453
  miss_fields = [
454
  'model',
455
  'messages',
 
 
 
 
 
456
  'presence_penalty',
457
  'frequency_penalty',
458
  'n',
459
  'user',
460
  'include_usage',
 
 
461
  ]
462
 
463
  for field, value in request.model_dump(exclude_unset=True).items():
464
  if field not in miss_fields and value is not None:
465
+ payload[field] = value
466
+
467
+ if request.tools and provider.get("tools"):
468
+ tools = []
469
+ for tool in request.tools:
470
+ json_tool = await gpt2claude_tools_json(tool.dict()["function"])
471
+ tools.append(json_tool)
472
+ payload["tools"] = tools
473
+ if "tool_choice" in payload:
474
+ payload["tool_choice"] = {
475
+ "type": "auto"
476
+ }
477
+
478
+ if provider.get("tools") == False:
479
+ payload.pop("tools", None)
480
+ payload.pop("tool_choice", None)
481
 
482
  return url, headers, payload
483
 
response.py CHANGED
@@ -84,6 +84,55 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
84
  sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=None, function_call_content=function_full_response)
85
  yield sse_string
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  async def fetch_gpt_response_stream(client, url, headers, payload, max_redirects=5):
88
  redirect_count = 0
89
  while redirect_count < max_redirects:
@@ -202,10 +251,10 @@ async def fetch_response(client, url, headers, payload):
202
 
203
  async def fetch_response_stream(client, url, headers, payload, engine, model):
204
  try:
205
- if engine == "gemini" or engine == "vertex":
206
  async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
207
  yield chunk
208
- elif engine == "claude":
209
  async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
210
  yield chunk
211
  elif engine == "gpt":
 
84
  sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=None, function_call_content=function_full_response)
85
  yield sse_string
86
 
87
+ async def fetch_vertex_claude_response_stream(client, url, headers, payload, model):
88
+ timestamp = datetime.timestamp(datetime.now())
89
+ async with client.stream('POST', url, headers=headers, json=payload) as response:
90
+ if response.status_code != 200:
91
+ error_message = await response.aread()
92
+ error_str = error_message.decode('utf-8', errors='replace')
93
+ try:
94
+ error_json = json.loads(error_str)
95
+ except json.JSONDecodeError:
96
+ error_json = error_str
97
+ yield {"error": f"fetch_gpt_response_stream HTTP Error {response.status_code}", "details": error_json}
98
+ buffer = ""
99
+ revicing_function_call = False
100
+ function_full_response = "{"
101
+ need_function_call = False
102
+ async for chunk in response.aiter_text():
103
+ buffer += chunk
104
+ while "\n" in buffer:
105
+ line, buffer = buffer.split("\n", 1)
106
+ logger.info(f"{line}")
107
+ if line and '\"text\": \"' in line:
108
+ try:
109
+ json_data = json.loads( "{" + line + "}")
110
+ content = json_data.get('text', '')
111
+ content = "\n".join(content.split("\\n"))
112
+ sse_string = await generate_sse_response(timestamp, model, content)
113
+ yield sse_string
114
+ except json.JSONDecodeError:
115
+ logger.error(f"无法解析JSON: {line}")
116
+
117
+ if line and ('\"type\": \"tool_use\"' in line or revicing_function_call):
118
+ revicing_function_call = True
119
+ need_function_call = True
120
+ if ']' in line:
121
+ revicing_function_call = False
122
+ continue
123
+
124
+ function_full_response += line
125
+
126
+ if need_function_call:
127
+ function_call = json.loads(function_full_response)
128
+ function_call_name = function_call["name"]
129
+ function_call_id = function_call["id"]
130
+ sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=function_call_name)
131
+ yield sse_string
132
+ function_full_response = json.dumps(function_call["input"])
133
+ sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=None, function_call_content=function_full_response)
134
+ yield sse_string
135
+
136
  async def fetch_gpt_response_stream(client, url, headers, payload, max_redirects=5):
137
  redirect_count = 0
138
  while redirect_count < max_redirects:
 
251
 
252
  async def fetch_response_stream(client, url, headers, payload, engine, model):
253
  try:
254
+ if engine == "gemini" or (engine == "vertex" and "gemini" in model):
255
  async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
256
  yield chunk
257
+ elif engine == "claude" or (engine == "vertex" and "claude" in model):
258
  async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
259
  yield chunk
260
  elif engine == "gpt":
test/provider_test.py CHANGED
@@ -80,7 +80,7 @@ def test_request_model(test_client, api_key, get_model):
80
 
81
  response = test_client.post("/v1/chat/completions", json=request_data, headers=headers)
82
  for line in response.iter_lines():
83
- print(line)
84
  assert response.status_code == 200
85
 
86
  if __name__ == "__main__":
 
80
 
81
  response = test_client.post("/v1/chat/completions", json=request_data, headers=headers)
82
  for line in response.iter_lines():
83
+ print(line.lstrip("data: "))
84
  assert response.status_code == 200
85
 
86
  if __name__ == "__main__":
utils.py CHANGED
@@ -80,6 +80,7 @@ async def error_handling_wrapper(generator, status_code=200):
80
  try:
81
  first_item = await generator.__anext__()
82
  first_item_str = first_item
 
83
  if isinstance(first_item_str, (bytes, bytearray)):
84
  first_item_str = first_item_str.decode("utf-8")
85
  if isinstance(first_item_str, str):
 
80
  try:
81
  first_item = await generator.__anext__()
82
  first_item_str = first_item
83
+ # logger.info("first_item_str: %s", first_item_str)
84
  if isinstance(first_item_str, (bytes, bytearray)):
85
  first_item_str = first_item_str.decode("utf-8")
86
  if isinstance(first_item_str, str):