from flask import Flask, request, Response, jsonify import os import requests import time import threading app = Flask(__name__) # 环境变量 PROJECT_ID = os.getenv('PROJECT_ID') CLIENT_ID = os.getenv('CLIENT_ID') CLIENT_SECRET = os.getenv('CLIENT_SECRET') REFRESH_TOKEN = os.getenv('REFRESH_TOKEN') API_KEY = os.getenv('API_KEY') TOKEN_URL = 'https://www.googleapis.com/oauth2/v4/token' token_cache = { 'access_token': '', 'expiry': 0, 'refresh_promise': None } def get_access_token(): now = time.time() # 如果 token 仍然有效,直接返回 if token_cache['access_token'] and now < token_cache['expiry'] - 120: return token_cache['access_token'] # 如果已经有一个刷新操作在进行中,等待它完成 if token_cache['refresh_promise']: token_cache['refresh_promise'].join() return token_cache['access_token'] # 开始新的刷新操作 def refresh_token(): try: response = requests.post(TOKEN_URL, json={ 'client_id': CLIENT_ID, 'client_secret': CLIENT_SECRET, 'refresh_token': REFRESH_TOKEN, 'grant_type': 'refresh_token' }) data = response.json() token_cache['access_token'] = data['access_token'] token_cache['expiry'] = now + data['expires_in'] finally: token_cache['refresh_promise'] = None token_cache['refresh_promise'] = threading.Thread(target=refresh_token) token_cache['refresh_promise'].start() token_cache['refresh_promise'].join() return token_cache['access_token'] def get_location(): current_seconds = time.localtime().tm_sec return 'europe-west1' if current_seconds < 30 else 'us-east5' def construct_api_url(location): return f'https://{location}-aiplatform.googleapis.com/v1/projects/{PROJECT_ID}/locations/{location}/publishers/anthropic/models/{MODEL}:streamRawPredict' @app.route('/ai/v1/messages', methods=['POST', 'OPTIONS']) def handle_request(): if request.method == 'OPTIONS': return handle_options() # 检查 x-api-key api_key = request.headers.get('x-api-key') if api_key != API_KEY: return jsonify({ 'type': 'error', 'error': { 'type': 'permission_error', 'message': 'Your API key does not have permission to use the specified resource.' } }), 403 access_token = get_access_token() location = get_location() model = request.headers.get('model', 'claude-3-5-sonnet@20240620') if model == 'claude-3-5-sonnet-20240620': model = 'claude-3-5-sonnet@20240620' api_url = construct_api_url(location) request_body = request.json if 'anthropic_version' in request_body: del request_body['anthropic_version'] if 'model' in request_body: del request_body['model'] request_body['anthropic_version'] = "vertex-2023-10-16" headers = { 'Authorization': f'Bearer {access_token}', 'Content-Type': 'application/json; charset=utf-8' } response = requests.post(api_url, headers=headers, json=request_body) return Response(response.content, status=response.status_code, content_type=response.headers['Content-Type']) @app.route('/', methods=['GET']) def index(): return "Vertex Claude API Proxy", 200 def handle_options(): headers = { 'Access-Control-Allow-Origin': '*', 'Access-Control-Allow-Methods': 'POST, GET, OPTIONS', 'Access-Control-Allow-Headers': 'Content-Type, Authorization, x-api-key, anthropic-version, model' } return '', 204, headers if __name__ == '__main__': app.run(port=8080)