yym68686 commited on
Commit
eefdfa1
·
1 Parent(s): bc31591

🐛 Bug: Fixed the bug with limited concurrency and removed unnecessary asynchronous mutex locks.

Browse files

💰 Sponsors: Thanks to @PowerHunter for the ¥1000 sponsorship, sponsorship information has been added to the README.

Files changed (3) hide show
  1. README.md +2 -2
  2. README_CN.md +2 -2
  3. main.py +33 -46
README.md CHANGED
@@ -315,8 +315,8 @@ curl -X POST http://127.0.0.1:8000/v1/chat/completions \
315
  ## Sponsors
316
 
317
  We thank the following sponsors for their support:
318
- <!-- ¥600 -->
319
- - @PowerHunter: ¥600
320
 
321
  ## How to sponsor us
322
 
 
315
  ## Sponsors
316
 
317
  We thank the following sponsors for their support:
318
+ <!-- ¥1000 -->
319
+ - @PowerHunter: ¥1000
320
 
321
  ## How to sponsor us
322
 
README_CN.md CHANGED
@@ -315,8 +315,8 @@ curl -X POST http://127.0.0.1:8000/v1/chat/completions \
315
  ## 赞助商
316
 
317
  我们感谢以下赞助商的支持:
318
- <!-- ¥600 -->
319
- - @PowerHunter:¥600
320
 
321
  ## 如何赞助我们
322
 
 
315
  ## 赞助商
316
 
317
  我们感谢以下赞助商的支持:
318
+ <!-- ¥1000 -->
319
+ - @PowerHunter:¥1000
320
 
321
  ## 如何赞助我们
322
 
main.py CHANGED
@@ -160,30 +160,24 @@ async def parse_request_body(request: Request):
160
  return None
161
 
162
  class ChannelManager:
163
- def __init__(self, cooldown_period: int = 300): # 默认冷却时间5分钟
164
- self._excluded_models: Dict[str, datetime] = {}
165
- self._lock = asyncio.Lock()
166
  self.cooldown_period = cooldown_period
167
 
168
  async def exclude_model(self, provider: str, model: str):
169
- """将特定渠道下的特定模型添加到排除列表"""
170
- async with self._lock:
171
- model_key = f"{provider}/{model}"
172
- self._excluded_models[model_key] = datetime.now()
173
 
174
  async def is_model_excluded(self, provider: str, model: str) -> bool:
175
- """检查特定渠道下的特定模型是否被排除"""
176
- async with self._lock:
177
- model_key = f"{provider}/{model}"
178
- if model_key not in self._excluded_models:
179
- return False
180
-
181
- excluded_time = self._excluded_models[model_key]
182
- if datetime.now() - excluded_time > timedelta(seconds=self.cooldown_period):
183
- # 已超过冷却时间,移除限制
184
- del self._excluded_models[model_key]
185
- return False
186
- return True
187
 
188
  async def get_available_providers(self, providers: list) -> list:
189
  """过滤出可用的providers,仅排除不可用的模型"""
@@ -541,39 +535,32 @@ class ClientManager:
541
  def __init__(self, pool_size=100):
542
  self.pool_size = pool_size
543
  self.clients = {} # {timeout_value: AsyncClient}
544
- self.locks = {} # {timeout_value: Lock}
545
 
546
  async def init(self, default_config):
547
  self.default_config = default_config
548
 
549
  @asynccontextmanager
550
  async def get_client(self, timeout_value):
551
- # 对同一超时值的客户端加锁
552
- if timeout_value not in self.locks:
553
- self.locks[timeout_value] = asyncio.Lock()
554
-
555
- async with self.locks[timeout_value]:
556
- # 获取或创建指定超时值的客户端
557
- if timeout_value not in self.clients:
558
- timeout = httpx.Timeout(
559
- connect=15.0,
560
- read=timeout_value,
561
- write=30.0,
562
- pool=self.pool_size
563
- )
564
- self.clients[timeout_value] = httpx.AsyncClient(
565
- timeout=timeout,
566
- limits=httpx.Limits(max_connections=self.pool_size),
567
- **self.default_config
568
- )
569
 
570
- try:
571
- yield self.clients[timeout_value]
572
- except Exception as e:
573
- # 如果客户端出现问题,关闭并重新创建
574
- await self.clients[timeout_value].aclose()
575
- del self.clients[timeout_value]
576
- raise e
577
 
578
  async def close(self):
579
  for client in self.clients.values():
@@ -791,7 +778,7 @@ def lottery_scheduling(weights):
791
  def get_provider_rules(model_rule, config, request_model):
792
  provider_rules = []
793
  if model_rule == "all":
794
- # 如果模型名为 all,则返回所有模型
795
  for provider in config["providers"]:
796
  model_dict = get_model_dict(provider)
797
  for model in model_dict.keys():
 
160
  return None
161
 
162
  class ChannelManager:
163
+ def __init__(self, cooldown_period=300):
164
+ self._excluded_models = defaultdict(lambda: None)
 
165
  self.cooldown_period = cooldown_period
166
 
167
  async def exclude_model(self, provider: str, model: str):
168
+ model_key = f"{provider}/{model}"
169
+ self._excluded_models[model_key] = datetime.now()
 
 
170
 
171
  async def is_model_excluded(self, provider: str, model: str) -> bool:
172
+ model_key = f"{provider}/{model}"
173
+ excluded_time = self._excluded_models[model_key]
174
+ if not excluded_time:
175
+ return False
176
+
177
+ if datetime.now() - excluded_time > timedelta(seconds=self.cooldown_period):
178
+ del self._excluded_models[model_key]
179
+ return False
180
+ return True
 
 
 
181
 
182
  async def get_available_providers(self, providers: list) -> list:
183
  """过滤出可用的providers,仅排除不可用的模型"""
 
535
  def __init__(self, pool_size=100):
536
  self.pool_size = pool_size
537
  self.clients = {} # {timeout_value: AsyncClient}
 
538
 
539
  async def init(self, default_config):
540
  self.default_config = default_config
541
 
542
  @asynccontextmanager
543
  async def get_client(self, timeout_value):
544
+ # 直接获取或创建客户端,不使用锁
545
+ if timeout_value not in self.clients:
546
+ timeout = httpx.Timeout(
547
+ connect=15.0,
548
+ read=timeout_value,
549
+ write=30.0,
550
+ pool=self.pool_size
551
+ )
552
+ self.clients[timeout_value] = httpx.AsyncClient(
553
+ timeout=timeout,
554
+ limits=httpx.Limits(max_connections=self.pool_size),
555
+ **self.default_config
556
+ )
 
 
 
 
 
557
 
558
+ try:
559
+ yield self.clients[timeout_value]
560
+ except Exception as e:
561
+ await self.clients[timeout_value].aclose()
562
+ del self.clients[timeout_value]
563
+ raise e
 
564
 
565
  async def close(self):
566
  for client in self.clients.values():
 
778
  def get_provider_rules(model_rule, config, request_model):
779
  provider_rules = []
780
  if model_rule == "all":
781
+ # 如���模型名为 all,则返回所有模型
782
  for provider in config["providers"]:
783
  model_dict = get_model_dict(provider)
784
  for model in model_dict.keys():