webchat / common /token_bucket.py
hhz520's picture
Upload 170 files
61517de
raw
history blame contribute delete
No virus
1.45 kB
import threading
import time
class TokenBucket:
def __init__(self, tpm, timeout=None):
self.capacity = int(tpm) # 令牌桶容量
self.tokens = 0 # 初始令牌数为0
self.rate = int(tpm) / 60 # 令牌每秒生成速率
self.timeout = timeout # 等待令牌超时时间
self.cond = threading.Condition() # 条件变量
self.is_running = True
# 开启令牌生成线程
threading.Thread(target=self._generate_tokens).start()
def _generate_tokens(self):
"""生成令牌"""
while self.is_running:
with self.cond:
if self.tokens < self.capacity:
self.tokens += 1
self.cond.notify() # 通知获取令牌的线程
time.sleep(1 / self.rate)
def get_token(self):
"""获取令牌"""
with self.cond:
while self.tokens <= 0:
flag = self.cond.wait(self.timeout)
if not flag: # 超时
return False
self.tokens -= 1
return True
def close(self):
self.is_running = False
if __name__ == "__main__":
token_bucket = TokenBucket(20, None) # 创建一个每分钟生产20个tokens的令牌桶
# token_bucket = TokenBucket(20, 0.1)
for i in range(3):
if token_bucket.get_token():
print(f"第{i+1}次请求成功")
token_bucket.close()