Spaces:
Running
Running
File size: 14,552 Bytes
9126c93 960a587 9126c93 960a587 0a61a36 9126c93 960a587 9126c93 960a587 9126c93 960a587 b6845e0 960a587 2181206 576c48b 2181206 2aacb46 2181206 960a587 fd4bd23 9126c93 fd4bd23 9126c93 fd4bd23 9126c93 fd4bd23 9126c93 fd4bd23 9126c93 fd4bd23 9126c93 fd4bd23 9126c93 fd4bd23 9126c93 fd4bd23 960a587 9126c93 960a587 fd4bd23 960a587 9126c93 960a587 2181206 9126c93 2181206 960a587 9126c93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 |
from fastapi import FastAPI, HTTPException, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import openai
from typing import List, Optional, Union
import logging
from itertools import cycle
import asyncio
import uvicorn
from app import config
import requests
from datetime import datetime, timezone
import json
import httpx
import uuid
import time
# 配置日志
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
app = FastAPI()
# 允许跨域
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# API密钥配置
API_KEYS = config.settings.API_KEYS
# 创建一个循环迭代器
key_cycle = cycle(API_KEYS)
# 创建两个独立的锁
key_cycle_lock = asyncio.Lock()
failure_count_lock = asyncio.Lock()
# 添加key失败计数记录
key_failure_counts = {key: 0 for key in API_KEYS}
MAX_FAILURES = 10 # 最大失败次数阈值
MAX_RETRIES = 3 # 最大重试次数
async def get_next_key():
"""仅获取下一个key,不检查失败次数"""
async with key_cycle_lock:
return next(key_cycle)
async def is_key_valid(key):
"""检查key是否有效"""
async with failure_count_lock:
return key_failure_counts[key] < MAX_FAILURES
async def reset_failure_counts():
"""重置所有key的失败计数"""
async with failure_count_lock:
for key in key_failure_counts:
key_failure_counts[key] = 0
async def get_next_working_key():
"""获取下一个可用的API key"""
initial_key = await get_next_key()
current_key = initial_key
while True:
if await is_key_valid(current_key):
return current_key
current_key = await get_next_key()
if current_key == initial_key: # 已经循环了一圈
await reset_failure_counts()
return current_key
async def handle_api_failure(api_key):
"""处理API调用失败"""
async with failure_count_lock:
key_failure_counts[api_key] += 1
if key_failure_counts[api_key] >= MAX_FAILURES:
logger.warning(f"API key {api_key} has failed {MAX_FAILURES} times, switching to next key")
# 在锁外获取新的key
return await get_next_working_key()
class ChatRequest(BaseModel):
messages: List[dict]
model: str = "gemini-1.5-flash-002"
temperature: Optional[float] = 0.7
stream: Optional[bool] = False
tools: Optional[List[dict]] = []
tool_choice: Optional[str] = "auto"
class EmbeddingRequest(BaseModel):
input: Union[str, List[str]]
model: str = "text-embedding-004"
encoding_format: Optional[str] = "float"
async def verify_authorization(authorization: str = Header(None)):
if not authorization:
logger.error("Missing Authorization header")
raise HTTPException(status_code=401, detail="Missing Authorization header")
if not authorization.startswith("Bearer "):
logger.error("Invalid Authorization header format")
raise HTTPException(
status_code=401, detail="Invalid Authorization header format"
)
token = authorization.replace("Bearer ", "")
if token not in config.settings.ALLOWED_TOKENS:
logger.error("Invalid token")
raise HTTPException(status_code=401, detail="Invalid token")
return token
def get_gemini_models(api_key):
base_url = "https://generativelanguage.googleapis.com/v1beta"
url = f"{base_url}/models?key={api_key}"
try:
response = requests.get(url)
if response.status_code == 200:
gemini_models = response.json()
return convert_to_openai_models_format(gemini_models)
else:
print(f"Error: {response.status_code}")
print(response.text)
return None
except requests.RequestException as e:
print(f"Request failed: {e}")
return None
def convert_to_openai_models_format(gemini_models):
openai_format = {"object": "list", "data": []}
for model in gemini_models.get("models", []):
openai_model = {
"id": model["name"].split("/")[-1], # 取最后一部分作为ID
"object": "model",
"created": int(datetime.now(timezone.utc).timestamp()), # 使用当前时间戳
"owned_by": "google", # 假设所有Gemini模型都由Google拥有
"permission": [], # Gemini API可能没有直接对应的权限信息
"root": model["name"],
"parent": None, # Gemini API可能没有直接对应的父模型信息
}
openai_format["data"].append(openai_model)
return openai_format
def convert_messages_to_gemini_format(messages):
"""Convert OpenAI message format to Gemini format"""
gemini_messages = []
for message in messages:
gemini_message = {
"role": "user" if message["role"] == "user" else "model",
"parts": [{"text": message["content"]}],
}
gemini_messages.append(gemini_message)
return gemini_messages
def convert_gemini_response_to_openai(response, model, stream=False):
"""Convert Gemini response to OpenAI format"""
if stream:
# 处理流式响应
chunk = response
if not chunk["candidates"]:
return None
return {
"id": "chatcmpl-" + str(uuid.uuid4()),
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": chunk["candidates"][0]["content"]["parts"][0]["text"]
},
"finish_reason": None,
}
],
}
else:
# 处理普通响应
return {
"id": "chatcmpl-" + str(uuid.uuid4()),
"object": "chat.completion",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response["candidates"][0]["content"]["parts"][0][
"text"
],
},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
}
@app.get("/v1/models")
@app.get("/hf/v1/models")
async def list_models(authorization: str = Header(None)):
await verify_authorization(authorization)
api_key = await get_next_working_key()
logger.info(f"Using API key: {api_key}")
try:
response = get_gemini_models(api_key)
logger.info("Successfully retrieved models list")
return response
except Exception as e:
logger.error(f"Error listing models: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/v1/chat/completions")
@app.post("/hf/v1/chat/completions")
async def chat_completion(request: ChatRequest, authorization: str = Header(None)):
await verify_authorization(authorization)
api_key = await get_next_working_key()
logger.info(f"Chat completion request - Model: {request.model}")
retries = 0
while retries < MAX_RETRIES:
try:
logger.info(f"Attempt {retries + 1} with API key: {api_key}")
if request.model in config.settings.MODEL_SEARCH:
# Gemini API调用部分
gemini_messages = convert_messages_to_gemini_format(request.messages)
# 调用Gemini API
payload = {
"contents": gemini_messages,
"generationConfig": {
"temperature": request.temperature,
},
"tools": [{"googleSearch": {}}],
}
if request.stream:
logger.info("Streaming response enabled")
async def generate():
nonlocal api_key, retries
while retries < MAX_RETRIES:
try:
async with httpx.AsyncClient() as client:
stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:streamGenerateContent?alt=sse&key={api_key}"
async with client.stream("POST", stream_url, json=payload) as response:
if response.status_code == 429:
logger.warning(f"Rate limit reached for key: {api_key}")
api_key = await handle_api_failure(api_key)
logger.info(f"Retrying with new API key: {api_key}")
retries += 1
if retries >= MAX_RETRIES:
yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
break
continue
if response.status_code != 200:
logger.error(f"Error in streaming response: {response.status_code}")
yield f"data: {json.dumps({'error': f'API error: {response.status_code}'})}\n\n"
break
async for line in response.aiter_lines():
if line.startswith("data: "):
try:
chunk = json.loads(line[6:])
openai_chunk = convert_gemini_response_to_openai(
chunk, request.model, stream=True
)
if openai_chunk:
yield f"data: {json.dumps(openai_chunk)}\n\n"
except json.JSONDecodeError:
continue
yield "data: [DONE]\n\n"
return
except Exception as e:
logger.error(f"Stream error: {str(e)}")
api_key = await handle_api_failure(api_key)
retries += 1
if retries >= MAX_RETRIES:
yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
break
continue
return StreamingResponse(content=generate(), media_type="text/event-stream")
else:
# 非流式响应
async with httpx.AsyncClient() as client:
non_stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:generateContent?key={api_key}"
response = await client.post(non_stream_url, json=payload)
gemini_response = response.json()
logger.info("Chat completion successful")
return convert_gemini_response_to_openai(gemini_response, request.model)
# OpenAI API调用部分
client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
response = client.chat.completions.create(
model=request.model,
messages=request.messages,
temperature=request.temperature,
stream=request.stream if hasattr(request, "stream") else False,
)
if hasattr(request, "stream") and request.stream:
logger.info("Streaming response enabled")
async def generate():
for chunk in response:
yield f"data: {chunk.model_dump_json()}\n\n"
logger.info("Chat completion successful")
return StreamingResponse(content=generate(), media_type="text/event-stream")
logger.info("Chat completion successful")
return response
except Exception as e:
logger.error(f"Error in chat completion: {str(e)}")
api_key = await handle_api_failure(api_key)
retries += 1
if retries >= MAX_RETRIES:
logger.error("Max retries reached, giving up")
raise HTTPException(status_code=500, detail="Max retries reached with all available API keys")
logger.info(f"Retrying with new API key: {api_key}")
continue
raise HTTPException(status_code=500, detail="Unexpected error in chat completion")
@app.post("/v1/embeddings")
@app.post("/hf/v1/embeddings")
async def embedding(request: EmbeddingRequest, authorization: str = Header(None)):
await verify_authorization(authorization)
api_key = await get_next_working_key()
logger.info(f"Using API key: {api_key}")
try:
client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
response = client.embeddings.create(input=request.input, model=request.model)
logger.info("Embedding successful")
return response
except Exception as e:
logger.error(f"Error in embedding: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
@app.get("/")
async def health_check():
logger.info("Health check endpoint called")
return {"status": "healthy"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
|