|
import threading |
|
import time |
|
|
|
|
|
class TokenBucket: |
|
def __init__(self, tpm, timeout=None): |
|
self.capacity = int(tpm) |
|
self.tokens = 0 |
|
self.rate = int(tpm) / 60 |
|
self.timeout = timeout |
|
self.cond = threading.Condition() |
|
self.is_running = True |
|
|
|
threading.Thread(target=self._generate_tokens).start() |
|
|
|
def _generate_tokens(self): |
|
"""生成令牌""" |
|
while self.is_running: |
|
with self.cond: |
|
if self.tokens < self.capacity: |
|
self.tokens += 1 |
|
self.cond.notify() |
|
time.sleep(1 / self.rate) |
|
|
|
def get_token(self): |
|
"""获取令牌""" |
|
with self.cond: |
|
while self.tokens <= 0: |
|
flag = self.cond.wait(self.timeout) |
|
if not flag: |
|
return False |
|
self.tokens -= 1 |
|
return True |
|
|
|
def close(self): |
|
self.is_running = False |
|
|
|
|
|
if __name__ == "__main__": |
|
token_bucket = TokenBucket(20, None) |
|
|
|
for i in range(3): |
|
if token_bucket.get_token(): |
|
print(f"第{i+1}次请求成功") |
|
token_bucket.close() |
|
|