✨ Feature: Add feature: support setting rate limit for each model individually
Browse files- README.md +6 -3
- README_CN.md +6 -3
- main.py +14 -13
- request.py +22 -23
- utils.py +47 -17
README.md
CHANGED
@@ -90,9 +90,12 @@ providers:
|
|
90 |
- gemini-1.5-flash-exp-0827 # Add this line, both gemini-1.5-flash-exp-0827 and gemini-1.5-flash can be requested
|
91 |
tools: true
|
92 |
preferences:
|
93 |
-
|
94 |
-
#
|
95 |
-
|
|
|
|
|
|
|
96 |
|
97 |
- provider: vertex
|
98 |
project_id: gen-lang-client-xxxxxxxxxxxxxx # Description: Your Google Cloud project ID. Format: String, usually composed of lowercase letters, numbers, and hyphens. How to obtain: You can find your project ID in the project selector of the Google Cloud Console.
|
|
|
90 |
- gemini-1.5-flash-exp-0827 # Add this line, both gemini-1.5-flash-exp-0827 and gemini-1.5-flash can be requested
|
91 |
tools: true
|
92 |
preferences:
|
93 |
+
api_key_rate_limit: 15/min # Each API Key can request up to 15 times per minute, optional. The default is 999999/min. Supports multiple frequency constraints: 15/min,10/day
|
94 |
+
# api_key_rate_limit: # You can set different frequency limits for each model
|
95 |
+
# gpt-4o: 3/min
|
96 |
+
# chatgpt-4o-latest: 2/min
|
97 |
+
# default: 4/min # If the model does not set the frequency limit, use the frequency limit of default
|
98 |
+
api_key_cooldown_period: 60 # Each API Key will be cooled down for 60 seconds after encountering a 429 error. Optional, the default is 0 seconds. When set to 0, the cooling mechanism is not enabled. When there are multiple API keys, the cooling mechanism will take effect.
|
99 |
|
100 |
- provider: vertex
|
101 |
project_id: gen-lang-client-xxxxxxxxxxxxxx # Description: Your Google Cloud project ID. Format: String, usually composed of lowercase letters, numbers, and hyphens. How to obtain: You can find your project ID in the project selector of the Google Cloud Console.
|
README_CN.md
CHANGED
@@ -90,9 +90,12 @@ providers:
|
|
90 |
- gemini-1.5-flash-exp-0827 # 加上这一行,gemini-1.5-flash-exp-0827 和 gemini-1.5-flash 都可以被请求
|
91 |
tools: true
|
92 |
preferences:
|
93 |
-
|
94 |
-
#
|
95 |
-
|
|
|
|
|
|
|
96 |
|
97 |
- provider: vertex
|
98 |
project_id: gen-lang-client-xxxxxxxxxxxxxx # 描述: 您的Google Cloud项目ID。格式: 字符串,通常由小写字母、数字和连字符组成。获取方式: 在Google Cloud Console的项目选择器中可以找到您的项目ID。
|
|
|
90 |
- gemini-1.5-flash-exp-0827 # 加上这一行,gemini-1.5-flash-exp-0827 和 gemini-1.5-flash 都可以被请求
|
91 |
tools: true
|
92 |
preferences:
|
93 |
+
api_key_rate_limit: 15/min # 每个 API Key 每分钟最多请求次数,选填。默认为 999999/min。支持多个频率约束条件:15/min,10/day
|
94 |
+
# api_key_rate_limit: # 可以为每个模型设置不同的频率限制
|
95 |
+
# gpt-4o: 3/min
|
96 |
+
# chatgpt-4o-latest: 2/min
|
97 |
+
# default: 4/min # 如果模型没有设置频率限制,使用 default 的频率限制
|
98 |
+
api_key_cooldown_period: 60 # 每个 API Key 遭遇 429 错误后的冷却时间,单位为秒,选填。默认为 0 秒, 当设置为 0 秒时,不启用冷却机制。当存在多个 API key 时才会生效。
|
99 |
|
100 |
- provider: vertex
|
101 |
project_id: gen-lang-client-xxxxxxxxxxxxxx # 描述: 您的Google Cloud项目ID。格式: 字符串,通常由小写字母、数字和连字符组成。获取方式: 在Google Cloud Console的项目选择器中可以找到您的项目ID。
|
main.py
CHANGED
@@ -655,20 +655,22 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
655 |
engine = "gpt"
|
656 |
|
657 |
model_dict = get_model_dict(provider)
|
658 |
-
|
659 |
-
|
660 |
-
|
|
|
|
|
661 |
and parsed_url.netloc != 'api.cloudflare.com' \
|
662 |
and parsed_url.netloc != 'api.cohere.com':
|
663 |
engine = "openrouter"
|
664 |
|
665 |
-
if "claude" in
|
666 |
engine = "vertex-claude"
|
667 |
|
668 |
-
if "gemini" in
|
669 |
engine = "vertex-gemini"
|
670 |
|
671 |
-
if "o1-preview" in
|
672 |
engine = "o1"
|
673 |
request.stream = False
|
674 |
|
@@ -702,17 +704,16 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
702 |
logger.info(json.dumps(payload, indent=4, ensure_ascii=False))
|
703 |
|
704 |
current_info = request_info.get()
|
705 |
-
model = model_dict[request.model]
|
706 |
|
707 |
timeout_value = None
|
708 |
# 先尝试精确匹配
|
709 |
|
710 |
-
if
|
711 |
-
timeout_value = app.state.timeouts[
|
712 |
else:
|
713 |
# 如果没有精确匹配,尝试模糊匹配
|
714 |
for timeout_model in app.state.timeouts:
|
715 |
-
if timeout_model in
|
716 |
timeout_value = app.state.timeouts[timeout_model]
|
717 |
break
|
718 |
|
@@ -723,11 +724,11 @@ async def process_request(request: Union[RequestModel, ImageGenerationRequest, A
|
|
723 |
try:
|
724 |
async with app.state.client_manager.get_client(timeout_value) as client:
|
725 |
if request.stream:
|
726 |
-
generator = fetch_response_stream(client, url, headers, payload, engine,
|
727 |
wrapped_generator, first_response_time = await error_handling_wrapper(generator)
|
728 |
response = StarletteStreamingResponse(wrapped_generator, media_type="text/event-stream")
|
729 |
else:
|
730 |
-
generator = fetch_response(client, url, headers, payload, engine,
|
731 |
wrapped_generator, first_response_time = await error_handling_wrapper(generator)
|
732 |
first_element = await anext(wrapped_generator)
|
733 |
first_element = first_element.lstrip("data: ")
|
@@ -1013,7 +1014,7 @@ class ModelRequestHandler:
|
|
1013 |
num_matching_providers = len(matching_providers)
|
1014 |
index = 0
|
1015 |
|
1016 |
-
cooling_time = safe_get(provider, "preferences", "
|
1017 |
api_key_count = provider_api_circular_list[channel_id].get_items_count()
|
1018 |
if cooling_time > 0 and api_key_count > 1:
|
1019 |
current_api = await provider_api_circular_list[channel_id].after_next_current()
|
|
|
655 |
engine = "gpt"
|
656 |
|
657 |
model_dict = get_model_dict(provider)
|
658 |
+
original_model = model_dict[request.model]
|
659 |
+
|
660 |
+
if "claude" not in original_model \
|
661 |
+
and "gpt" not in original_model \
|
662 |
+
and "gemini" not in original_model \
|
663 |
and parsed_url.netloc != 'api.cloudflare.com' \
|
664 |
and parsed_url.netloc != 'api.cohere.com':
|
665 |
engine = "openrouter"
|
666 |
|
667 |
+
if "claude" in original_model and engine == "vertex":
|
668 |
engine = "vertex-claude"
|
669 |
|
670 |
+
if "gemini" in original_model and engine == "vertex":
|
671 |
engine = "vertex-gemini"
|
672 |
|
673 |
+
if "o1-preview" in original_model or "o1-mini" in original_model:
|
674 |
engine = "o1"
|
675 |
request.stream = False
|
676 |
|
|
|
704 |
logger.info(json.dumps(payload, indent=4, ensure_ascii=False))
|
705 |
|
706 |
current_info = request_info.get()
|
|
|
707 |
|
708 |
timeout_value = None
|
709 |
# 先尝试精确匹配
|
710 |
|
711 |
+
if original_model in app.state.timeouts:
|
712 |
+
timeout_value = app.state.timeouts[original_model]
|
713 |
else:
|
714 |
# 如果没有精确匹配,尝试模糊匹配
|
715 |
for timeout_model in app.state.timeouts:
|
716 |
+
if timeout_model in original_model:
|
717 |
timeout_value = app.state.timeouts[timeout_model]
|
718 |
break
|
719 |
|
|
|
724 |
try:
|
725 |
async with app.state.client_manager.get_client(timeout_value) as client:
|
726 |
if request.stream:
|
727 |
+
generator = fetch_response_stream(client, url, headers, payload, engine, original_model)
|
728 |
wrapped_generator, first_response_time = await error_handling_wrapper(generator)
|
729 |
response = StarletteStreamingResponse(wrapped_generator, media_type="text/event-stream")
|
730 |
else:
|
731 |
+
generator = fetch_response(client, url, headers, payload, engine, original_model)
|
732 |
wrapped_generator, first_response_time = await error_handling_wrapper(generator)
|
733 |
first_element = await anext(wrapped_generator)
|
734 |
first_element = first_element.lstrip("data: ")
|
|
|
1014 |
num_matching_providers = len(matching_providers)
|
1015 |
index = 0
|
1016 |
|
1017 |
+
cooling_time = safe_get(provider, "preferences", "api_key_cooldown_period", default=0)
|
1018 |
api_key_count = provider_api_circular_list[channel_id].get_items_count()
|
1019 |
if cooling_time > 0 and api_key_count > 1:
|
1020 |
current_api = await provider_api_circular_list[channel_id].after_next_current()
|
request.py
CHANGED
@@ -125,9 +125,9 @@ async def get_gemini_payload(request, engine, provider):
|
|
125 |
gemini_stream = "streamGenerateContent"
|
126 |
url = provider['base_url']
|
127 |
if url.endswith("v1beta"):
|
128 |
-
url = "https://generativelanguage.googleapis.com/v1beta/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next())
|
129 |
if url.endswith("v1"):
|
130 |
-
url = "https://generativelanguage.googleapis.com/v1/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next())
|
131 |
|
132 |
messages = []
|
133 |
systemInstruction = None
|
@@ -596,8 +596,10 @@ async def get_gpt_payload(request, engine, provider):
|
|
596 |
headers = {
|
597 |
'Content-Type': 'application/json',
|
598 |
}
|
|
|
|
|
599 |
if provider.get("api"):
|
600 |
-
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
601 |
url = provider['base_url']
|
602 |
|
603 |
messages = []
|
@@ -637,8 +639,6 @@ async def get_gpt_payload(request, engine, provider):
|
|
637 |
else:
|
638 |
messages.append({"role": msg.role, "content": content})
|
639 |
|
640 |
-
model_dict = get_model_dict(provider)
|
641 |
-
model = model_dict[request.model]
|
642 |
payload = {
|
643 |
"model": model,
|
644 |
"messages": messages,
|
@@ -663,8 +663,10 @@ async def get_openrouter_payload(request, engine, provider):
|
|
663 |
headers = {
|
664 |
'Content-Type': 'application/json'
|
665 |
}
|
|
|
|
|
666 |
if provider.get("api"):
|
667 |
-
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
668 |
|
669 |
url = provider['base_url']
|
670 |
|
@@ -696,8 +698,6 @@ async def get_openrouter_payload(request, engine, provider):
|
|
696 |
else:
|
697 |
messages.append({"role": msg.role, "content": content})
|
698 |
|
699 |
-
model_dict = get_model_dict(provider)
|
700 |
-
model = model_dict[request.model]
|
701 |
payload = {
|
702 |
"model": model,
|
703 |
"messages": messages,
|
@@ -730,8 +730,10 @@ async def get_cohere_payload(request, engine, provider):
|
|
730 |
headers = {
|
731 |
'Content-Type': 'application/json'
|
732 |
}
|
|
|
|
|
733 |
if provider.get("api"):
|
734 |
-
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
735 |
|
736 |
url = provider['base_url']
|
737 |
|
@@ -759,8 +761,6 @@ async def get_cohere_payload(request, engine, provider):
|
|
759 |
else:
|
760 |
messages.append({"role": role_map[msg.role], "message": content})
|
761 |
|
762 |
-
model_dict = get_model_dict(provider)
|
763 |
-
model = model_dict[request.model]
|
764 |
chat_history = messages[:-1]
|
765 |
query = messages[-1].get("message")
|
766 |
payload = {
|
@@ -798,11 +798,11 @@ async def get_cloudflare_payload(request, engine, provider):
|
|
798 |
headers = {
|
799 |
'Content-Type': 'application/json'
|
800 |
}
|
801 |
-
if provider.get("api"):
|
802 |
-
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
803 |
-
|
804 |
model_dict = get_model_dict(provider)
|
805 |
model = model_dict[request.model]
|
|
|
|
|
|
|
806 |
url = "https://api.cloudflare.com/client/v4/accounts/{cf_account_id}/ai/run/{cf_model_id}".format(cf_account_id=provider['cf_account_id'], cf_model_id=model)
|
807 |
|
808 |
msg = request.messages[-1]
|
@@ -816,7 +816,6 @@ async def get_cloudflare_payload(request, engine, provider):
|
|
816 |
content = msg.content
|
817 |
name = msg.name
|
818 |
|
819 |
-
model = model_dict[request.model]
|
820 |
payload = {
|
821 |
"prompt": content,
|
822 |
}
|
@@ -848,8 +847,10 @@ async def get_o1_payload(request, engine, provider):
|
|
848 |
headers = {
|
849 |
'Content-Type': 'application/json'
|
850 |
}
|
|
|
|
|
851 |
if provider.get("api"):
|
852 |
-
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
853 |
|
854 |
url = provider['base_url']
|
855 |
|
@@ -871,8 +872,6 @@ async def get_o1_payload(request, engine, provider):
|
|
871 |
elif msg.role != "system":
|
872 |
messages.append({"role": msg.role, "content": content})
|
873 |
|
874 |
-
model_dict = get_model_dict(provider)
|
875 |
-
model = model_dict[request.model]
|
876 |
payload = {
|
877 |
"model": model,
|
878 |
"messages": messages,
|
@@ -925,7 +924,7 @@ async def get_claude_payload(request, engine, provider):
|
|
925 |
model = model_dict[request.model]
|
926 |
headers = {
|
927 |
"content-type": "application/json",
|
928 |
-
"x-api-key": f"{await provider_api_circular_list[provider['provider']].next()}",
|
929 |
"anthropic-version": "2023-06-01",
|
930 |
"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" if "claude-3-5-sonnet" in model else "tools-2024-05-16",
|
931 |
}
|
@@ -1068,7 +1067,7 @@ async def get_dalle_payload(request, engine, provider):
|
|
1068 |
"Content-Type": "application/json",
|
1069 |
}
|
1070 |
if provider.get("api"):
|
1071 |
-
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
1072 |
url = provider['base_url']
|
1073 |
url = BaseAPI(url).image_url
|
1074 |
|
@@ -1088,7 +1087,7 @@ async def get_whisper_payload(request, engine, provider):
|
|
1088 |
# "Content-Type": "multipart/form-data",
|
1089 |
}
|
1090 |
if provider.get("api"):
|
1091 |
-
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
1092 |
url = provider['base_url']
|
1093 |
url = BaseAPI(url).audio_transcriptions
|
1094 |
|
@@ -1115,7 +1114,7 @@ async def get_moderation_payload(request, engine, provider):
|
|
1115 |
"Content-Type": "application/json",
|
1116 |
}
|
1117 |
if provider.get("api"):
|
1118 |
-
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
1119 |
url = provider['base_url']
|
1120 |
url = BaseAPI(url).moderations
|
1121 |
|
@@ -1132,7 +1131,7 @@ async def get_embedding_payload(request, engine, provider):
|
|
1132 |
"Content-Type": "application/json",
|
1133 |
}
|
1134 |
if provider.get("api"):
|
1135 |
-
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next()}"
|
1136 |
url = provider['base_url']
|
1137 |
url = BaseAPI(url).embeddings
|
1138 |
|
|
|
125 |
gemini_stream = "streamGenerateContent"
|
126 |
url = provider['base_url']
|
127 |
if url.endswith("v1beta"):
|
128 |
+
url = "https://generativelanguage.googleapis.com/v1beta/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next(model))
|
129 |
if url.endswith("v1"):
|
130 |
+
url = "https://generativelanguage.googleapis.com/v1/models/{model}:{stream}?key={api_key}".format(model=model, stream=gemini_stream, api_key=await provider_api_circular_list[provider['provider']].next(model))
|
131 |
|
132 |
messages = []
|
133 |
systemInstruction = None
|
|
|
596 |
headers = {
|
597 |
'Content-Type': 'application/json',
|
598 |
}
|
599 |
+
model_dict = get_model_dict(provider)
|
600 |
+
model = model_dict[request.model]
|
601 |
if provider.get("api"):
|
602 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
|
603 |
url = provider['base_url']
|
604 |
|
605 |
messages = []
|
|
|
639 |
else:
|
640 |
messages.append({"role": msg.role, "content": content})
|
641 |
|
|
|
|
|
642 |
payload = {
|
643 |
"model": model,
|
644 |
"messages": messages,
|
|
|
663 |
headers = {
|
664 |
'Content-Type': 'application/json'
|
665 |
}
|
666 |
+
model_dict = get_model_dict(provider)
|
667 |
+
model = model_dict[request.model]
|
668 |
if provider.get("api"):
|
669 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
|
670 |
|
671 |
url = provider['base_url']
|
672 |
|
|
|
698 |
else:
|
699 |
messages.append({"role": msg.role, "content": content})
|
700 |
|
|
|
|
|
701 |
payload = {
|
702 |
"model": model,
|
703 |
"messages": messages,
|
|
|
730 |
headers = {
|
731 |
'Content-Type': 'application/json'
|
732 |
}
|
733 |
+
model_dict = get_model_dict(provider)
|
734 |
+
model = model_dict[request.model]
|
735 |
if provider.get("api"):
|
736 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
|
737 |
|
738 |
url = provider['base_url']
|
739 |
|
|
|
761 |
else:
|
762 |
messages.append({"role": role_map[msg.role], "message": content})
|
763 |
|
|
|
|
|
764 |
chat_history = messages[:-1]
|
765 |
query = messages[-1].get("message")
|
766 |
payload = {
|
|
|
798 |
headers = {
|
799 |
'Content-Type': 'application/json'
|
800 |
}
|
|
|
|
|
|
|
801 |
model_dict = get_model_dict(provider)
|
802 |
model = model_dict[request.model]
|
803 |
+
if provider.get("api"):
|
804 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
|
805 |
+
|
806 |
url = "https://api.cloudflare.com/client/v4/accounts/{cf_account_id}/ai/run/{cf_model_id}".format(cf_account_id=provider['cf_account_id'], cf_model_id=model)
|
807 |
|
808 |
msg = request.messages[-1]
|
|
|
816 |
content = msg.content
|
817 |
name = msg.name
|
818 |
|
|
|
819 |
payload = {
|
820 |
"prompt": content,
|
821 |
}
|
|
|
847 |
headers = {
|
848 |
'Content-Type': 'application/json'
|
849 |
}
|
850 |
+
model_dict = get_model_dict(provider)
|
851 |
+
model = model_dict[request.model]
|
852 |
if provider.get("api"):
|
853 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
|
854 |
|
855 |
url = provider['base_url']
|
856 |
|
|
|
872 |
elif msg.role != "system":
|
873 |
messages.append({"role": msg.role, "content": content})
|
874 |
|
|
|
|
|
875 |
payload = {
|
876 |
"model": model,
|
877 |
"messages": messages,
|
|
|
924 |
model = model_dict[request.model]
|
925 |
headers = {
|
926 |
"content-type": "application/json",
|
927 |
+
"x-api-key": f"{await provider_api_circular_list[provider['provider']].next(model)}",
|
928 |
"anthropic-version": "2023-06-01",
|
929 |
"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15" if "claude-3-5-sonnet" in model else "tools-2024-05-16",
|
930 |
}
|
|
|
1067 |
"Content-Type": "application/json",
|
1068 |
}
|
1069 |
if provider.get("api"):
|
1070 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
|
1071 |
url = provider['base_url']
|
1072 |
url = BaseAPI(url).image_url
|
1073 |
|
|
|
1087 |
# "Content-Type": "multipart/form-data",
|
1088 |
}
|
1089 |
if provider.get("api"):
|
1090 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
|
1091 |
url = provider['base_url']
|
1092 |
url = BaseAPI(url).audio_transcriptions
|
1093 |
|
|
|
1114 |
"Content-Type": "application/json",
|
1115 |
}
|
1116 |
if provider.get("api"):
|
1117 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
|
1118 |
url = provider['base_url']
|
1119 |
url = BaseAPI(url).moderations
|
1120 |
|
|
|
1131 |
"Content-Type": "application/json",
|
1132 |
}
|
1133 |
if provider.get("api"):
|
1134 |
+
headers['Authorization'] = f"Bearer {await provider_api_circular_list[provider['provider']].next(model)}"
|
1135 |
url = provider['base_url']
|
1136 |
url = BaseAPI(url).embeddings
|
1137 |
|
utils.py
CHANGED
@@ -80,13 +80,21 @@ async def get_user_rate_limit(app, api_index: str = None):
|
|
80 |
import asyncio
|
81 |
|
82 |
class ThreadSafeCircularList:
|
83 |
-
def __init__(self, items, rate_limit="
|
84 |
self.items = items
|
85 |
self.index = 0
|
86 |
self.lock = asyncio.Lock()
|
87 |
-
|
|
|
88 |
self.cooling_until = defaultdict(float)
|
89 |
-
self.rate_limits =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
async def set_cooling(self, item: str, cooling_time: int = 60):
|
92 |
"""设置某个 item 进入冷却状态
|
@@ -102,36 +110,58 @@ class ThreadSafeCircularList:
|
|
102 |
# self.requests[item] = []
|
103 |
logger.warning(f"API key {item} 已进入冷却状态,冷却时间 {cooling_time} 秒")
|
104 |
|
105 |
-
async def is_rate_limited(self, item) -> bool:
|
106 |
now = time()
|
107 |
# 检查是否在冷却中
|
108 |
if now < self.cooling_until[item]:
|
109 |
return True
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
# 检查所有速率限制条件
|
112 |
-
for limit_count, limit_period in
|
113 |
-
#
|
114 |
-
recent_requests = sum(1 for req in self.requests[item] if req > now - limit_period)
|
115 |
if recent_requests >= limit_count:
|
116 |
-
logger.warning(f"API key {item} 已达到速率限制 ({limit_count}/{limit_period}秒)")
|
117 |
return True
|
118 |
|
119 |
-
#
|
120 |
-
max_period = max(period for _, period in
|
121 |
-
self.requests[item] = [req for req in self.requests[item] if req > now - max_period]
|
122 |
|
123 |
-
#
|
124 |
-
self.requests[item].append(now)
|
125 |
return False
|
126 |
|
127 |
-
async def next(self):
|
128 |
async with self.lock:
|
129 |
start_index = self.index
|
130 |
while True:
|
131 |
item = self.items[self.index]
|
132 |
self.index = (self.index + 1) % len(self.items)
|
133 |
|
134 |
-
if not await self.is_rate_limited(item):
|
135 |
return item
|
136 |
|
137 |
# 如果已经检查了所有的 API key 都被限制
|
@@ -220,12 +250,12 @@ def update_config(config_data, use_config_url=False):
|
|
220 |
if isinstance(provider_api, str):
|
221 |
provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(
|
222 |
[provider_api],
|
223 |
-
safe_get(provider, "preferences", "
|
224 |
)
|
225 |
if isinstance(provider_api, list):
|
226 |
provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(
|
227 |
provider_api,
|
228 |
-
safe_get(provider, "preferences", "
|
229 |
)
|
230 |
|
231 |
if not provider.get("model"):
|
|
|
80 |
import asyncio
|
81 |
|
82 |
class ThreadSafeCircularList:
|
83 |
+
def __init__(self, items, rate_limit={"default": "999999/min"}):
|
84 |
self.items = items
|
85 |
self.index = 0
|
86 |
self.lock = asyncio.Lock()
|
87 |
+
# 修改为二级字典,第一级是item,第二级是model
|
88 |
+
self.requests = defaultdict(lambda: defaultdict(list))
|
89 |
self.cooling_until = defaultdict(float)
|
90 |
+
self.rate_limits = {}
|
91 |
+
if isinstance(rate_limit, dict):
|
92 |
+
for rate_limit_model, rate_limit_value in rate_limit.items():
|
93 |
+
self.rate_limits[rate_limit_model] = parse_rate_limit(rate_limit_value)
|
94 |
+
elif isinstance(rate_limit, str):
|
95 |
+
self.rate_limits["default"] = parse_rate_limit(rate_limit)
|
96 |
+
else:
|
97 |
+
logger.error(f"Error ThreadSafeCircularList: Unknown rate_limit type: {type(rate_limit)}, rate_limit: {rate_limit}")
|
98 |
|
99 |
async def set_cooling(self, item: str, cooling_time: int = 60):
|
100 |
"""设置某个 item 进入冷却状态
|
|
|
110 |
# self.requests[item] = []
|
111 |
logger.warning(f"API key {item} 已进入冷却状态,冷却时间 {cooling_time} 秒")
|
112 |
|
113 |
+
async def is_rate_limited(self, item, model: str = None) -> bool:
|
114 |
now = time()
|
115 |
# 检查是否在冷却中
|
116 |
if now < self.cooling_until[item]:
|
117 |
return True
|
118 |
|
119 |
+
# 获取适用的速率限制
|
120 |
+
|
121 |
+
if model:
|
122 |
+
model_key = model
|
123 |
+
else:
|
124 |
+
model_key = "default"
|
125 |
+
|
126 |
+
rate_limit = None
|
127 |
+
# 先尝试精确匹配
|
128 |
+
if model and model in self.rate_limits:
|
129 |
+
rate_limit = self.rate_limits[model]
|
130 |
+
else:
|
131 |
+
# 如果没有精确匹配,尝试模糊匹配
|
132 |
+
for limit_model in self.rate_limits:
|
133 |
+
if limit_model != "default" and model and limit_model in model:
|
134 |
+
rate_limit = self.rate_limits[limit_model]
|
135 |
+
break
|
136 |
+
|
137 |
+
# 如果都没匹配到,使用默认值
|
138 |
+
if rate_limit is None:
|
139 |
+
rate_limit = self.rate_limits.get("default", [(999999, 60)]) # 默认限制
|
140 |
+
|
141 |
# 检查所有速率限制条件
|
142 |
+
for limit_count, limit_period in rate_limit:
|
143 |
+
# 使用特定模型的请求记录进行计算
|
144 |
+
recent_requests = sum(1 for req in self.requests[item][model_key] if req > now - limit_period)
|
145 |
if recent_requests >= limit_count:
|
146 |
+
logger.warning(f"API key {item} 对模型 {model_key} 已达到速率限制 ({limit_count}/{limit_period}秒)")
|
147 |
return True
|
148 |
|
149 |
+
# 清理太旧的请求记录
|
150 |
+
max_period = max(period for _, period in rate_limit)
|
151 |
+
self.requests[item][model_key] = [req for req in self.requests[item][model_key] if req > now - max_period]
|
152 |
|
153 |
+
# 记录新的请求
|
154 |
+
self.requests[item][model_key].append(now)
|
155 |
return False
|
156 |
|
157 |
+
async def next(self, model: str = None):
|
158 |
async with self.lock:
|
159 |
start_index = self.index
|
160 |
while True:
|
161 |
item = self.items[self.index]
|
162 |
self.index = (self.index + 1) % len(self.items)
|
163 |
|
164 |
+
if not await self.is_rate_limited(item, model):
|
165 |
return item
|
166 |
|
167 |
# 如果已经检查了所有的 API key 都被限制
|
|
|
250 |
if isinstance(provider_api, str):
|
251 |
provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(
|
252 |
[provider_api],
|
253 |
+
safe_get(provider, "preferences", "api_key_rate_limit", default={"default": "999999/min"})
|
254 |
)
|
255 |
if isinstance(provider_api, list):
|
256 |
provider_api_circular_list[provider['provider']] = ThreadSafeCircularList(
|
257 |
provider_api,
|
258 |
+
safe_get(provider, "preferences", "api_key_rate_limit", default={"default": "999999/min"})
|
259 |
)
|
260 |
|
261 |
if not provider.get("model"):
|