File size: 14,552 Bytes
9126c93
960a587
 
 
 
9126c93
960a587
 
 
 
 
 
 
0a61a36
 
9126c93
 
 
 
960a587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9126c93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
960a587
 
 
 
9126c93
960a587
 
b6845e0
 
960a587
 
2181206
576c48b
2181206
2aacb46
2181206
 
960a587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd4bd23
 
 
9126c93
fd4bd23
 
 
 
9126c93
fd4bd23
 
 
 
9126c93
fd4bd23
 
 
 
9126c93
 
 
 
 
fd4bd23
9126c93
fd4bd23
 
 
 
9126c93
fd4bd23
 
 
9126c93
fd4bd23
9126c93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd4bd23
960a587
 
 
 
9126c93
 
960a587
fd4bd23
960a587
 
 
 
 
 
 
 
 
 
 
9126c93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
960a587
 
2181206
 
 
 
9126c93
 
2181206
 
 
 
 
 
 
 
 
 
 
960a587
 
 
 
 
 
 
 
9126c93
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
from fastapi import FastAPI, HTTPException, Header
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import openai
from typing import List, Optional, Union
import logging
from itertools import cycle
import asyncio

import uvicorn

from app import config
import requests
from datetime import datetime, timezone
import json
import httpx
import uuid
import time

# 配置日志
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

app = FastAPI()

# 允许跨域
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# API密钥配置
API_KEYS = config.settings.API_KEYS

# 创建一个循环迭代器
key_cycle = cycle(API_KEYS)

# 创建两个独立的锁
key_cycle_lock = asyncio.Lock()
failure_count_lock = asyncio.Lock()

# 添加key失败计数记录
key_failure_counts = {key: 0 for key in API_KEYS}
MAX_FAILURES = 10  # 最大失败次数阈值
MAX_RETRIES = 3  # 最大重试次数


async def get_next_key():
    """仅获取下一个key,不检查失败次数"""
    async with key_cycle_lock:
        return next(key_cycle)

async def is_key_valid(key):
    """检查key是否有效"""
    async with failure_count_lock:
        return key_failure_counts[key] < MAX_FAILURES

async def reset_failure_counts():
    """重置所有key的失败计数"""
    async with failure_count_lock:
        for key in key_failure_counts:
            key_failure_counts[key] = 0

async def get_next_working_key():
    """获取下一个可用的API key"""
    initial_key = await get_next_key()
    current_key = initial_key
    
    while True:
        if await is_key_valid(current_key):
            return current_key
            
        current_key = await get_next_key()
        if current_key == initial_key:  # 已经循环了一圈
            await reset_failure_counts()
            return current_key

async def handle_api_failure(api_key):
    """处理API调用失败"""
    async with failure_count_lock:
        key_failure_counts[api_key] += 1
        if key_failure_counts[api_key] >= MAX_FAILURES:
            logger.warning(f"API key {api_key} has failed {MAX_FAILURES} times, switching to next key")
    
    # 在锁外获取新的key
    return await get_next_working_key()


class ChatRequest(BaseModel):
    messages: List[dict]
    model: str = "gemini-1.5-flash-002"
    temperature: Optional[float] = 0.7
    stream: Optional[bool] = False
    tools: Optional[List[dict]] = []
    tool_choice: Optional[str] = "auto"


class EmbeddingRequest(BaseModel):
    input: Union[str, List[str]]
    model: str = "text-embedding-004"
    encoding_format: Optional[str] = "float"


async def verify_authorization(authorization: str = Header(None)):
    if not authorization:
        logger.error("Missing Authorization header")
        raise HTTPException(status_code=401, detail="Missing Authorization header")
    if not authorization.startswith("Bearer "):
        logger.error("Invalid Authorization header format")
        raise HTTPException(
            status_code=401, detail="Invalid Authorization header format"
        )
    token = authorization.replace("Bearer ", "")
    if token not in config.settings.ALLOWED_TOKENS:
        logger.error("Invalid token")
        raise HTTPException(status_code=401, detail="Invalid token")
    return token


