Add feature: support vertex claude API using tool use functionality.
Browse files- main.py +49 -3
- request.py +58 -85
- response.py +51 -2
- test/provider_test.py +1 -1
- utils.py +1 -0
main.py
CHANGED
@@ -5,7 +5,7 @@ import secrets
|
|
5 |
from contextlib import asynccontextmanager
|
6 |
|
7 |
from fastapi.middleware.cors import CORSMiddleware
|
8 |
-
from fastapi import FastAPI, HTTPException, Depends
|
9 |
from fastapi.responses import StreamingResponse, JSONResponse
|
10 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
11 |
|
@@ -40,6 +40,37 @@ async def lifespan(app: FastAPI):
|
|
40 |
|
41 |
app = FastAPI(lifespan=lifespan)
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
# 配置 CORS 中间件
|
44 |
app.add_middleware(
|
45 |
CORSMiddleware,
|
@@ -219,9 +250,24 @@ def generate_api_key():
|
|
219 |
api_key = "sk-" + secrets.token_urlsafe(32)
|
220 |
return JSONResponse(content={"api_key": api_key})
|
221 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
222 |
# async def on_fetch(request, env):
|
223 |
# import asgi
|
224 |
-
|
225 |
# return await asgi.fetch(app, request, env)
|
226 |
|
227 |
if __name__ == '__main__':
|
@@ -232,5 +278,5 @@ if __name__ == '__main__':
|
|
232 |
port=8000,
|
233 |
reload=True,
|
234 |
ws="none",
|
235 |
-
log_level="warning"
|
236 |
)
|
|
|
5 |
from contextlib import asynccontextmanager
|
6 |
|
7 |
from fastapi.middleware.cors import CORSMiddleware
|
8 |
+
from fastapi import FastAPI, HTTPException, Depends, Request
|
9 |
from fastapi.responses import StreamingResponse, JSONResponse
|
10 |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
11 |
|
|
|
40 |
|
41 |
app = FastAPI(lifespan=lifespan)
|
42 |
|
43 |
+
# from time import time
|
44 |
+
# from collections import defaultdict
|
45 |
+
# import asyncio
|
46 |
+
|
47 |
+
# class StatsMiddleware:
|
48 |
+
# def __init__(self):
|
49 |
+
# self.request_counts = defaultdict(int)
|
50 |
+
# self.request_times = defaultdict(float)
|
51 |
+
# self.ip_counts = defaultdict(lambda: defaultdict(int))
|
52 |
+
# self.lock = asyncio.Lock()
|
53 |
+
|
54 |
+
# async def __call__(self, request: Request, call_next):
|
55 |
+
# start_time = time()
|
56 |
+
# response = await call_next(request)
|
57 |
+
# process_time = time() - start_time
|
58 |
+
|
59 |
+
# endpoint = f"{request.method} {request.url.path}"
|
60 |
+
# client_ip = request.client.host
|
61 |
+
|
62 |
+
# async with self.lock:
|
63 |
+
# self.request_counts[endpoint] += 1
|
64 |
+
# self.request_times[endpoint] += process_time
|
65 |
+
# self.ip_counts[endpoint][client_ip] += 1
|
66 |
+
|
67 |
+
# return response
|
68 |
+
# # 创建 StatsMiddleware 实例
|
69 |
+
# stats_middleware = StatsMiddleware()
|
70 |
+
|
71 |
+
# # 添加 StatsMiddleware
|
72 |
+
# app.add_middleware(StatsMiddleware)
|
73 |
+
|
74 |
# 配置 CORS 中间件
|
75 |
app.add_middleware(
|
76 |
CORSMiddleware,
|
|
|
250 |
api_key = "sk-" + secrets.token_urlsafe(32)
|
251 |
return JSONResponse(content={"api_key": api_key})
|
252 |
|
253 |
+
# @app.get("/stats")
|
254 |
+
# async def get_stats(token: str = Depends(verify_api_key)):
|
255 |
+
# async with stats_middleware.lock:
|
256 |
+
# return {
|
257 |
+
# "request_counts": dict(stats_middleware.request_counts),
|
258 |
+
# "average_request_times": {
|
259 |
+
# endpoint: total_time / count
|
260 |
+
# for endpoint, total_time in stats_middleware.request_times.items()
|
261 |
+
# for count in [stats_middleware.request_counts[endpoint]]
|
262 |
+
# },
|
263 |
+
# "ip_counts": {
|
264 |
+
# endpoint: dict(ips)
|
265 |
+
# for endpoint, ips in stats_middleware.ip_counts.items()
|
266 |
+
# }
|
267 |
+
# }
|
268 |
+
|
269 |
# async def on_fetch(request, env):
|
270 |
# import asgi
|
|
|
271 |
# return await asgi.fetch(app, request, env)
|
272 |
|
273 |
if __name__ == '__main__':
|
|
|
278 |
port=8000,
|
279 |
reload=True,
|
280 |
ws="none",
|
281 |
+
# log_level="warning"
|
282 |
)
|
request.py
CHANGED
@@ -363,7 +363,7 @@ async def get_vertex_gemini_payload(request, engine, provider):
|
|
363 |
|
364 |
async def get_vertex_claude_payload(request, engine, provider):
|
365 |
headers = {
|
366 |
-
'Content-Type': 'application/json'
|
367 |
}
|
368 |
if provider.get("client_email") and provider.get("private_key"):
|
369 |
access_token = get_access_token(provider['client_email'], provider['private_key'])
|
@@ -386,12 +386,10 @@ async def get_vertex_claude_payload(request, engine, provider):
|
|
386 |
url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL=model, stream=claude_stream)
|
387 |
|
388 |
messages = []
|
389 |
-
|
390 |
-
function_arguments = None
|
391 |
for msg in request.messages:
|
392 |
-
if msg.role == "assistant":
|
393 |
-
msg.role = "model"
|
394 |
tool_calls = None
|
|
|
395 |
if isinstance(msg.content, list):
|
396 |
content = []
|
397 |
for item in msg.content:
|
@@ -402,109 +400,84 @@ async def get_vertex_claude_payload(request, engine, provider):
|
|
402 |
image_message = await get_image_message(item.image_url.url, engine)
|
403 |
content.append(image_message)
|
404 |
else:
|
405 |
-
content =
|
406 |
tool_calls = msg.tool_calls
|
|
|
407 |
|
408 |
if tool_calls:
|
409 |
-
|
410 |
-
|
411 |
-
|
|
|
|
|
412 |
"name": tool_call.function.name,
|
413 |
-
"
|
414 |
-
}
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
)
|
422 |
-
elif msg.role == "tool":
|
423 |
-
function_call_name = function_arguments["functionCall"]["name"]
|
424 |
-
messages.append(
|
425 |
-
{
|
426 |
-
"role": "function",
|
427 |
-
"parts": [{
|
428 |
-
"functionResponse": {
|
429 |
-
"name": function_call_name,
|
430 |
-
"response": {
|
431 |
-
"name": function_call_name,
|
432 |
-
"content": {
|
433 |
-
"result": msg.content,
|
434 |
-
}
|
435 |
-
}
|
436 |
-
}
|
437 |
-
}]
|
438 |
-
}
|
439 |
-
)
|
440 |
elif msg.role != "system":
|
441 |
-
messages.append({"role": msg.role, "
|
442 |
elif msg.role == "system":
|
443 |
-
|
444 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
|
|
|
446 |
payload = {
|
447 |
-
"
|
448 |
-
|
449 |
-
|
450 |
-
# "category": "HARM_CATEGORY_HARASSMENT",
|
451 |
-
# "threshold": "BLOCK_NONE"
|
452 |
-
# },
|
453 |
-
# {
|
454 |
-
# "category": "HARM_CATEGORY_HATE_SPEECH",
|
455 |
-
# "threshold": "BLOCK_NONE"
|
456 |
-
# },
|
457 |
-
# {
|
458 |
-
# "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
459 |
-
# "threshold": "BLOCK_NONE"
|
460 |
-
# },
|
461 |
-
# {
|
462 |
-
# "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
463 |
-
# "threshold": "BLOCK_NONE"
|
464 |
-
# }
|
465 |
-
# ]
|
466 |
-
"generationConfig": {
|
467 |
-
"temperature": 0.5,
|
468 |
-
"max_output_tokens": 8192,
|
469 |
-
"top_k": 40,
|
470 |
-
"top_p": 0.95
|
471 |
-
},
|
472 |
}
|
473 |
-
if systemInstruction:
|
474 |
-
payload["system_instruction"] = systemInstruction
|
475 |
|
476 |
miss_fields = [
|
477 |
'model',
|
478 |
'messages',
|
479 |
-
'stream',
|
480 |
-
'tool_choice',
|
481 |
-
'temperature',
|
482 |
-
'top_p',
|
483 |
-
'max_tokens',
|
484 |
'presence_penalty',
|
485 |
'frequency_penalty',
|
486 |
'n',
|
487 |
'user',
|
488 |
'include_usage',
|
489 |
-
'logprobs',
|
490 |
-
'top_logprobs'
|
491 |
]
|
492 |
|
493 |
for field, value in request.model_dump(exclude_unset=True).items():
|
494 |
if field not in miss_fields and value is not None:
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
|
|
|
|
|
|
508 |
|
509 |
return url, headers, payload
|
510 |
|
|
|
363 |
|
364 |
async def get_vertex_claude_payload(request, engine, provider):
|
365 |
headers = {
|
366 |
+
'Content-Type': 'application/json',
|
367 |
}
|
368 |
if provider.get("client_email") and provider.get("private_key"):
|
369 |
access_token = get_access_token(provider['client_email'], provider['private_key'])
|
|
|
386 |
url = "https://{LOCATION}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{LOCATION}/publishers/anthropic/models/{MODEL}:{stream}".format(LOCATION=location.next(), PROJECT_ID=project_id, MODEL=model, stream=claude_stream)
|
387 |
|
388 |
messages = []
|
389 |
+
system_prompt = None
|
|
|
390 |
for msg in request.messages:
|
|
|
|
|
391 |
tool_calls = None
|
392 |
+
tool_call_id = None
|
393 |
if isinstance(msg.content, list):
|
394 |
content = []
|
395 |
for item in msg.content:
|
|
|
400 |
image_message = await get_image_message(item.image_url.url, engine)
|
401 |
content.append(image_message)
|
402 |
else:
|
403 |
+
content = msg.content
|
404 |
tool_calls = msg.tool_calls
|
405 |
+
tool_call_id = msg.tool_call_id
|
406 |
|
407 |
if tool_calls:
|
408 |
+
tool_calls_list = []
|
409 |
+
for tool_call in tool_calls:
|
410 |
+
tool_calls_list.append({
|
411 |
+
"type": "tool_use",
|
412 |
+
"id": tool_call.id,
|
413 |
"name": tool_call.function.name,
|
414 |
+
"input": json.loads(tool_call.function.arguments),
|
415 |
+
})
|
416 |
+
messages.append({"role": msg.role, "content": tool_calls_list})
|
417 |
+
elif tool_call_id:
|
418 |
+
messages.append({"role": "user", "content": [{
|
419 |
+
"type": "tool_result",
|
420 |
+
"tool_use_id": tool_call.id,
|
421 |
+
"content": content
|
422 |
+
}]})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
423 |
elif msg.role != "system":
|
424 |
+
messages.append({"role": msg.role, "content": content})
|
425 |
elif msg.role == "system":
|
426 |
+
system_prompt = content
|
427 |
|
428 |
+
conversation_len = len(messages) - 1
|
429 |
+
message_index = 0
|
430 |
+
while message_index < conversation_len:
|
431 |
+
if messages[message_index]["role"] == messages[message_index + 1]["role"]:
|
432 |
+
if messages[message_index].get("content"):
|
433 |
+
if isinstance(messages[message_index]["content"], list):
|
434 |
+
messages[message_index]["content"].extend(messages[message_index + 1]["content"])
|
435 |
+
elif isinstance(messages[message_index]["content"], str) and isinstance(messages[message_index + 1]["content"], list):
|
436 |
+
content_list = [{"type": "text", "text": messages[message_index]["content"]}]
|
437 |
+
content_list.extend(messages[message_index + 1]["content"])
|
438 |
+
messages[message_index]["content"] = content_list
|
439 |
+
else:
|
440 |
+
messages[message_index]["content"] += messages[message_index + 1]["content"]
|
441 |
+
messages.pop(message_index + 1)
|
442 |
+
conversation_len = conversation_len - 1
|
443 |
+
else:
|
444 |
+
message_index = message_index + 1
|
445 |
|
446 |
+
model = provider['model'][request.model]
|
447 |
payload = {
|
448 |
+
"anthropic_version": "vertex-2023-10-16",
|
449 |
+
"messages": messages,
|
450 |
+
"system": system_prompt or "You are Claude, a large language model trained by Anthropic.",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
451 |
}
|
|
|
|
|
452 |
|
453 |
miss_fields = [
|
454 |
'model',
|
455 |
'messages',
|
|
|
|
|
|
|
|
|
|
|
456 |
'presence_penalty',
|
457 |
'frequency_penalty',
|
458 |
'n',
|
459 |
'user',
|
460 |
'include_usage',
|
|
|
|
|
461 |
]
|
462 |
|
463 |
for field, value in request.model_dump(exclude_unset=True).items():
|
464 |
if field not in miss_fields and value is not None:
|
465 |
+
payload[field] = value
|
466 |
+
|
467 |
+
if request.tools and provider.get("tools"):
|
468 |
+
tools = []
|
469 |
+
for tool in request.tools:
|
470 |
+
json_tool = await gpt2claude_tools_json(tool.dict()["function"])
|
471 |
+
tools.append(json_tool)
|
472 |
+
payload["tools"] = tools
|
473 |
+
if "tool_choice" in payload:
|
474 |
+
payload["tool_choice"] = {
|
475 |
+
"type": "auto"
|
476 |
+
}
|
477 |
+
|
478 |
+
if provider.get("tools") == False:
|
479 |
+
payload.pop("tools", None)
|
480 |
+
payload.pop("tool_choice", None)
|
481 |
|
482 |
return url, headers, payload
|
483 |
|
response.py
CHANGED
@@ -84,6 +84,55 @@ async def fetch_gemini_response_stream(client, url, headers, payload, model):
|
|
84 |
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=None, function_call_content=function_full_response)
|
85 |
yield sse_string
|
86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
async def fetch_gpt_response_stream(client, url, headers, payload, max_redirects=5):
|
88 |
redirect_count = 0
|
89 |
while redirect_count < max_redirects:
|
@@ -202,10 +251,10 @@ async def fetch_response(client, url, headers, payload):
|
|
202 |
|
203 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
204 |
try:
|
205 |
-
if engine == "gemini" or engine == "vertex":
|
206 |
async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
|
207 |
yield chunk
|
208 |
-
elif engine == "claude":
|
209 |
async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
|
210 |
yield chunk
|
211 |
elif engine == "gpt":
|
|
|
84 |
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id="chatcmpl-9inWv0yEtgn873CxMBzHeCeiHctTV", function_call_name=None, function_call_content=function_full_response)
|
85 |
yield sse_string
|
86 |
|
87 |
+
async def fetch_vertex_claude_response_stream(client, url, headers, payload, model):
|
88 |
+
timestamp = datetime.timestamp(datetime.now())
|
89 |
+
async with client.stream('POST', url, headers=headers, json=payload) as response:
|
90 |
+
if response.status_code != 200:
|
91 |
+
error_message = await response.aread()
|
92 |
+
error_str = error_message.decode('utf-8', errors='replace')
|
93 |
+
try:
|
94 |
+
error_json = json.loads(error_str)
|
95 |
+
except json.JSONDecodeError:
|
96 |
+
error_json = error_str
|
97 |
+
yield {"error": f"fetch_gpt_response_stream HTTP Error {response.status_code}", "details": error_json}
|
98 |
+
buffer = ""
|
99 |
+
revicing_function_call = False
|
100 |
+
function_full_response = "{"
|
101 |
+
need_function_call = False
|
102 |
+
async for chunk in response.aiter_text():
|
103 |
+
buffer += chunk
|
104 |
+
while "\n" in buffer:
|
105 |
+
line, buffer = buffer.split("\n", 1)
|
106 |
+
logger.info(f"{line}")
|
107 |
+
if line and '\"text\": \"' in line:
|
108 |
+
try:
|
109 |
+
json_data = json.loads( "{" + line + "}")
|
110 |
+
content = json_data.get('text', '')
|
111 |
+
content = "\n".join(content.split("\\n"))
|
112 |
+
sse_string = await generate_sse_response(timestamp, model, content)
|
113 |
+
yield sse_string
|
114 |
+
except json.JSONDecodeError:
|
115 |
+
logger.error(f"无法解析JSON: {line}")
|
116 |
+
|
117 |
+
if line and ('\"type\": \"tool_use\"' in line or revicing_function_call):
|
118 |
+
revicing_function_call = True
|
119 |
+
need_function_call = True
|
120 |
+
if ']' in line:
|
121 |
+
revicing_function_call = False
|
122 |
+
continue
|
123 |
+
|
124 |
+
function_full_response += line
|
125 |
+
|
126 |
+
if need_function_call:
|
127 |
+
function_call = json.loads(function_full_response)
|
128 |
+
function_call_name = function_call["name"]
|
129 |
+
function_call_id = function_call["id"]
|
130 |
+
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=function_call_name)
|
131 |
+
yield sse_string
|
132 |
+
function_full_response = json.dumps(function_call["input"])
|
133 |
+
sse_string = await generate_sse_response(timestamp, model, content=None, tools_id=function_call_id, function_call_name=None, function_call_content=function_full_response)
|
134 |
+
yield sse_string
|
135 |
+
|
136 |
async def fetch_gpt_response_stream(client, url, headers, payload, max_redirects=5):
|
137 |
redirect_count = 0
|
138 |
while redirect_count < max_redirects:
|
|
|
251 |
|
252 |
async def fetch_response_stream(client, url, headers, payload, engine, model):
|
253 |
try:
|
254 |
+
if engine == "gemini" or (engine == "vertex" and "gemini" in model):
|
255 |
async for chunk in fetch_gemini_response_stream(client, url, headers, payload, model):
|
256 |
yield chunk
|
257 |
+
elif engine == "claude" or (engine == "vertex" and "claude" in model):
|
258 |
async for chunk in fetch_claude_response_stream(client, url, headers, payload, model):
|
259 |
yield chunk
|
260 |
elif engine == "gpt":
|
test/provider_test.py
CHANGED
@@ -80,7 +80,7 @@ def test_request_model(test_client, api_key, get_model):
|
|
80 |
|
81 |
response = test_client.post("/v1/chat/completions", json=request_data, headers=headers)
|
82 |
for line in response.iter_lines():
|
83 |
-
print(line)
|
84 |
assert response.status_code == 200
|
85 |
|
86 |
if __name__ == "__main__":
|
|
|
80 |
|
81 |
response = test_client.post("/v1/chat/completions", json=request_data, headers=headers)
|
82 |
for line in response.iter_lines():
|
83 |
+
print(line.lstrip("data: "))
|
84 |
assert response.status_code == 200
|
85 |
|
86 |
if __name__ == "__main__":
|
utils.py
CHANGED
@@ -80,6 +80,7 @@ async def error_handling_wrapper(generator, status_code=200):
|
|
80 |
try:
|
81 |
first_item = await generator.__anext__()
|
82 |
first_item_str = first_item
|
|
|
83 |
if isinstance(first_item_str, (bytes, bytearray)):
|
84 |
first_item_str = first_item_str.decode("utf-8")
|
85 |
if isinstance(first_item_str, str):
|
|
|
80 |
try:
|
81 |
first_item = await generator.__anext__()
|
82 |
first_item_str = first_item
|
83 |
+
# logger.info("first_item_str: %s", first_item_str)
|
84 |
if isinstance(first_item_str, (bytes, bytearray)):
|
85 |
first_item_str = first_item_str.decode("utf-8")
|
86 |
if isinstance(first_item_str, str):
|