yym68686 commited on
Commit
ecf5650
·
1 Parent(s): 1c49e1a

✨ Feature: Add feature: Support that as long as the prefix of the API key exists in the configuration file, the API key is valid.

Browse files
Files changed (2) hide show
  1. main.py +52 -37
  2. utils.py +2 -5
main.py CHANGED
@@ -418,14 +418,20 @@ class StatsMiddleware(BaseHTTPMiddleware):
418
  )
419
  else:
420
  token = None
 
 
421
  if token:
422
  try:
423
  api_list = app.state.api_list
424
  api_index = api_list.index(token)
425
- enable_moderation = safe_get(config, 'api_keys', api_index, "preferences", "ENABLE_MODERATION", default=False)
426
  except ValueError:
 
 
427
  # token不在api_list中,使用默认值(不开启)
428
  pass
 
 
 
429
  else:
430
  # 如果token为None,检查全局设置
431
  enable_moderation = config.get('ENABLE_MODERATION', False)
@@ -473,7 +479,7 @@ class StatsMiddleware(BaseHTTPMiddleware):
473
 
474
 
475
  if enable_moderation and moderated_content:
476
- moderation_response = await self.moderate_content(moderated_content, token)
477
  is_flagged = moderation_response.get('results', [{}])[0].get('flagged', False)
478
 
479
  if is_flagged:
@@ -518,11 +524,11 @@ class StatsMiddleware(BaseHTTPMiddleware):
518
  # print("current_request_info", current_request_info)
519
  request_info.reset(current_request_info)
520
 
521
- async def moderate_content(self, content, token):
522
  moderation_request = ModerationRequest(input=content)
523
 
524
  # 直接调用 moderations 函数
525
- response = await moderations(moderation_request, token)
526
 
527
  # 读取流式响应的内容
528
  moderation_result = b""
@@ -640,7 +646,7 @@ async def ensure_config(request: Request, call_next):
640
  return await call_next(request)
641
 
642
  # 在 process_request 函数中更新成功和失败计数
643
- async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], provider: Dict, endpoint=None, token=None):
644
  url = provider['base_url']
645
  parsed_url = urlparse(url)
646
  # print("parsed_url", parsed_url)
@@ -745,17 +751,14 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
745
  # response = JSONResponse(first_element)
746
 
747
  # 更新成功计数和首次响应时间
748
- await update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=True)
749
- # await app.middleware_stack.app.update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=True)
750
  current_info["first_response_time"] = first_response_time
751
  current_info["success"] = True
752
  current_info["provider"] = provider['provider']
753
  return response
754
 
755
  except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError, httpx.ReadTimeout) as e:
756
- await update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=False)
757
- # await app.middleware_stack.app.update_channel_stats(current_info["request_id"], provider['provider'], request.model, token, success=False)
758
-
759
  raise e
760
 
761
  def weighted_round_robin(weights):
@@ -950,11 +953,8 @@ class ModelRequestHandler:
950
  self.last_provider_indices = defaultdict(lambda: -1)
951
  self.locks = defaultdict(asyncio.Lock)
952
 
953
- async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], token: str, endpoint=None):
954
  config = app.state.config
955
- api_list = app.state.api_list
956
- api_index = api_list.index(token)
957
-
958
  request_model = request.model
959
  if not safe_get(config, 'api_keys', api_index, 'model'):
960
  raise HTTPException(status_code=404, detail=f"No matching model found: {request_model}")
@@ -988,7 +988,7 @@ class ModelRequestHandler:
988
  index += 1
989
  provider = matching_providers[current_index]
990
  try:
991
- response = await process_request(request, provider, endpoint, token)
992
  return response
993
  except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError, httpx.ReadTimeout) as e:
994
 
