yym68686 commited on
Commit
cdf3ed9
·
1 Parent(s): 1778d52

✨ Feature: Add feature: support setting rate limit for each model individually

Browse files
Files changed (5) hide show
  1. README.md +6 -3
  2. README_CN.md +6 -3
  3. main.py +14 -13
  4. request.py +22 -23
  5. utils.py +47 -17
README.md CHANGED
@@ -90,9 +90,12 @@ providers:
90
  - gemini-1.5-flash-exp-0827 # Add this line, both gemini-1.5-flash-exp-0827 and gemini-1.5-flash can be requested
91
  tools: true
92
  preferences:
93
- API_KEY_RATE_LIMIT: 15/min # Each API Key can request up to 15 times per minute, optional. The default is 999999/min.
94
- # API_KEY_RATE_LIMIT: 15/min,10/day # Supports multiple frequency constraints
95
- API_KEY_COOLDOWN_PERIOD: 60 # Each API Key will be cooled down for 60 seconds after encountering a 429 error. Optional, the default is 0 seconds. When set to 0, the cooling mechanism is not enabled.
 
 
 
96
 
97
  - provider: vertex
98
  project_id: gen-lang-client-xxxxxxxxxxxxxx # Description: Your Google Cloud project ID. Format: String, usually composed of lowercase letters, numbers, and hyphens. How to obtain: You can find your project ID in the project selector of the Google Cloud Console.
 
90
  - gemini-1.5-flash-exp-0827 # Add this line, both gemini-1.5-flash-exp-0827 and gemini-1.5-flash can be requested
91
  tools: true
92
  preferences:
93
+ api_key_rate_limit: 15/min # Each API Key can request up to 15 times per minute, optional. The default is 999999/min. Supports multiple frequency constraints: 15/min,10/day
94
+ # api_key_rate_limit: # You can set different frequency limits for each model
95
+ # gpt-4o: 3/min
96
+ # chatgpt-4o-latest: 2/min
97
+ # default: 4/min # If the model does not set the frequency limit, use the frequency limit of default
98
+ api_key_cooldown_period: 60 # Each API Key will be cooled down for 60 seconds after encountering a 429 error. Optional, the default is 0 seconds. When set to 0, the cooling mechanism is not enabled. When there are multiple API keys, the cooling mechanism will take effect.
99
 
100
  - provider: vertex
101
  project_id: gen-lang-client-xxxxxxxxxxxxxx # Description: Your Google Cloud project ID. Format: String, usually composed of lowercase letters, numbers, and hyphens. How to obtain: You can find your project ID in the project selector of the Google Cloud Console.
README_CN.md CHANGED
@@ -90,9 +90,12 @@ providers:
90
  - gemini-1.5-flash-exp-0827 # 加上这一行,gemini-1.5-flash-exp-0827 和 gemini-1.5-flash 都可以被请求
91
  tools: true
92
  preferences:
93
- API_KEY_RATE_LIMIT: 15/min # 每个 API Key 每分钟最多请求次数,选填。默认为 999999/min
94
- # API_KEY_RATE_LIMIT: 15/min,10/day # 支持多个频率约束条件
95
- API_KEY_COOLDOWN_PERIOD: 60 # 每个 API Key 遭遇 429 错误后的冷却时间,单位为秒,选填。默认为 0 秒, 当设置为 0 秒时,不启用冷却机制。
 
 
 
96
 
97
  - provider: vertex
98
  project_id: gen-lang-client-xxxxxxxxxxxxxx # 描述: 您的Google Cloud项目ID。格式: 字符串,通常由小写字母、数字和连字符组成。获取方式: 在Google Cloud Console的项目选择器中可以找到您的项目ID。
 
90
  - gemini-1.5-flash-exp-0827 # 加上这一行,gemini-1.5-flash-exp-0827 和 gemini-1.5-flash 都可以被请求
91
  tools: true
92
  preferences:
93
+ api_key_rate_limit: 15/min # 每个 API Key 每分钟最多请求次数,选填。默认为 999999/min。支持多个频率约束条件:15/min,10/day
94
+ # api_key_rate_limit: # 可以为每个模型设置不同的频率限制
95
+ # gpt-4o: 3/min
96
+ # chatgpt-4o-latest: 2/min
97
+ # default: 4/min # 如果模型没有设置频率限制,使用 default 的频率限制
98
+ api_key_cooldown_period: 60 # 每个 API Key 遭遇 429 错误后的冷却时间,单位为秒,选填。默认为 0 秒, 当设置为 0 秒时,不启用冷却机制。当存在多个 API key 时才会生效。
99
 