def get_gemini_models(api_key):
    base_url = "https://generativelanguage.googleapis.com/v1beta"
    url = f"{base_url}/models?key={api_key}"

    try:
        response = requests.get(url)
        if response.status_code == 200:
            gemini_models = response.json()
            return convert_to_openai_models_format(gemini_models)
        else:
            print(f"Error: {response.status_code}")
            print(response.text)
            return None

    except requests.RequestException as e:
        print(f"Request failed: {e}")
        return None


def convert_to_openai_models_format(gemini_models):
    openai_format = {"object": "list", "data": []}

    for model in gemini_models.get("models", []):
        openai_model = {
            "id": model["name"].split("/")[-1],  # 取最后一部分作为ID
            "object": "model",
            "created": int(datetime.now(timezone.utc).timestamp()),  # 使用当前时间戳
            "owned_by": "google",  # 假设所有Gemini模型都由Google拥有
            "permission": [],  # Gemini API可能没有直接对应的权限信息
            "root": model["name"],
            "parent": None,  # Gemini API可能没有直接对应的父模型信息
        }
        openai_format["data"].append(openai_model)

    return openai_format


def convert_messages_to_gemini_format(messages):
    """Convert OpenAI message format to Gemini format"""
    gemini_messages = []
    for message in messages:
        gemini_message = {
            "role": "user" if message["role"] == "user" else "model",
            "parts": [{"text": message["content"]}],
        }
        gemini_messages.append(gemini_message)
    return gemini_messages


def convert_gemini_response_to_openai(response, model, stream=False):
    """Convert Gemini response to OpenAI format"""
    if stream:
        # 处理流式响应
        chunk = response
        if not chunk["candidates"]:
            return None

        return {
            "id": "chatcmpl-" + str(uuid.uuid4()),
            "object": "chat.completion.chunk",
            "created": int(time.time()),
            "model": model,
            "choices": [
                {
                    "index": 0,
                    "delta": {
                        "content": chunk["candidates"][0]["content"]["parts"][0]["text"]
                    },
                    "finish_reason": None,
                }
            ],
        }
    else:
        # 处理普通响应
        return {
            "id": "chatcmpl-" + str(uuid.uuid4()),
            "object": "chat.completion",
            "created": int(time.time()),
            "model": model,
            "choices": [
                {
                    "index": 0,
                    "message": {
                        "role": "assistant",
                        "content": response["candidates"][0]["content"]["parts"][0][
                            "text"
                        ],
                    },
                    "finish_reason": "stop",
                }
            ],
            "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
        }


@app.get("/v1/models")
@app.get("/hf/v1/models")
async def list_models(authorization: str = Header(None)):
    await verify_authorization(authorization)
    api_key = await get_next_working_key()
    logger.info(f"Using API key: {api_key}")
    try:
        response = get_gemini_models(api_key)
        logger.info("Successfully retrieved models list")
        return response
    except Exception as e:
        logger.error(f"Error listing models: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/v1/chat/completions")
