|
|
"""Tests for rate limiting functionality.""" |
|
|
|
|
|
import asyncio |
|
|
import time |
|
|
|
|
|
import pytest |
|
|
|
|
|
from src.tools.rate_limiter import RateLimiter, get_pubmed_limiter, reset_pubmed_limiter |
|
|
|
|
|
|
|
|
class TestRateLimiter: |
|
|
"""Test suite for rate limiter.""" |
|
|
|
|
|
def test_create_limiter_without_api_key(self) -> None: |
|
|
"""Should create 3/sec limiter without API key.""" |
|
|
limiter = RateLimiter(rate="3/second") |
|
|
assert limiter.rate == "3/second" |
|
|
|
|
|
def test_create_limiter_with_api_key(self) -> None: |
|
|
"""Should create 10/sec limiter with API key.""" |
|
|
limiter = RateLimiter(rate="10/second") |
|
|
assert limiter.rate == "10/second" |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_limiter_allows_requests_under_limit(self) -> None: |
|
|
"""Should allow requests under the rate limit.""" |
|
|
limiter = RateLimiter(rate="10/second") |
|
|
|
|
|
|
|
|
for _ in range(3): |
|
|
allowed = await limiter.acquire() |
|
|
assert allowed is True |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_limiter_blocks_when_exceeded(self) -> None: |
|
|
"""Should wait when rate limit exceeded.""" |
|
|
limiter = RateLimiter(rate="2/second") |
|
|
|
|
|
|
|
|
await limiter.acquire() |
|
|
await limiter.acquire() |
|
|
|
|
|
|
|
|
start = time.monotonic() |
|
|
await limiter.acquire() |
|
|
elapsed = time.monotonic() - start |
|
|
|
|
|
|
|
|
assert elapsed >= 0.3 |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_limiter_resets_after_window(self) -> None: |
|
|
"""Rate limit should reset after time window.""" |
|
|
limiter = RateLimiter(rate="5/second") |
|
|
|
|
|
|
|
|
for _ in range(5): |
|
|
await limiter.acquire() |
|
|
|
|
|
|
|
|
await asyncio.sleep(1.1) |
|
|
|
|
|
|
|
|
start = time.monotonic() |
|
|
await limiter.acquire() |
|
|
elapsed = time.monotonic() - start |
|
|
|
|
|
assert elapsed < 0.1 |
|
|
|
|
|
|
|
|
class TestGetPubmedLimiter: |
|
|
"""Test PubMed-specific limiter factory.""" |
|
|
|
|
|
@pytest.fixture(autouse=True) |
|
|
def setup_teardown(self): |
|
|
"""Reset limiter before and after each test.""" |
|
|
reset_pubmed_limiter() |
|
|
yield |
|
|
reset_pubmed_limiter() |
|
|
|
|
|
def test_limiter_without_api_key(self) -> None: |
|
|
"""Should return 3/sec limiter without key.""" |
|
|
limiter = get_pubmed_limiter(api_key=None) |
|
|
assert "3" in limiter.rate |
|
|
|
|
|
def test_limiter_with_api_key(self) -> None: |
|
|
"""Should return 10/sec limiter with key.""" |
|
|
limiter = get_pubmed_limiter(api_key="my-api-key") |
|
|
assert "10" in limiter.rate |
|
|
|
|
|
def test_limiter_is_singleton(self) -> None: |
|
|
"""Same API key should return same limiter instance.""" |
|
|
limiter1 = get_pubmed_limiter(api_key="key1") |
|
|
limiter2 = get_pubmed_limiter(api_key="key1") |
|
|
assert limiter1 is limiter2 |
|
|
|
|
|
def test_different_keys_different_limiters(self) -> None: |
|
|
"""Different API keys should return different limiters.""" |
|
|
limiter1 = get_pubmed_limiter(api_key="key1") |
|
|
limiter2 = get_pubmed_limiter(api_key="key2") |
|
|
|
|
|
|
|
|
|
|
|
assert limiter1 is limiter2 |
|
|
|