100
  - provider: vertex
101
  project_id: gen-lang-client-xxxxxxxxxxxxxx # 描述: 您的Google Cloud项目ID。格式: 字符串,通常由小写字母、数字和连字符组成。获取方式: 在Google Cloud Console的项目选择器中可以找到您的项目ID。
main.py CHANGED
@@ -655,20 +655,22 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
655
  engine = "gpt"
656
 
657
  model_dict = get_model_dict(provider)
658
- if "claude" not in model_dict[request.model] \
659
- and "gpt" not in model_dict[request.model] \
660
- and "gemini" not in model_dict[request.model] \
 
 
661
  and parsed_url.netloc != 'api.cloudflare.com' \
662
  and parsed_url.netloc != 'api.cohere.com':
663
  engine = "openrouter"
664
 
665
- if "claude" in model_dict[request.model] and engine == "vertex":
666
  engine = "vertex-claude"
667
 
668
- if "gemini" in model_dict[request.model] and engine == "vertex":
669
  engine = "vertex-gemini"
670
 
671
- if "o1-preview" in model_dict[request.model] or "o1-mini" in model_dict[request.model]:
672
  engine = "o1"
673
  request.stream = False
674
 
@@ -702,17 +704,16 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
702
  logger.info(json.dumps(payload, indent=4, ensure_ascii=False))
703
 
704
  current_info = request_info.get()
705
- model = model_dict[request.model]
706
 
707
  timeout_value = None
708
  # 先尝试精确匹配
709
 
710
- if model in app.state.timeouts:
711
- timeout_value = app.state.timeouts[model]
712
  else:
713
  # 如果没有精确匹配,尝试模糊匹配
714
  for timeout_model in app.state.timeouts:
715
- if timeout_model in model:
716
  timeout_value = app.state.timeouts[timeout_model]
717
  break
718
 
@@ -723,11 +724,11 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
723
  try:
724
  async with app.state.client_manager.get_client(timeout_value) as client:
725
  if request.stream:
726
- generator = fetch_response_stream(client, url, headers, payload, engine, model)
727
  wrapped_generator, first_response_time = await error_handling_wrapper(generator)
728
  response = StarletteStreamingResponse(wrapped_generator, media_type="text/event-stream")
729
  else:
730
- generator = fetch_response(client, url, headers, payload, engine, model)
731
  wrapped_generator, first_response_time = await error_handling_wrapper(generator)
732
  first_element = await anext(wrapped_generator)
733
  first_element = first_element.lstrip("data: ")
@@ -1013,7 +1014,7 @@ class ModelRequestHandler:
1013
  num_matching_providers = len(matching_providers)
1014
  index = 0
1015
 
1016
- cooling_time = safe_get(provider, "preferences", "API_KEY_COOLDOWN_PERIOD", default=0)
1017
  api_key_count = provider_api_circular_list[channel_id].get_items_count()
1018
  if cooling_time > 0 and api_key_count > 1:
1019
  current_api = await provider_api_circular_list[channel_id].after_next_current()
 
655
  engine = "gpt"
656
 
657
  model_dict = get_model_dict(provider)
658
+ original_model = model_dict[request.model]
659
+
660
+ if "claude" not in original_model \
661
+ and "gpt" not in original_model \
662
+ and "gemini" not in original_model \
663
  and parsed_url.netloc != 'api.cloudflare.com' \
664
  and parsed_url.netloc != 'api.cohere.com':
665
  engine = "openrouter"
666
 
667
+ if "claude" in original_model and engine == "vertex":
668
  engine = "vertex-claude"
669
 
670
+ if "gemini" in original_model and engine == "vertex":
671
  engine = "vertex-gemini"
672
 
673
+ if "o1-preview" in original_model or "o1-mini" in original_model:
674
  engine = "o1"
675
  request.stream = False
676
 
 
704
  logger.info(json.dumps(payload, indent=4, ensure_ascii=False))
705
 
706
  current_info = request_info.get()
 
707
 
708
  timeout_value = None
709
  # 先尝试精确匹配
710
 
711
+ if original_model in app.state.timeouts:
712
+ timeout_value = app.state.timeouts[original_model]
713
  else:
714
  # 如果没有精确匹配,尝试模糊匹配
715
  for timeout_model in app.state.timeouts:
716
+ if timeout_model in original_model:
717
  timeout_value = app.state.timeouts[timeout_model]
718
  break
719
 
 
724
  try:
725
  async with app.state.client_manager.get_client(timeout_value) as client:
726
  if request.stream:
727
+ generator = fetch_response_stream(client, url, headers, payload, engine, original_model)
728
  wrapped_generator, first_response_time = await error_handling_wrapper(generator)
729
  response = StarletteStreamingResponse(wrapped_generator, media_type="text/event-stream")
730
  else:
731
+ generator = fetch_response(client, url, headers, payload, engine, original_model)
732
  wrapped_generator, first_response_time = await error_handling_wrapper(generator)
733
  first_element = await anext(wrapped_generator)
734
  first_element = first_element.lstrip("data: ")
 
1014
  num_matching_providers = len(matching_providers)
1015
  index = 0
1016
 
1017
+ cooling_time = safe_get(provider, "preferences", "api_key_cooldown_period", default=0)
1018
  api_key_count = provider_api_circular_list[channel_id].get_items_count()
1019
  if cooling_time > 0 and api_key_count > 1:
1020
  current_api = await provider_api_circular_list[channel_id].after_next_current()
request.py CHANGED
@@ -125,9 +125,9 @@ async def get_gemini_payload(request, engine, provider):
125
  gemini_stream = "streamGenerateContent"
126
  url = provider['base_url']
127
  if url.endswith("v1beta"):
128
- url = "https://generativelanguage.googleapis.com/v1beta/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next())
129
  if url.endswith("v1"):
130
- url = "https://generativelanguage.googleapis.com/v1/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next())
131
 
132
  messages = []
133
  systemInstruction = None
@@ -596,8 +596,10 @@ async def get_gpt_payload(request, engine, provider):
596
  headers = {
597
  'Content-Type': 'application/json',
598
  }
 
 
599
  if provider.get("api"):
600
- headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
601
  url = provider['base_url']
602
 
603
  messages = []
@@ -637,8 +639,6 @@ async def get_gpt_payload(request, engine, provider):
637
  else:
638
  messages.append({"role": msg.role, "content": content})
639
 
