GeminiBalance / app /service /openai_compatiable /openai_compatiable_service.py
CatPtain's picture
Upload 77 files
76b9762 verified
import datetime
import json
import re
import time
from typing import Any, AsyncGenerator, Dict, Union
from app.config.config import settings
from app.database.services import (
add_error_log,
add_request_log,
)
from app.domain.openai_models import ChatRequest, ImageGenerationRequest
from app.service.client.api_client import OpenaiApiClient
from app.service.key.key_manager import KeyManager
from app.log.logger import get_openai_compatible_logger
logger = get_openai_compatible_logger()
class OpenAICompatiableService:
def __init__(self, base_url: str, key_manager: KeyManager = None):
self.key_manager = key_manager
self.base_url = base_url
self.api_client = OpenaiApiClient(base_url, settings.TIME_OUT)
async def get_models(self, api_key: str) -> Dict[str, Any]:
return await self.api_client.get_models(api_key)
async def create_chat_completion(
self,
request: ChatRequest,
api_key: str,
) -> Union[Dict[str, Any], AsyncGenerator[str, None]]:
"""创建聊天完成"""
request_dict = request.model_dump()
# 移除值为null的
request_dict = {k: v for k, v in request_dict.items() if v is not None}
del request_dict["top_k"] # 删除top_k参数,目前不支持该参数
if request.stream:
return self._handle_stream_completion(request.model, request_dict, api_key)
return await self._handle_normal_completion(request.model, request_dict, api_key)
async def generate_images(
self,
request: ImageGenerationRequest,
) -> Dict[str, Any]:
"""生成图片"""
request_dict = request.model_dump()
# 移除值为null的
request_dict = {k: v for k, v in request_dict.items() if v is not None}
api_key = settings.PAID_KEY
return await self.api_client.generate_images(request_dict, api_key)
async def create_embeddings(
self,
input_text: str,
model: str,
api_key: str,
) -> Dict[str, Any]:
"""创建嵌入"""
return await self.api_client.create_embeddings(input_text, model, api_key)
async def _handle_normal_completion(
self, model: str, request: dict, api_key: str
) -> Dict[str, Any]:
"""处理普通聊天完成"""
start_time = time.perf_counter()
request_datetime = datetime.datetime.now()
is_success = False
status_code = None
response = None
try:
response = await self.api_client.generate_content(request, api_key)
is_success = True
status_code = 200
return response
except Exception as e:
is_success = False
error_log_msg = str(e)
logger.error(f"Normal API call failed with error: {error_log_msg}")
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
await add_error_log(
gemini_key=api_key,
model_name=model,
error_type="openai-compatiable-non-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=request,
)
raise e
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
await add_request_log(
model_name=model,
api_key=api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime,
)
async def _handle_stream_completion(
self, model: str, payload: dict, api_key: str
) -> AsyncGenerator[str, None]:
"""处理流式聊天完成,添加重试逻辑"""
retries = 0
max_retries = settings.MAX_RETRIES
is_success = False
status_code = None
final_api_key = api_key
while retries < max_retries:
start_time = time.perf_counter()
request_datetime = datetime.datetime.now()
current_attempt_key = api_key
final_api_key = current_attempt_key
try:
async for line in self.api_client.stream_generate_content(
payload, current_attempt_key
):
if line.startswith("data:"):
# print(line)
yield line + "\n\n"
logger.info("Streaming completed successfully")
is_success = True
status_code = 200
break
except Exception as e:
retries += 1
is_success = False
error_log_msg = str(e)
logger.warning(
f"Streaming API call failed with error: {error_log_msg}. Attempt {retries} of {max_retries}"
)
match = re.search(r"status code (\d+)", error_log_msg)
if match:
status_code = int(match.group(1))
else:
status_code = 500
await add_error_log(
gemini_key=current_attempt_key,
model_name=model,
error_type="openai-compatiable-stream",
error_log=error_log_msg,
error_code=status_code,
request_msg=payload,
)
if self.key_manager:
api_key = await self.key_manager.handle_api_failure(
current_attempt_key, retries
)
if api_key:
logger.info(f"Switched to new API key: {api_key}")
else:
logger.error(
f"No valid API key available after {retries} retries."
)
break
else:
logger.error("KeyManager not available for retry logic.")
break
if retries >= max_retries:
logger.error(f"Max retries ({max_retries}) reached for streaming.")
break
finally:
end_time = time.perf_counter()
latency_ms = int((end_time - start_time) * 1000)
await add_request_log(
model_name=model,
api_key=final_api_key,
is_success=is_success,
status_code=status_code,
latency_ms=latency_ms,
request_time=request_datetime,
)
if not is_success and retries >= max_retries:
yield f"data: {json.dumps({'error': 'Streaming failed after retries'})}\n\n"
yield "data: [DONE]\n\n"