🐛 Bug: Fix the bug where weight polling cannot match the model.
Browse files
main.py
CHANGED
@@ -759,7 +759,9 @@ class ModelRequestHandler:
|
|
759 |
weights = safe_get(config, 'api_keys', api_index, "weights")
|
760 |
|
761 |
# 步骤 1: 提取 matching_providers 中的所有 provider 值
|
762 |
-
|
|
|
|
|
763 |
|
764 |
intersection = None
|
765 |
if weights and all_providers:
|
@@ -768,21 +770,25 @@ class ModelRequestHandler:
|
|
768 |
for model_rule in weight_keys:
|
769 |
provider_rules.extend(get_provider_rules(model_rule, config, request_model))
|
770 |
provider_list = get_provider_list(provider_rules, config, request_model)
|
771 |
-
weight_keys = set([provider['provider'] for provider in provider_list])
|
772 |
# print("all_providers", all_providers)
|
773 |
-
# print("weights",
|
|
|
|
|
774 |
# 步骤 3: 计算交集
|
775 |
intersection = all_providers.intersection(weight_keys)
|
|
|
776 |
|
777 |
if weights and intersection:
|
778 |
-
|
|
|
779 |
|
780 |
if scheduling_algorithm == "weighted_round_robin":
|
781 |
-
weighted_provider_name_list = weighted_round_robin(
|
782 |
elif scheduling_algorithm == "lottery":
|
783 |
-
weighted_provider_name_list = lottery_scheduling(
|
784 |
else:
|
785 |
-
weighted_provider_name_list = list(
|
786 |
# print("weighted_provider_name_list", weighted_provider_name_list)
|
787 |
|
788 |
new_matching_providers = []
|
|
|
759 |
weights = safe_get(config, 'api_keys', api_index, "weights")
|
760 |
|
761 |
# 步骤 1: 提取 matching_providers 中的所有 provider 值
|
762 |
+
# print("matching_providers", matching_providers)
|
763 |
+
# print(type(matching_providers[0]['model'][0].keys()), list(matching_providers[0]['model'][0].keys())[0], matching_providers[0]['model'][0].keys())
|
764 |
+
all_providers = set(provider['provider'] + "/" + list(provider['model'][0].keys())[0] for provider in matching_providers)
|
765 |
|
766 |
intersection = None
|
767 |
if weights and all_providers:
|
|
|
770 |
for model_rule in weight_keys:
|
771 |
provider_rules.extend(get_provider_rules(model_rule, config, request_model))
|
772 |
provider_list = get_provider_list(provider_rules, config, request_model)
|
773 |
+
weight_keys = set([provider['provider'] + "/" + list(provider['model'][0].keys())[0] for provider in provider_list])
|
774 |
# print("all_providers", all_providers)
|
775 |
+
# print("weights", weights)
|
776 |
+
# print("weight_keys", weight_keys)
|
777 |
+
|
778 |
# 步骤 3: 计算交集
|
779 |
intersection = all_providers.intersection(weight_keys)
|
780 |
+
# print("intersection", intersection)
|
781 |
|
782 |
if weights and intersection:
|
783 |
+
filtered_weights = {k.split("/")[0]: v for k, v in weights.items() if k in intersection}
|
784 |
+
# print("filtered_weights", filtered_weights)
|
785 |
|
786 |
if scheduling_algorithm == "weighted_round_robin":
|
787 |
+
weighted_provider_name_list = weighted_round_robin(filtered_weights)
|
788 |
elif scheduling_algorithm == "lottery":
|
789 |
+
weighted_provider_name_list = lottery_scheduling(filtered_weights)
|
790 |
else:
|
791 |
+
weighted_provider_name_list = list(filtered_weights.keys())
|
792 |
# print("weighted_provider_name_list", weighted_provider_name_list)
|
793 |
|
794 |
new_matching_providers = []
|