640
- model_dict = get_model_dict(provider)
641
- model = model_dict[request.model]
642
  payload = {
643
  "model": model,
644
  "messages": messages,
@@ -663,8 +663,10 @@ async def get_openrouter_payload(request, engine, provider):
663
  headers = {
664
  'Content-Type': 'application/json'
665
  }
 
 
666
  if provider.get("api"):
667
- headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
668
 
669
  url = provider['base_url']
670
 
@@ -696,8 +698,6 @@ async def get_openrouter_payload(request, engine, provider):
696
  else:
697
  messages.append({"role": msg.role, "content": content})
698
 
699
- model_dict = get_model_dict(provider)
700
- model = model_dict[request.model]
701
  payload = {
702
  "model": model,
703
  "messages": messages,
@@ -730,8 +730,10 @@ async def get_cohere_payload(request, engine, provider):
730
  headers = {
731
  'Content-Type': 'application/json'
732
  }
 
 
733
  if provider.get("api"):
734
- headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
735
 
736
  url = provider['base_url']
737
 
@@ -759,8 +761,6 @@ async def get_cohere_payload(request, engine, provider):
759
  else:
760
  messages.append({"role": role_map[msg.role], "message": content})
761
 
762
- model_dict = get_model_dict(provider)
763
- model = model_dict[request.model]
764
  chat_history = messages[:-1]
765
  query = messages[-1].get("message")
766
  payload = {
@@ -798,11 +798,11 @@ async def get_cloudflare_payload(request, engine, provider):
798
  headers = {
799
  'Content-Type': 'application/json'
800
  }
801
- if provider.get("api"):
802
- headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
803
-
804
  model_dict = get_model_dict(provider)
805
  model = model_dict[request.model]
 
 
 
806
  url = "https://api.cloudflare.com/client/v4/accounts/{cf_account_id}/ai/run/{cf_model_id}".format(cf_account_id=provider['cf_account_id'], cf_model_id=model)
807
 
808
  msg = request.messages[-1]
@@ -816,7 +816,6 @@ async def get_cloudflare_payload(request, engine, provider):
816
  content = msg.content
817
  name = msg.name
818
 
819
- model = model_dict[request.model]
820
  payload = {
821
  "prompt": content,
822
  }
@@ -848,8 +847,10 @@ async def get_o1_payload(request, engine, provider):
848
  headers = {
849
  'Content-Type': 'application/json'
850
  }
 
 
851
  if provider.get("api"):
852
- headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
853
 
854
  url = provider['base_url']
855
 
@@ -871,8 +872,6 @@ async def get_o1_payload(request, engine, provider):
871
  elif msg.role != "system":
872
  messages.append({"role": msg.role, "content": content})
873
 
874
- model_dict = get_model_dict(provider)
875
- model = model_dict[request.model]
876
  payload = {
877
  "model": model,
878
  "messages": messages,
@@ -925,7 +924,7 @@ async def get_claude_payload(request, engine, provider):
925
  model = model_dict[request.model]
926
  headers = {
927
  "content-type": "application/json",
928
- "x-api-key": f"{await provider_api_circular_list[provider['provider']].next()}",
929
  "anthropic-version": "2023-06-01",
930
  "anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" if "claude-3-5-sonnet" in model else "tools-2024-05-16",
931
  }
@@ -1068,7 +1067,7 @@ async def get_dalle_payload(request, engine, provider):
1068
  "Content-Type": "application/json",
1069
  }
1070
  if provider.get("api"):
1071
- headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
1072
  url = provider['base_url']
1073
  url = BaseAPI(url).image_url
1074
 
@@ -1088,7 +1087,7 @@ async def get_whisper_payload(request, engine, provider):
1088
  # "Content-Type": "multipart/form-data",
1089
  }
1090
  if provider.get("api"):
1091
- headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
1092
  url = provider['base_url']
1093
  url = BaseAPI(url).audio_transcriptions
1094
 
@@ -1115,7 +1114,7 @@ async def get_moderation_payload(request, engine, provider):
1115
  "Content-Type": "application/json",
1116
  }
1117
  if provider.get("api"):
1118
- headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
1119
  url = provider['base_url']
1120
  url = BaseAPI(url).moderations
1121
 
@@ -1132,7 +1131,7 @@ async def get_embedding_payload(request, engine, provider):
1132
  "Content-Type": "application/json",
1133
  }
1134
  if provider.get("api"):
1135
- headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
1136
  url = provider['base_url']
1137
  url = BaseAPI(url).embeddings
1138
 
 
125
  gemini_stream = "streamGenerateContent"
126
  url = provider['base_url']
127
  if url.endswith("v1beta"):
128
+ url = "https://generativelanguage.googleapis.com/v1beta/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next(model))
129
  if url.endswith("v1"):
130
+ url = "https://generativelanguage.googleapis.com/v1/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next(model))
131
 
132
  messages = []
133
  systemInstruction = None
 
596
  headers = {
597
  'Content-Type': 'application/json',
598
  }
599
+ model_dict = get_model_dict(provider)
600
+ model = model_dict[request.model]
601
  if provider.get("api"):
602
+ headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
603
  url = provider['base_url']
604
 
605
  messages = []
 
639
  else:
640
  messages.append({"role": msg.role, "content": content})
641
 
 
 