@app.post("/hf/v1/chat/completions")
async def chat_completion(request: ChatRequest, authorization: str = Header(None)):
    await verify_authorization(authorization)
    api_key = await get_next_working_key()
    logger.info(f"Chat completion request - Model: {request.model}")
    retries = 0
    
    while retries < MAX_RETRIES:
        try:
            logger.info(f"Attempt {retries + 1} with API key: {api_key}")
            
            if request.model in config.settings.MODEL_SEARCH:
                # Gemini API调用部分
                gemini_messages = convert_messages_to_gemini_format(request.messages)
                # 调用Gemini API
                payload = {
                    "contents": gemini_messages,
                    "generationConfig": {
                        "temperature": request.temperature,
                    },
                    "tools": [{"googleSearch": {}}],
                }
                
                if request.stream:
                    logger.info("Streaming response enabled")

                    async def generate():
                        nonlocal api_key, retries
                        while retries < MAX_RETRIES:
                            try:
                                async with httpx.AsyncClient() as client:
                                    stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:streamGenerateContent?alt=sse&key={api_key}"
                                    async with client.stream("POST", stream_url, json=payload) as response:
                                        if response.status_code == 429:
                                            logger.warning(f"Rate limit reached for key: {api_key}")
                                            api_key = await handle_api_failure(api_key)
                                            logger.info(f"Retrying with new API key: {api_key}")
                                            retries += 1
                                            if retries >= MAX_RETRIES:
                                                yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
                                                break
                                            continue
                                            
                                        if response.status_code != 200:
                                            logger.error(f"Error in streaming response: {response.status_code}")
                                            yield f"data: {json.dumps({'error': f'API error: {response.status_code}'})}\n\n"
                                            break
                                            
                                        async for line in response.aiter_lines():
                                            if line.startswith("data: "):
                                                try:
                                                    chunk = json.loads(line[6:])
                                                    openai_chunk = convert_gemini_response_to_openai(
                                                        chunk, request.model, stream=True
                                                    )
                                                    if openai_chunk:
                                                        yield f"data: {json.dumps(openai_chunk)}\n\n"
                                                except json.JSONDecodeError:
                                                    continue
                                        yield "data: [DONE]\n\n"
                                        return
                            except Exception as e:
                                logger.error(f"Stream error: {str(e)}")
                                api_key = await handle_api_failure(api_key)
                                retries += 1
                                if retries >= MAX_RETRIES:
                                    yield f"data: {json.dumps({'error': 'Max retries reached'})}\n\n"
                                    break
                                continue
                                
                    return StreamingResponse(content=generate(), media_type="text/event-stream")
                else:
                    # 非流式响应
                    async with httpx.AsyncClient() as client:
                        non_stream_url = f"https://generativelanguage.googleapis.com/v1beta/models/{request.model}:generateContent?key={api_key}"
                        response = await client.post(non_stream_url, json=payload)
                        gemini_response = response.json()
                        logger.info("Chat completion successful")
                        return convert_gemini_response_to_openai(gemini_response, request.model)
            
            # OpenAI API调用部分
            client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
            response = client.chat.completions.create(
                model=request.model,
                messages=request.messages,
                temperature=request.temperature,
                stream=request.stream if hasattr(request, "stream") else False,
            )
            
            if hasattr(request, "stream") and request.stream:
                logger.info("Streaming response enabled")

                async def generate():
                    for chunk in response:
                        yield f"data: {chunk.model_dump_json()}\n\n"
                logger.info("Chat completion successful")
                return StreamingResponse(content=generate(), media_type="text/event-stream")
            
            logger.info("Chat completion successful")
            return response

        except Exception as e:
            logger.error(f"Error in chat completion: {str(e)}")
            api_key = await handle_api_failure(api_key)
            retries += 1
            
            if retries >= MAX_RETRIES:
                logger.error("Max retries reached, giving up")
                raise HTTPException(status_code=500, detail="Max retries reached with all available API keys")
            
            logger.info(f"Retrying with new API key: {api_key}")
            continue

    raise HTTPException(status_code=500, detail="Unexpected error in chat completion")


@app.post("/v1/embeddings")
@app.post("/hf/v1/embeddings")
async def embedding(request: EmbeddingRequest, authorization: str = Header(None)):
    await verify_authorization(authorization)
    api_key = await get_next_working_key()
    logger.info(f"Using API key: {api_key}")

    try:
        client = openai.OpenAI(api_key=api_key, base_url=config.settings.BASE_URL)
        response = client.embeddings.create(input=request.input, model=request.model)
        logger.info("Embedding successful")
        return response
    except Exception as e:
        logger.error(f"Error in embedding: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/health")
@app.get("/")
async def health_check():
    logger.info("Health check endpoint called")
    return {"status": "healthy"}


if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)