@@ -1058,9 +1058,12 @@ async def rate_limit_dependency(request: Request, credentials: HTTPAuthorization
1058
  try:
1059
  api_index = api_list.index(token)
1060
  except ValueError:
1061
- print("error: Invalid or missing API Key:", token)
1062
- api_index = None
1063
- token = None
 
 
 
1064
 
1065
  # 使用 IP 地址和 token(如果有)作为限制键
1066
  client_ip = request.client.host
@@ -1073,32 +1076,44 @@ async def rate_limit_dependency(request: Request, credentials: HTTPAuthorization
1073
  def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
1074
  api_list = app.state.api_list
1075
  token = credentials.credentials
1076
- if token not in api_list:
 
 
 
 
 
 
1077
  raise HTTPException(status_code=403, detail="Invalid or missing API Key")
1078
- return token
1079
 
1080
  def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
1081
  api_list = app.state.api_list
1082
  token = credentials.credentials
1083
- if token not in api_list:
 
 
 
 
 
 
1084
  raise HTTPException(status_code=403, detail="Invalid or missing API Key")
1085
- for api_key in app.state.api_keys_db:
1086
- if api_key['api'] == token:
1087
- if api_key.get('role') != "admin":
1088
- raise HTTPException(status_code=403, detail="Permission denied")
1089
  return token
1090
 
1091
  @app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
1092
- async def request_model(request: RequestModel, token: str = Depends(verify_api_key)):
1093
- return await model_handler.request_model(request, token)
1094
 
1095
  @app.options("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
1096
  async def options_handler():
1097
  return JSONResponse(status_code=200, content={"detail": "OPTIONS allowed"})
1098
 
1099
  @app.get("/v1/models", dependencies=[Depends(rate_limit_dependency)])
1100
- async def list_models(token: str = Depends(verify_api_key)):
1101
- models = post_all_models(token, app.state.config, app.state.api_list)
1102
  return JSONResponse(content={
1103
  "object": "list",
1104
  "data": models
@@ -1107,23 +1122,23 @@ async def list_models(token: str = Depends(verify_api_key)):
1107
  @app.post("/v1/images/generations", dependencies=[Depends(rate_limit_dependency)])
1108
  async def images_generations(
1109
  request: ImageGenerationRequest,
1110
- token: str = Depends(verify_api_key)
1111
  ):
1112
- return await model_handler.request_model(request, token, endpoint="/v1/images/generations")
1113
 
1114
  @app.post("/v1/embeddings", dependencies=[Depends(rate_limit_dependency)])
1115
  async def embeddings(
1116
  request: EmbeddingRequest,
1117
- token: str = Depends(verify_api_key)
1118
  ):
1119
- return await model_handler.request_model(request, token, endpoint="/v1/embeddings")
1120
 
1121
  @app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
1122
  async def moderations(
1123
  request: ModerationRequest,
1124
- token: str = Depends(verify_api_key)
1125
  ):
1126
- return await model_handler.request_model(request, token, endpoint="/v1/moderations")
1127
 
1128
  from fastapi import UploadFile, File, Form, HTTPException
1129
  import io
@@ -1131,7 +1146,7 @@ import io
1131
  async def audio_transcriptions(
1132
  file: UploadFile = File(...),
1133
  model: str = Form(...),
1134
- token: str = Depends(verify_api_key)
1135
  ):
1136
  try:
1137
  # 读取上传的文件内容
@@ -1144,7 +1159,7 @@ async def audio_transcriptions(
1144
  model=model
1145
  )
1146
 
1147
- return await model_handler.request_model(request, token, endpoint="/v1/audio/transcriptions")
1148
  except UnicodeDecodeError:
1149
  raise HTTPException(status_code=400, detail="Invalid audio file encoding")
1150
  except Exception as e:
 
418
  )
419
  else:
420
  token = None
421
+
422
+ api_index = None
423
  if token:
424
  try:
425
  api_list = app.state.api_list
426
  api_index = api_list.index(token)
 
427
  except ValueError:
428
+ # 如果 token 不在 api_list 中,检查是否以 api_list 中的任何一个开头
429
+ api_index = next((i for i, api in enumerate(api_list) if token.startswith(api)), None)
430
  # token不在api_list中,使用默认值(不开启)
431
  pass
432
+
433
+ if api_index is not None:
434
+ enable_moderation = safe_get(config, 'api_keys', api_index, "preferences", "ENABLE_MODERATION", default=False)
435
  else:
436
  # 如果token为None,检查全局设置
437
  enable_moderation = config.get('ENABLE_MODERATION', False)
 
479
 
480
 
481
  if enable_moderation and moderated_content:
482
+ moderation_response = await self.moderate_content(moderated_content, api_index)
483
  is_flagged = moderation_response.get('results', [{}])[0].get('flagged', False)
484
 
485
  if is_flagged:
 
524
  # print("current_request_info", current_request_info)
525
  request_info.reset(current_request_info)
526
 
527
+ async def moderate_content(self, content, api_index):
528
  moderation_request = ModerationRequest(input=content)
529
 
530
  # 直接调用 moderations 函数
531
+ response = await moderations(moderation_request, api_index)
532
 
533
  # 读取流式响应的内容
534
  moderation_result = b""
 
646
  return await call_next(request)
647
 
648
  # 在 process_request 函数中更新成功和失败计数
649
+ async def process_request(request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], provider: Dict, endpoint=None):
650
  url = provider['base_url']
651
  parsed_url = urlparse(url)
652
  # print("parsed_url", parsed_url)
 
751
  # response = JSONResponse(first_element)
752
 
753
  # 更新成功计数和首次响应时间
754
+ await update_channel_stats(current_info["request_id"], provider['provider'], request.model, current_info["api_key"], success=True)
 
755
  current_info["first_response_time"] = first_response_time
756
  current_info["success"] = True
757
  current_info["provider"] = provider['provider']
758
  return response
759
 
760
  except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError, httpx.ReadTimeout) as e:
761
+ await update_channel_stats(current_info["request_id"], provider['provider'], request.model, current_info["api_key"], success=False)
 
 
762
  raise e
763
 
764
  def weighted_round_robin(weights):
 
953
  self.last_provider_indices = defaultdict(lambda: -1)
954
  self.locks = defaultdict(asyncio.Lock)
955
 
956
+ async def request_model(self, request: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest], api_index: int = None, endpoint=None):
957
  config = app.state.config
 
 
 
958
  request_model = request.model
959
  if not safe_get(config, 'api_keys', api_index, 'model'):
960
  raise HTTPException(status_code=404, detail=f"No matching model found: {request_model}")
 
988
  index += 1
989
  provider = matching_providers[current_index]
990
  try:
991
+ response = await process_request(request, provider, endpoint)
992
  return response
993
  except (Exception, HTTPException, asyncio.CancelledError, httpx.ReadError, httpx.RemoteProtocolError, httpx.ReadTimeout) as e:
994
 
 
1058
  try:
1059
  api_index = api_list.index(token)
1060
  except ValueError:
1061
+ # 如果 token 不在 api_list 中,检查是否以 api_list 中的任何一个开头
1062
+ api_index = next((i for i, api in enumerate(api_list) if token.startswith(api)), None)
1063
+ if api_index is None:
1064
+ print("error: Invalid or missing API Key:", token)
1065
+ api_index = None
1066
+ token = None
1067
 
1068
  # 使用 IP 地址和 token(如果有)作为限制键
1069
  client_ip = request.client.host
 
1076
  def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
1077
  api_list = app.state.api_list
1078
  token = credentials.credentials
1079
+ api_index = None
1080
+ try:
1081
+ api_index = api_list.index(token)
1082
+ except ValueError:
1083
+ # 如果 token 不在 api_list 中,检查是否以 api_list 中的任何一个开头
1084
+ api_index = next((i for i, api in enumerate(api_list) if token.startswith(api)), None)
1085
+ if api_index is None:
1086
  raise HTTPException(status_code=403, detail="Invalid or missing API Key")
1087
+ return api_index
1088
 
1089
  def verify_admin_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
1090
  api_list = app.state.api_list
1091
  token = credentials.credentials
1092
+ api_index = None
1093
+ try:
1094
+ api_index = api_list.index(token)
1095
+ except ValueError:
1096
+ # 如果 token 不在 api_list 中,检查是否以 api_list 中的任何一个开头
1097
+ api_index = next((i for i, api in enumerate(api_list) if token.startswith(api)), None)
1098
+ if api_index is None:
1099
  raise HTTPException(status_code=403, detail="Invalid or missing API Key")
1100
+ # for api_key in app.state.api_keys_db:
1101
+ # if token.startswith(api_key['api']):
1102
+ if app.state.api_keys_db[api_index].get('role') != "admin":
1103
+ raise HTTPException(status_code=403, detail="Permission denied")
1104
  return token
1105
 
1106
  @app.post("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
1107
+ async def request_model(request: RequestModel, api_index: int = Depends(verify_api_key)):
1108
+ return await model_handler.request_model(request, api_index)
1109
 
1110
  @app.options("/v1/chat/completions", dependencies=[Depends(rate_limit_dependency)])
1111
  async def options_handler():
1112
  return JSONResponse(status_code=200, content={"detail": "OPTIONS allowed"})
1113
 
1114
  @app.get("/v1/models", dependencies=[Depends(rate_limit_dependency)])
1115
+ async def list_models(api_index: int = Depends(verify_api_key)):
1116
+ models = post_all_models(api_index, app.state.config)
1117
  return JSONResponse(content={
1118
  "object": "list",
1119
  "data": models
 
1122
  @app.post("/v1/images/generations", dependencies=[Depends(rate_limit_dependency)])
1123
  async def images_generations(
1124
  request: ImageGenerationRequest,
1125
+ api_index: int = Depends(verify_api_key)
1126
  ):
1127
+ return await model_handler.request_model(request, api_index, endpoint="/v1/images/generations")
1128
 
1129
  @app.post("/v1/embeddings", dependencies=[Depends(rate_limit_dependency)])
1130
  async def embeddings(
1131
  request: EmbeddingRequest,
1132
+ api_index: int = Depends(verify_api_key)
1133
  ):
1134
+ return await model_handler.request_model(request, api_index, endpoint="/v1/embeddings")
1135
 
1136
  @app.post("/v1/moderations", dependencies=[Depends(rate_limit_dependency)])
1137
  async def moderations(
1138
  request: ModerationRequest,
1139
+ api_index: int = Depends(verify_api_key)
1140
  ):
1141
+ return await model_handler.request_model(request, api_index, endpoint="/v1/moderations")
1142
 
1143
  from fastapi import UploadFile, File, Form, HTTPException
1144
  import io
 
1146
  async def audio_transcriptions(
1147
  file: UploadFile = File(...),
1148
  model: str = Form(...),
1149
+ api_index: int = Depends(verify_api_key)
1150
  ):
1151
  try:
1152
  # 读取上传的文件内容
 
1159
  model=model
1160
  )
1161
 
1162
+ return await model_handler.request_model(request, api_index, endpoint="/v1/audio/transcriptions")
1163
  except UnicodeDecodeError:
1164
  raise HTTPException(status_code=400, detail="Invalid audio file encoding")
1165
  except Exception as e:
utils.py CHANGED
@@ -63,7 +63,7 @@ class InMemoryRateLimiter:
63
 
64
  rate_limiter = InMemoryRateLimiter()
65
 
66
- async def get_user_rate_limit(app, api_index: str = None):
67
  # 这里应该实现根据 token 获取用户速率限制的逻辑
68
  # 示例: 返回 (次数, 秒数)
69
  config = app.state.config
@@ -457,13 +457,10 @@ async def error_handling_wrapper(generator):
457
  except StopAsyncIteration:
458
  raise HTTPException(status_code=400, detail="data: {'error': 'No data returned'}")
459
 
460
- def post_all_models(token, config, api_list):
461
  all_models = []
462
  unique_models = set()
463
 
464
- if token not in api_list:
465
- raise HTTPException(status_code=403, detail="Invalid or missing API Key")
466
- api_index = api_list.index(token)
467
  if config['api_keys'][api_index]['model']:
468
  for model in config['api_keys'][api_index]['model']:
469
  if model == "all":
 
63
 
64
  rate_limiter = InMemoryRateLimiter()
65
 
66
+ async def get_user_rate_limit(app, api_index: int = None):
67
  # 这里应该实现根据 token 获取用户速率限制的逻辑
68
  # 示例: 返回 (次数, 秒数)
69
  config = app.state.config
 
457
  except StopAsyncIteration:
458
  raise HTTPException(status_code=400, detail="data: {'error': 'No data returned'}")
459
 
460
+ def post_all_models(api_index, config):
461
  all_models = []
462
  unique_models = set()
463
 
 
 
 
464
  if config['api_keys'][api_index]['model']:
465
  for model in config['api_keys'][api_index]['model']:
466
  if model == "all":