642
  payload = {
643
  "model": model,
644
  "messages": messages,
 
663
  headers = {
664
  'Content-Type': 'application/json'
665
  }
666
+ model_dict = get_model_dict(provider)
667
+ model = model_dict[request.model]
668
  if provider.get("api"):
669
+ headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
670
 
671
  url = provider['base_url']
672
 
 
698
  else:
699
  messages.append({"role": msg.role, "content": content})
700
 
 
 
701
  payload = {
702
  "model": model,
703
  "messages": messages,
 
730
  headers = {
731
  'Content-Type': 'application/json'
732
  }
733
+ model_dict = get_model_dict(provider)
734
+ model = model_dict[request.model]
735
  if provider.get("api"):
736
+ headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
737
 
738
  url = provider['base_url']
739
 
 
761
  else:
762
  messages.append({"role": role_map[msg.role], "message": content})
763
 
 
 
764
  chat_history = messages[:-1]
765
  query = messages[-1].get("message")
766
  payload = {
 
798
  headers = {
799
  'Content-Type': 'application/json'
800
  }
 
 
 
801
  model_dict = get_model_dict(provider)
802
  model = model_dict[request.model]
803
+ if provider.get("api"):
804
+ headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
805
+
806
  url = "https://api.cloudflare.com/client/v4/accounts/{cf_account_id}/ai/run/{cf_model_id}".format(cf_account_id=provider['cf_account_id'], cf_model_id=model)
807
 
808
  msg = request.messages[-1]
 
816
  content = msg.content
817
  name = msg.name
818
 
 
819
  payload = {
820
  "prompt": content,
821
  }
 
847
  headers = {
848
  'Content-Type': 'application/json'
849
  }
850
+ model_dict = get_model_dict(provider)
851
+ model = model_dict[request.model]
852
  if provider.get("api"):
853
+ headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
854
 
855
  url = provider['base_url']
856
 
 
872
  elif msg.role != "system":
873
  messages.append({"role": msg.role, "content": content})
874
 
 
 
875
  payload = {
876
  "model": model,
877
  "messages": messages,
 
924
  model = model_dict[request.model]
925
  headers = {
926
  "content-type": "application/json",
927
+ "x-api-key": f"{await provider_api_circular_list[provider['provider']].next(model)}",
928
  "anthropic-version": "2023-06-01",
929
  "anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" if "claude-3-5-sonnet" in model else "tools-2024-05-16",
930
  }
 
1067
  "Content-Type": "application/json",
1068
  }
1069
  if provider.get("api"):
1070
+ headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
1071
  url = provider['base_url']
1072
  url = BaseAPI(url).image_url
1073
 
 
1087
  # "Content-Type": "multipart/form-data",
1088
  }
1089
  if provider.get("api"):
1090
+ headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
1091
  url = provider['base_url']
1092
  url = BaseAPI(url).audio_transcriptions
1093
 
 
1114
  "Content-Type": "application/json",
1115
  }
1116
  if provider.get("api"):
1117
+ headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
1118
  url = provider['base_url']
1119
  url = BaseAPI(url).moderations
1120
 
 
1131
  "Content-Type": "application/json",
1132
  }
1133
  if provider.get("api"):
1134
+ headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
1135
  url = provider['base_url']
1136
  url = BaseAPI(url).embeddings
1137
 
utils.py CHANGED
@@ -80,13 +80,21 @@ async def get_user_rate_limit(app, api_index: str = None):
80
  import asyncio
81
 
82
  class ThreadSafeCircularList:
83
- def __init__(self, items, rate_limit="99999/min"):
84
  self.items = items
85
  self.index = 0
86
  self.lock = asyncio.Lock()
87
- self.requests = defaultdict(list) # 用于追踪每个 API key 的请求时间
 
88
  self.cooling_until = defaultdict(float)
89
- self.rate_limits = parse_rate_limit(rate_limit) # 现在返回一个限制条件列表
 
 
 
 
 
 
 
90
 
91
  async def set_cooling(self, item: str, cooling_time: int = 60):
