Spaces:
Running
Running
File size: 3,732 Bytes
3fdf5e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
from flask import Flask, request, Response, jsonify
import os
import requests
import time
import threading
app = Flask(__name__)
# 环境变量
PROJECT_ID = os.getenv('PROJECT_ID')
CLIENT_ID = os.getenv('CLIENT_ID')
CLIENT_SECRET = os.getenv('CLIENT_SECRET')
REFRESH_TOKEN = os.getenv('REFRESH_TOKEN')
API_KEY = os.getenv('API_KEY')
TOKEN_URL = 'https://www.googleapis.com/oauth2/v4/token'
token_cache = {
'access_token': '',
'expiry': 0,
'refresh_promise': None
}
def get_access_token():
now = time.time()
# 如果 token 仍然有效,直接返回
if token_cache['access_token'] and now < token_cache['expiry'] - 120:
return token_cache['access_token']
# 如果已经有一个刷新操作在进行中,等待它完成
if token_cache['refresh_promise']:
token_cache['refresh_promise'].join()
return token_cache['access_token']
# 开始新的刷新操作
def refresh_token():
try:
response = requests.post(TOKEN_URL, json={
'client_id': CLIENT_ID,
'client_secret': CLIENT_SECRET,
'refresh_token': REFRESH_TOKEN,
'grant_type': 'refresh_token'
})
data = response.json()
token_cache['access_token'] = data['access_token']
token_cache['expiry'] = now + data['expires_in']
finally:
token_cache['refresh_promise'] = None
token_cache['refresh_promise'] = threading.Thread(target=refresh_token)
token_cache['refresh_promise'].start()
token_cache['refresh_promise'].join()
return token_cache['access_token']
def get_location():
current_seconds = time.localtime().tm_sec
return 'europe-west1' if current_seconds < 30 else 'us-east5'
def construct_api_url(location):
return f'https://{location}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{location}/publishers/anthropic/models/{MODEL}:streamRawPredict'
@app.route('/ai/v1/messages', methods=['POST', 'OPTIONS'])
def handle_request():
if request.method == 'OPTIONS':
return handle_options()
# 检查 x-api-key
api_key = request.headers.get('x-api-key')
if api_key != API_KEY:
return jsonify({
'type': 'error',
'error': {
'type': 'permission_error',
'message': 'Your API key does not have permission to use the specified resource.'
}
}), 403
access_token = get_access_token()
location = get_location()
model = request.headers.get('model', 'claude-3-5-sonnet@20240620')
if model == 'claude-3-5-sonnet-20240620':
model = 'claude-3-5-sonnet@20240620'
api_url = construct_api_url(location)
request_body = request.json
if 'anthropic_version' in request_body:
del request_body['anthropic_version']
if 'model' in request_body:
del request_body['model']
request_body['anthropic_version'] = "vertex-2023-10-16"
headers = {
'Authorization': f'Bearer {access_token}',
'Content-Type': 'application/json; charset=utf-8'
}
response = requests.post(api_url, headers=headers, json=request_body)
return Response(response.content, status=response.status_code, content_type=response.headers['Content-Type'])
@app.route('/', methods=['GET'])
def index():
return "Vertex Claude API Proxy", 200
def handle_options():
headers = {
'Access-Control-Allow-Origin': '*',
'Access-Control-Allow-Methods': 'POST, GET, OPTIONS',
'Access-Control-Allow-Headers': 'Content-Type, Authorization, x-api-key, anthropic-version, model'
}
return '', 204, headers
if __name__ == '__main__':
app.run(port=8080) |