Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,12 +1,10 @@
|
|
1 |
-
from flask import Flask, request, Response, jsonify
|
2 |
import os
|
3 |
-
import requests
|
4 |
import time
|
5 |
-
import
|
|
|
|
|
|
|
6 |
|
7 |
-
app = Flask(__name__)
|
8 |
-
|
9 |
-
# 环境变量
|
10 |
PROJECT_ID = os.getenv('PROJECT_ID')
|
11 |
CLIENT_ID = os.getenv('CLIENT_ID')
|
12 |
CLIENT_SECRET = os.getenv('CLIENT_SECRET')
|
@@ -21,36 +19,31 @@ token_cache = {
|
|
21 |
'refresh_promise': None
|
22 |
}
|
23 |
|
24 |
-
def get_access_token():
|
25 |
now = time.time()
|
26 |
|
27 |
-
# 如果 token 仍然有效,直接返回
|
28 |
if token_cache['access_token'] and now < token_cache['expiry'] - 120:
|
29 |
return token_cache['access_token']
|
30 |
|
31 |
-
# 如果已经有一个刷新操作在进行中,等待它完成
|
32 |
if token_cache['refresh_promise']:
|
33 |
-
token_cache['refresh_promise']
|
34 |
return token_cache['access_token']
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
response = requests.post(TOKEN_URL, json={
|
40 |
'client_id': CLIENT_ID,
|
41 |
'client_secret': CLIENT_SECRET,
|
42 |
'refresh_token': REFRESH_TOKEN,
|
43 |
'grant_type': 'refresh_token'
|
44 |
-
})
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
token_cache['refresh_promise'] =
|
52 |
-
token_cache['refresh_promise'].start()
|
53 |
-
token_cache['refresh_promise'].join()
|
54 |
return token_cache['access_token']
|
55 |
|
56 |
def get_location():
|
@@ -60,48 +53,62 @@ def get_location():
|
|
60 |
def construct_api_url(location):
|
61 |
return f'https://{location}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{location}/publishers/anthropic/models/{MODEL}:streamRawPredict'
|
62 |
|
63 |
-
|
64 |
-
def handle_request():
|
65 |
if request.method == 'OPTIONS':
|
66 |
return handle_options()
|
67 |
|
68 |
-
# 检查 x-api-key
|
69 |
api_key = request.headers.get('x-api-key')
|
70 |
if api_key != API_KEY:
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
'
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
location = get_location()
|
81 |
-
model = request.headers.get('model', 'claude-3-5-sonnet@20240620')
|
82 |
-
if model == 'claude-3-5-sonnet-20240620':
|
83 |
-
model = 'claude-3-5-sonnet@20240620'
|
84 |
-
|
85 |
api_url = construct_api_url(location)
|
86 |
|
87 |
-
request_body = request.json
|
|
|
88 |
if 'anthropic_version' in request_body:
|
89 |
del request_body['anthropic_version']
|
90 |
if 'model' in request_body:
|
91 |
del request_body['model']
|
92 |
-
|
|
|
93 |
|
94 |
headers = {
|
95 |
'Authorization': f'Bearer {access_token}',
|
96 |
'Content-Type': 'application/json; charset=utf-8'
|
97 |
}
|
98 |
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
-
|
103 |
-
def index():
|
104 |
-
return "Vertex Claude API Proxy", 200
|
105 |
|
106 |
def handle_options():
|
107 |
headers = {
|
@@ -109,7 +116,10 @@ def handle_options():
|
|
109 |
'Access-Control-Allow-Methods': 'POST, GET, OPTIONS',
|
110 |
'Access-Control-Allow-Headers': 'Content-Type, Authorization, x-api-key, anthropic-version, model'
|
111 |
}
|
112 |
-
return
|
|
|
|
|
|
|
113 |
|
114 |
if __name__ == '__main__':
|
115 |
-
|
|
|
|
|
1 |
import os
|
|
|
2 |
import time
|
3 |
+
import json
|
4 |
+
import asyncio
|
5 |
+
import aiohttp
|
6 |
+
from aiohttp import web
|
7 |
|
|
|
|
|
|
|
8 |
PROJECT_ID = os.getenv('PROJECT_ID')
|
9 |
CLIENT_ID = os.getenv('CLIENT_ID')
|
10 |
CLIENT_SECRET = os.getenv('CLIENT_SECRET')
|
|
|
19 |
'refresh_promise': None
|
20 |
}
|
21 |
|
22 |
+
async def get_access_token():
|
23 |
now = time.time()
|
24 |
|
|
|
25 |
if token_cache['access_token'] and now < token_cache['expiry'] - 120:
|
26 |
return token_cache['access_token']
|
27 |
|
|
|
28 |
if token_cache['refresh_promise']:
|
29 |
+
await token_cache['refresh_promise']
|
30 |
return token_cache['access_token']
|
31 |
|
32 |
+
async def refresh_token():
|
33 |
+
async with aiohttp.ClientSession() as session:
|
34 |
+
async with session.post(TOKEN_URL, json={
|
|
|
35 |
'client_id': CLIENT_ID,
|
36 |
'client_secret': CLIENT_SECRET,
|
37 |
'refresh_token': REFRESH_TOKEN,
|
38 |
'grant_type': 'refresh_token'
|
39 |
+
}) as response:
|
40 |
+
data = await response.json()
|
41 |
+
token_cache['access_token'] = data['access_token']
|
42 |
+
token_cache['expiry'] = now + data['expires_in']
|
43 |
+
|
44 |
+
token_cache['refresh_promise'] = refresh_token()
|
45 |
+
await token_cache['refresh_promise']
|
46 |
+
token_cache['refresh_promise'] = None
|
|
|
|
|
47 |
return token_cache['access_token']
|
48 |
|
49 |
def get_location():
|
|
|
53 |
def construct_api_url(location):
|
54 |
return f'https://{location}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{location}/publishers/anthropic/models/{MODEL}:streamRawPredict'
|
55 |
|
56 |
+
async def handle_request(request):
|
|
|
57 |
if request.method == 'OPTIONS':
|
58 |
return handle_options()
|
59 |
|
|
|
60 |
api_key = request.headers.get('x-api-key')
|
61 |
if api_key != API_KEY:
|
62 |
+
error_response = web.Response(
|
63 |
+
text=json.dumps({
|
64 |
+
'type': 'error',
|
65 |
+
'error': {
|
66 |
+
'type': 'permission_error',
|
67 |
+
'message': 'Your API key does not have permission to use the specified resource.'
|
68 |
+
}
|
69 |
+
}),
|
70 |
+
status=403,
|
71 |
+
content_type='application/json'
|
72 |
+
)
|
73 |
+
error_response.headers['Access-Control-Allow-Origin'] = '*'
|
74 |
+
error_response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS, DELETE, HEAD'
|
75 |
+
error_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'
|
76 |
+
return error_response
|
77 |
+
|
78 |
+
access_token = await get_access_token()
|
79 |
location = get_location()
|
|
|
|
|
|
|
|
|
80 |
api_url = construct_api_url(location)
|
81 |
|
82 |
+
request_body = await request.json()
|
83 |
+
|
84 |
if 'anthropic_version' in request_body:
|
85 |
del request_body['anthropic_version']
|
86 |
if 'model' in request_body:
|
87 |
del request_body['model']
|
88 |
+
|
89 |
+
request_body['anthropic_version'] = 'vertex-2023-10-16'
|
90 |
|
91 |
headers = {
|
92 |
'Authorization': f'Bearer {access_token}',
|
93 |
'Content-Type': 'application/json; charset=utf-8'
|
94 |
}
|
95 |
|
96 |
+
async with aiohttp.ClientSession() as session:
|
97 |
+
async with session.post(api_url, json=request_body, headers=headers) as response:
|
98 |
+
response_body = await response.read()
|
99 |
+
response_headers = response.headers
|
100 |
+
response_status = response.status
|
101 |
+
|
102 |
+
modified_response = web.Response(
|
103 |
+
body=response_body,
|
104 |
+
status=response_status,
|
105 |
+
headers=response_headers
|
106 |
+
)
|
107 |
+
modified_response.headers['Access-Control-Allow-Origin'] = '*'
|
108 |
+
modified_response.headers['Access-Control-Allow-Methods'] = 'POST, GET, OPTIONS'
|
109 |
+
modified_response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, x-api-key, anthropic-version, model'
|
110 |
|
111 |
+
return modified_response
|
|
|
|
|
112 |
|
113 |
def handle_options():
|
114 |
headers = {
|
|
|
116 |
'Access-Control-Allow-Methods': 'POST, GET, OPTIONS',
|
117 |
'Access-Control-Allow-Headers': 'Content-Type, Authorization, x-api-key, anthropic-version, model'
|
118 |
}
|
119 |
+
return web.Response(status=204, headers=headers)
|
120 |
+
|
121 |
+
app = web.Application()
|
122 |
+
app.router.add_route('*', '/', handle_request)
|
123 |
|
124 |
if __name__ == '__main__':
|
125 |
+
web.run_app(app, port=8080)
|