92
  """设置某个 item 进入冷却状态
@@ -102,36 +110,58 @@ class ThreadSafeCircularList:
102
  # self.requests[item] = []
103
  logger.warning(f"API key {item} 已进入冷却状态,冷却时间 {cooling_time} 秒")
104
 
105
- async def is_rate_limited(self, item) -> bool:
106
  now = time()
107
  # 检查是否在冷却中
108
  if now < self.cooling_until[item]:
109
  return True
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  # 检查所有速率限制条件
112
- for limit_count, limit_period in self.rate_limits:
113
- # 计算在当前时间窗口内的请求数量,而不是直接修改请求列表
114
- recent_requests = sum(1 for req in self.requests[item] if req > now - limit_period)
115
  if recent_requests >= limit_count:
116
- logger.warning(f"API key {item} 已达到速率限制 ({limit_count}/{limit_period}秒)")
117
  return True
118
 
119
- # 清理太旧的请求记录(比最长时间窗口还要老的记录)
120
- max_period = max(period for _, period in self.rate_limits)
121
- self.requests[item] = [req for req in self.requests[item] if req > now - max_period]
122
 
123
- # 所有限制条件都通过,记录新的请求
124
- self.requests[item].append(now)
125
  return False
126
 
127
- async def next(self):
128
  async with self.lock:
129
  start_index = self.index
130
  while True:
131
  item = self.items[self.index]
132
  self.index = (self.index + 1) % len(self.items)
133
 
134
- if not await self.is_rate_limited(item):
135
  return item
136
 
137
  # 如果已经检查了所有的 API key 都被限制
@@ -220,12 +250,12 @@ def update_config(config_data, use_config_url=False):
220
  if isinstance(provider_api, str):
221
  provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(
222
  [provider_api],
223
- safe_get(provider, "preferences", "API_KEY_RATE_LIMIT", default="999999/min")
224
  )
225
  if isinstance(provider_api, list):
226
  provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(
227
  provider_api,
228
- safe_get(provider, "preferences", "API_KEY_RATE_LIMIT", default="999999/min")
229
  )
230
 
231
  if not provider.get("model"):
 
80
  import asyncio
81
 
82
  class ThreadSafeCircularList:
83
+ def __init__(self, items, rate_limit={"default": "999999/min"}):
84
  self.items = items
85
  self.index = 0
86
  self.lock = asyncio.Lock()
87
+ # 修改为二级字典,第一级是item,第二级是model
88
+ self.requests = defaultdict(lambda: defaultdict(list))
89
  self.cooling_until = defaultdict(float)
90
+ self.rate_limits = {}
91
+ if isinstance(rate_limit, dict):
92
+ for rate_limit_model, rate_limit_value in rate_limit.items():
93
+ self.rate_limits[rate_limit_model] = parse_rate_limit(rate_limit_value)
94
+ elif isinstance(rate_limit, str):
95
+ self.rate_limits["default"] = parse_rate_limit(rate_limit)
96
+ else:
97
+ logger.error(f"Error ThreadSafeCircularList: Unknown rate_limit type: {type(rate_limit)}, rate_limit: {rate_limit}")
98
 
99
  async def set_cooling(self, item: str, cooling_time: int = 60):
100
  """设置某个 item 进入冷却状态
 
110
  # self.requests[item] = []
111
  logger.warning(f"API key {item} 已进入冷却状态,冷却时间 {cooling_time} 秒")
112
 
113
+ async def is_rate_limited(self, item, model: str = None) -> bool:
114
  now = time()
115
  # 检查是否在冷却中
116
  if now < self.cooling_until[item]:
117
  return True
118
 
119
+ # 获取适用的速率限制
120
+
121
+ if model:
122
+ model_key = model
123
+ else:
124
+ model_key = "default"
125
+
126
+ rate_limit = None
127
+ # 先尝试精确匹配
128
+ if model and model in self.rate_limits:
129
+ rate_limit = self.rate_limits[model]
130
+ else:
131
+ # 如果没有精确匹配,尝试模糊匹配
132
+ for limit_model in self.rate_limits:
133
+ if limit_model != "default" and model and limit_model in model:
134
+ rate_limit = self.rate_limits[limit_model]
135
+ break
136
+
137
+ # 如果都没匹配到,使用默认值
138
+ if rate_limit is None:
139
+ rate_limit = self.rate_limits.get("default", [(999999, 60)]) # 默认限制
140
+
141
  # 检查所有速率限制条件
142
+ for limit_count, limit_period in rate_limit:
143
+ # 使用特定模型的请求记录进行计算
144
+ recent_requests = sum(1 for req in self.requests[item][model_key] if req > now - limit_period)
145
  if recent_requests >= limit_count:
146
+ logger.warning(f"API key {item} 对模型 {model_key} 已达到速率限制 ({limit_count}/{limit_period}秒)")
147
  return True
148
 
149
+ # 清理太旧的请求记录
150
+ max_period = max(period for _, period in rate_limit)
151
+ self.requests[item][model_key] = [req for req in self.requests[item][model_key] if req > now - max_period]
152
 
153
+ # 记录新的请求
154
+ self.requests[item][model_key].append(now)
155
  return False
156
 
157
+ async def next(self, model: str = None):
158
  async with self.lock:
159
  start_index = self.index
160
  while True:
161
  item = self.items[self.index]
162
  self.index = (self.index + 1) % len(self.items)
163
 
164
+ if not await self.is_rate_limited(item, model):
165
  return item
166
 
167
  # 如果已经检查了所有的 API key 都被限制
 
250
  if isinstance(provider_api, str):
251
  provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(
252
  [provider_api],
253
+ safe_get(provider, "preferences", "api_key_rate_limit", default={"default": "999999/min"})
254
  )
255
  if isinstance(provider_api, list):
256
  provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(
257
  provider_api,
258
+ safe_get(provider, "preferences", "api_key_rate_limit", default={"default": "999999/min"})
259
  )
260
 
261
  if not provider.get("model"):