✨ Feature: Add feature: Support that as long as the prefix of the API key exists in the configuration file, the API key is valid.
Browse files
main.py
CHANGED
@@ -418,14 +418,20 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
418 |
)
|
419 |
else:
|
420 |
token = None
|
|
|
|
|
421 |
if token:
|
422 |
try:
|
423 |
api_list = app.state.api_list
|
424 |
api_index = api_list.index(token)
|
425 |
-
enable_moderation = safe_get(config, 'api_keys', api_index, "preferences", "ENABLE_MODERATION", default=False)
|
426 |
except ValueError:
|
|
|
|
|
427 |
# token不在api_list中,使用默认值(不开启)
|
428 |
pass
|
|
|
|
|
|
|
429 |
else:
|
430 |
# 如果token为None,检查全局设置
|
431 |
enable_moderation = config.get('ENABLE_MODERATION', False)
|
@@ -473,7 +479,7 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
473 |
|
474 |
|
475 |
if enable_moderation and moderated_content:
|
476 |
-
moderation_response = await self.moderate_content(moderated_content,
|
477 |
is_flagged = moderation_response.get('results', [{}])[0].get('flagged', False)
|
478 |
|
479 |
if is_flagged:
|
@@ -518,11 +524,11 @@ class StatsMiddleware(BaseHTTPMiddleware):
|
|
518 |
# print("current_request_info", current_request_info)
|
519 |
request_info.reset(current_request_info)
|
520 |
|
521 |
-
async def moderate_content(self, content,
|
522 |
moderation_request = ModerationRequest(input=content)
|
523 |
|
524 |
# 直接调用 moderations 函数
|
525 |
-
response = await moderations(moderation_request,
|
526 |
|
527 |
# 读取流式响应的内容
|
528 |
moderation_result = b""
|
@@ -640,7 +646,7 @@ async def ensure_config(request: Request, call_next):
|
|
640 |
return await call_next(request)
|
641 |
|
642 |
# 在 process_request 函数中更新成功和失败计数
|
643 |
-
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], provider: Dict, endpoint=None
|
644 |
url = provider['base_url']
|
645 |
parsed_url = urlparse(url)
|
646 |
# print("parsed_url", parsed_url)
|
@@ -745,17 +751,14 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
745 |
# response = JSONResponse(first_element)
|
746 |
|
747 |
# 更新成功计数和首次响应时间
|
748 |
-
await update_channel_stats(current_info["request_id"], provider['provider'], request.model,
|
749 |
-
# await app.middleware_stack.app.update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=True)
|
750 |
current_info["first_response_time"] = first_response_time
|
751 |
current_info["success"] = True
|
752 |
current_info["provider"] = provider['provider']
|
753 |
return response
|
754 |
|
755 |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError, httpx.ReadTimeout) as e:
|
756 |
-
await update_channel_stats(current_info["request_id"], provider['provider'], request.model,
|
757 |
-
# await app.middleware_stack.app.update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=False)
|
758 |
-
|
759 |
raise e
|
760 |
|
761 |
def weighted_round_robin(weights):
|
@@ -950,11 +953,8 @@ class ModelRequestHandler:
|
|
950 |
self.last_provider_indices = defaultdict(lambda: -1)
|
951 |
self.locks = defaultdict(asyncio.Lock)
|
952 |
|
953 |
-
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest],
|
954 |
config = app.state.config
|
955 |
-
api_list = app.state.api_list
|
956 |
-
api_index = api_list.index(token)
|
957 |
-
|
958 |
request_model = request.model
|
959 |
if not safe_get(config, 'api_keys', api_index, 'model'):
|
960 |
raise HTTPException(status_code=404, detail=f"No matching model found: {request_model}")
|
@@ -988,7 +988,7 @@ class ModelRequestHandler:
|
|
988 |
index += 1
|
989 |
provider = matching_providers[current_index]
|
990 |
try:
|
991 |
-
response = await process_request(request, provider, endpoint
|
992 |
return response
|
993 |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError, httpx.ReadTimeout) as e:
|
994 |
|
@@ -1058,9 +1058,12 @@ async def rate_limit_dependency(request: Request, credentials: HTTPAuthorization
|
|
1058 |
try:
|
1059 |
api_index = api_list.index(token)
|
1060 |
except ValueError:
|
1061 |
-
|
1062 |
-
api_index = None
|
1063 |
-
|
|
|
|
|
|
|
1064 |
|
1065 |
# 使用 IP 地址和 token(如果有)作为限制键
|
1066 |
client_ip = request.client.host
|
@@ -1073,32 +1076,44 @@ async def rate_limit_dependency(request: Request, credentials: HTTPAuthorization
|
|
1073 |
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
1074 |
api_list = app.state.api_list
|
1075 |
token = credentials.credentials
|
1076 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1077 |
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
1078 |
-
return
|
1079 |
|
1080 |
def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
1081 |
api_list = app.state.api_list
|
1082 |
token = credentials.credentials
|
1083 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1084 |
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
1085 |
-
for api_key in app.state.api_keys_db:
|
1086 |
-
|
1087 |
-
|
1088 |
-
|
1089 |
return token
|
1090 |
|
1091 |
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
|
1092 |
-
async def request_model(request: RequestModel,
|
1093 |
-
return await model_handler.request_model(request,
|
1094 |
|
1095 |
@app.options("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
|
1096 |
async def options_handler():
|
1097 |
return JSONResponse(status_code=200, content={"detail": "OPTIONS allowed"})
|
1098 |
|
1099 |
@app.get("/v1/models", dependencies=[Depends(rate_limit_dependency)])
|
1100 |
-
async def list_models(
|
1101 |
-
models = post_all_models(
|
1102 |
return JSONResponse(content={
|
1103 |
"object": "list",
|
1104 |
"data": models
|
@@ -1107,23 +1122,23 @@ async def list_models(token: str = Depends(verify_api_key)):
|
|
1107 |
@app.post("/v1/images/generations", dependencies=[Depends(rate_limit_dependency)])
|
1108 |
async def images_generations(
|
1109 |
request: ImageGenerationRequest,
|
1110 |
-
|
1111 |
):
|
1112 |
-
return await model_handler.request_model(request,
|
1113 |
|
1114 |
@app.post("/v1/embeddings", dependencies=[Depends(rate_limit_dependency)])
|
1115 |
async def embeddings(
|
1116 |
request: EmbeddingRequest,
|
1117 |
-
|
1118 |
):
|
1119 |
-
return await model_handler.request_model(request,
|
1120 |
|
1121 |
@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
|
1122 |
async def moderations(
|
1123 |
request: ModerationRequest,
|
1124 |
-
|
1125 |
):
|
1126 |
-
return await model_handler.request_model(request,
|
1127 |
|
1128 |
from fastapi import UploadFile, File, Form, HTTPException
|
1129 |
import io
|
@@ -1131,7 +1146,7 @@ import io
|
|
1131 |
async def audio_transcriptions(
|
1132 |
file: UploadFile = File(...),
|
1133 |
model: str = Form(...),
|
1134 |
-
|
1135 |
):
|
1136 |
try:
|
1137 |
# 读取上传的文件内容
|
@@ -1144,7 +1159,7 @@ async def audio_transcriptions(
|
|
1144 |
model=model
|
1145 |
)
|
1146 |
|
1147 |
-
return await model_handler.request_model(request,
|
1148 |
except UnicodeDecodeError:
|
1149 |
raise HTTPException(status_code=400, detail="Invalid audio file encoding")
|
1150 |
except Exception as e:
|
|
|
418 |
)
|
419 |
else:
|
420 |
token = None
|
421 |
+
|
422 |
+
api_index = None
|
423 |
if token:
|
424 |
try:
|
425 |
api_list = app.state.api_list
|
426 |
api_index = api_list.index(token)
|
|
|
427 |
except ValueError:
|
428 |
+
# 如果 token 不在 api_list 中,检查是否以 api_list 中的任何一个开头
|
429 |
+
api_index = next((i for i, api in enumerate(api_list) if token.startswith(api)), None)
|
430 |
# token不在api_list中,使用默认值(不开启)
|
431 |
pass
|
432 |
+
|
433 |
+
if api_index is not None:
|
434 |
+
enable_moderation = safe_get(config, 'api_keys', api_index, "preferences", "ENABLE_MODERATION", default=False)
|
435 |
else:
|
436 |
# 如果token为None,检查全局设置
|
437 |
enable_moderation = config.get('ENABLE_MODERATION', False)
|
|
|
479 |
|
480 |
|
481 |
if enable_moderation and moderated_content:
|
482 |
+
moderation_response = await self.moderate_content(moderated_content, api_index)
|
483 |
is_flagged = moderation_response.get('results', [{}])[0].get('flagged', False)
|
484 |
|
485 |
if is_flagged:
|
|
|
524 |
# print("current_request_info", current_request_info)
|
525 |
request_info.reset(current_request_info)
|
526 |
|
527 |
+
async def moderate_content(self, content, api_index):
|
528 |
moderation_request = ModerationRequest(input=content)
|
529 |
|
530 |
# 直接调用 moderations 函数
|
531 |
+
response = await moderations(moderation_request, api_index)
|
532 |
|
533 |
# 读取流式响应的内容
|
534 |
moderation_result = b""
|
|
|
646 |
return await call_next(request)
|
647 |
|
648 |
# 在 process_request 函数中更新成功和失败计数
|
649 |
+
async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], provider: Dict, endpoint=None):
|
650 |
url = provider['base_url']
|
651 |
parsed_url = urlparse(url)
|
652 |
# print("parsed_url", parsed_url)
|
|
|
751 |
# response = JSONResponse(first_element)
|
752 |
|
753 |
# 更新成功计数和首次响应时间
|
754 |
+
await update_channel_stats(current_info["request_id"], provider['provider'], request.model, current_info["api_key"], success=True)
|
|
|
755 |
current_info["first_response_time"] = first_response_time
|
756 |
current_info["success"] = True
|
757 |
current_info["provider"] = provider['provider']
|
758 |
return response
|
759 |
|
760 |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError, httpx.ReadTimeout) as e:
|
761 |
+
await update_channel_stats(current_info["request_id"], provider['provider'], request.model, current_info["api_key"], success=False)
|
|
|
|
|
762 |
raise e
|
763 |
|
764 |
def weighted_round_robin(weights):
|
|
|
953 |
self.last_provider_indices = defaultdict(lambda: -1)
|
954 |
self.locks = defaultdict(asyncio.Lock)
|
955 |
|
956 |
+
async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], api_index: int = None, endpoint=None):
|
957 |
config = app.state.config
|
|
|
|
|
|
|
958 |
request_model = request.model
|
959 |
if not safe_get(config, 'api_keys', api_index, 'model'):
|
960 |
raise HTTPException(status_code=404, detail=f"No matching model found: {request_model}")
|
|
|
988 |
index += 1
|
989 |
provider = matching_providers[current_index]
|
990 |
try:
|
991 |
+
response = await process_request(request, provider, endpoint)
|
992 |
return response
|
993 |
except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError, httpx.ReadTimeout) as e:
|
994 |
|
|
|
1058 |
try:
|
1059 |
api_index = api_list.index(token)
|
1060 |
except ValueError:
|
1061 |
+
# 如果 token 不在 api_list 中,检查是否以 api_list 中的任何一个开头
|
1062 |
+
api_index = next((i for i, api in enumerate(api_list) if token.startswith(api)), None)
|
1063 |
+
if api_index is None:
|
1064 |
+
print("error: Invalid or missing API Key:", token)
|
1065 |
+
api_index = None
|
1066 |
+
token = None
|
1067 |
|
1068 |
# 使用 IP 地址和 token(如果有)作为限制键
|
1069 |
client_ip = request.client.host
|
|
|
1076 |
def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
1077 |
api_list = app.state.api_list
|
1078 |
token = credentials.credentials
|
1079 |
+
api_index = None
|
1080 |
+
try:
|
1081 |
+
api_index = api_list.index(token)
|
1082 |
+
except ValueError:
|
1083 |
+
# 如果 token 不在 api_list 中,检查是否以 api_list 中的任何一个开头
|
1084 |
+
api_index = next((i for i, api in enumerate(api_list) if token.startswith(api)), None)
|
1085 |
+
if api_index is None:
|
1086 |
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
1087 |
+
return api_index
|
1088 |
|
1089 |
def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
1090 |
api_list = app.state.api_list
|
1091 |
token = credentials.credentials
|
1092 |
+
api_index = None
|
1093 |
+
try:
|
1094 |
+
api_index = api_list.index(token)
|
1095 |
+
except ValueError:
|
1096 |
+
# 如果 token 不在 api_list 中,检查是否以 api_list 中的任何一个开头
|
1097 |
+
api_index = next((i for i, api in enumerate(api_list) if token.startswith(api)), None)
|
1098 |
+
if api_index is None:
|
1099 |
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
1100 |
+
# for api_key in app.state.api_keys_db:
|
1101 |
+
# if token.startswith(api_key['api']):
|
1102 |
+
if app.state.api_keys_db[api_index].get('role') != "admin":
|
1103 |
+
raise HTTPException(status_code=403, detail="Permission denied")
|
1104 |
return token
|
1105 |
|
1106 |
@app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
|
1107 |
+
async def request_model(request: RequestModel, api_index: int = Depends(verify_api_key)):
|
1108 |
+
return await model_handler.request_model(request, api_index)
|
1109 |
|
1110 |
@app.options("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
|
1111 |
async def options_handler():
|
1112 |
return JSONResponse(status_code=200, content={"detail": "OPTIONS allowed"})
|
1113 |
|
1114 |
@app.get("/v1/models", dependencies=[Depends(rate_limit_dependency)])
|
1115 |
+
async def list_models(api_index: int = Depends(verify_api_key)):
|
1116 |
+
models = post_all_models(api_index, app.state.config)
|
1117 |
return JSONResponse(content={
|
1118 |
"object": "list",
|
1119 |
"data": models
|
|
|
1122 |
@app.post("/v1/images/generations", dependencies=[Depends(rate_limit_dependency)])
|
1123 |
async def images_generations(
|
1124 |
request: ImageGenerationRequest,
|
1125 |
+
api_index: int = Depends(verify_api_key)
|
1126 |
):
|
1127 |
+
return await model_handler.request_model(request, api_index, endpoint="/v1/images/generations")
|
1128 |
|
1129 |
@app.post("/v1/embeddings", dependencies=[Depends(rate_limit_dependency)])
|
1130 |
async def embeddings(
|
1131 |
request: EmbeddingRequest,
|
1132 |
+
api_index: int = Depends(verify_api_key)
|
1133 |
):
|
1134 |
+
return await model_handler.request_model(request, api_index, endpoint="/v1/embeddings")
|
1135 |
|
1136 |
@app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
|
1137 |
async def moderations(
|
1138 |
request: ModerationRequest,
|
1139 |
+
api_index: int = Depends(verify_api_key)
|
1140 |
):
|
1141 |
+
return await model_handler.request_model(request, api_index, endpoint="/v1/moderations")
|
1142 |
|
1143 |
from fastapi import UploadFile, File, Form, HTTPException
|
1144 |
import io
|
|
|
1146 |
async def audio_transcriptions(
|
1147 |
file: UploadFile = File(...),
|
1148 |
model: str = Form(...),
|
1149 |
+
api_index: int = Depends(verify_api_key)
|
1150 |
):
|
1151 |
try:
|
1152 |
# 读取上传的文件内容
|
|
|
1159 |
model=model
|
1160 |
)
|
1161 |
|
1162 |
+
return await model_handler.request_model(request, api_index, endpoint="/v1/audio/transcriptions")
|
1163 |
except UnicodeDecodeError:
|
1164 |
raise HTTPException(status_code=400, detail="Invalid audio file encoding")
|
1165 |
except Exception as e:
|
utils.py
CHANGED
@@ -63,7 +63,7 @@ class InMemoryRateLimiter:
|
|
63 |
|
64 |
rate_limiter = InMemoryRateLimiter()
|
65 |
|
66 |
-
async def get_user_rate_limit(app, api_index:
|
67 |
# 这里应该实现根据 token 获取用户速率限制的逻辑
|
68 |
# 示例: 返回 (次数, 秒数)
|
69 |
config = app.state.config
|
@@ -457,13 +457,10 @@ async def error_handling_wrapper(generator):
|
|
457 |
except StopAsyncIteration:
|
458 |
raise HTTPException(status_code=400, detail="data: {'error': 'No data returned'}")
|
459 |
|
460 |
-
def post_all_models(
|
461 |
all_models = []
|
462 |
unique_models = set()
|
463 |
|
464 |
-
if token not in api_list:
|
465 |
-
raise HTTPException(status_code=403, detail="Invalid or missing API Key")
|
466 |
-
api_index = api_list.index(token)
|
467 |
if config['api_keys'][api_index]['model']:
|
468 |
for model in config['api_keys'][api_index]['model']:
|
469 |
if model == "all":
|
|
|
63 |
|
64 |
rate_limiter = InMemoryRateLimiter()
|
65 |
|
66 |
+
async def get_user_rate_limit(app, api_index: int = None):
|
67 |
# 这里应该实现根据 token 获取用户速率限制的逻辑
|
68 |
# 示例: 返回 (次数, 秒数)
|
69 |
config = app.state.config
|
|
|
457 |
except StopAsyncIteration:
|
458 |
raise HTTPException(status_code=400, detail="data: {'error': 'No data returned'}")
|
459 |
|
460 |
+
def post_all_models(api_index, config):
|
461 |
all_models = []
|
462 |
unique_models = set()
|
463 |
|
|
|
|
|
|
|
464 |
if config['api_keys'][api_index]['model']:
|
465 |
for model in config['api_keys'][api_index]['model']:
|
466 |
if model == "all":
|