yym68686 commited on
Commit
8583e53
·
1 Parent(s): 44cd6f6

🐛 Bug: Fix the bug where weight polling cannot match the model.

Browse files
Files changed (1) hide show
  1. main.py +13 -7
main.py CHANGED
@@ -759,7 +759,9 @@ class ModelRequestHandler:
759
  weights = safe_get(config, 'api_keys', api_index, "weights")
760
 
761
  # 步骤 1: 提取 matching_providers 中的所有 provider 值
762
- all_providers = set(provider['provider'] for provider in matching_providers)
 
 
763
 
764
  intersection = None
765
  if weights and all_providers:
@@ -768,21 +770,25 @@ class ModelRequestHandler:
768
  for model_rule in weight_keys:
769
  provider_rules.extend(get_provider_rules(model_rule, config, request_model))
770
  provider_list = get_provider_list(provider_rules, config, request_model)
771
- weight_keys = set([provider['provider'] for provider in provider_list])
772
  # print("all_providers", all_providers)
773
- # print("weights", weight_keys)
 
 
774
  # 步骤 3: 计算交集
775
  intersection = all_providers.intersection(weight_keys)
 
776
 
777
  if weights and intersection:
778
- weights = dict(filter(lambda item: item[0] in intersection, weights.items()))
 
779
 
780
  if scheduling_algorithm == "weighted_round_robin":
781
- weighted_provider_name_list = weighted_round_robin(weights)
782
  elif scheduling_algorithm == "lottery":
783
- weighted_provider_name_list = lottery_scheduling(weights)
784
  else:
785
- weighted_provider_name_list = list(weights.keys())
786
  # print("weighted_provider_name_list", weighted_provider_name_list)
787
 
788
  new_matching_providers = []
 
759
  weights = safe_get(config, 'api_keys', api_index, "weights")
760
 
761
  # 步骤 1: 提取 matching_providers 中的所有 provider 值
762
+ # print("matching_providers", matching_providers)
763
+ # print(type(matching_providers[0]['model'][0].keys()), list(matching_providers[0]['model'][0].keys())[0], matching_providers[0]['model'][0].keys())
764
+ all_providers = set(provider['provider'] + "/" + list(provider['model'][0].keys())[0] for provider in matching_providers)
765
 
766
  intersection = None
767
  if weights and all_providers:
 
770
  for model_rule in weight_keys:
771
  provider_rules.extend(get_provider_rules(model_rule, config, request_model))
772
  provider_list = get_provider_list(provider_rules, config, request_model)
773
+ weight_keys = set([provider['provider'] + "/" + list(provider['model'][0].keys())[0] for provider in provider_list])
774
  # print("all_providers", all_providers)
775
+ # print("weights", weights)
776
+ # print("weight_keys", weight_keys)
777
+
778
  # 步骤 3: 计算交集
779
  intersection = all_providers.intersection(weight_keys)
780
+ # print("intersection", intersection)
781
 
782
  if weights and intersection:
783
+ filtered_weights = {k.split("/")[0]: v for k, v in weights.items() if k in intersection}
784
+ # print("filtered_weights", filtered_weights)
785
 
786
  if scheduling_algorithm == "weighted_round_robin":
787
+ weighted_provider_name_list = weighted_round_robin(filtered_weights)
788
  elif scheduling_algorithm == "lottery":
789
+ weighted_provider_name_list = lottery_scheduling(filtered_weights)
790
  else:
791
+ weighted_provider_name_list = list(filtered_weights.keys())
792
  # print("weighted_provider_name_list", weighted_provider_name_list)
793
 
794
  new_matching_providers = []