Spaces:
Runtime error
Runtime error
| """Tests for WayyDB PubSub abstraction layer.""" | |
| import asyncio | |
| import time | |
| import pytest | |
| import sys | |
| sys.path.insert(0, str(__file__).replace("/tests/python/test_pubsub.py", "")) | |
| from api.pubsub import InMemoryPubSub, RedisPubSub, create_pubsub, Message | |
| class TestInMemoryPubSub: | |
| """Tests for InMemoryPubSub backend.""" | |
| def pubsub(self): | |
| return InMemoryPubSub(max_buffer_per_channel=100) | |
| async def test_start_stop(self, pubsub): | |
| await pubsub.start() | |
| assert pubsub._running | |
| await pubsub.stop() | |
| assert not pubsub._running | |
| async def test_publish_returns_sequence(self, pubsub): | |
| await pubsub.start() | |
| seq1 = await pubsub.publish("ticks:AAPL", {"price": 150.0}) | |
| seq2 = await pubsub.publish("ticks:AAPL", {"price": 151.0}) | |
| assert seq1 == 1 | |
| assert seq2 == 2 | |
| await pubsub.stop() | |
| async def test_sequence_per_channel(self, pubsub): | |
| await pubsub.start() | |
| seq_a = await pubsub.publish("ticks:AAPL", {"price": 150.0}) | |
| seq_m = await pubsub.publish("ticks:MSFT", {"price": 380.0}) | |
| assert seq_a == 1 | |
| assert seq_m == 1 # Separate sequence per channel | |
| await pubsub.stop() | |
| async def test_subscribe_receives_messages(self, pubsub): | |
| await pubsub.start() | |
| received = [] | |
| async def callback(msg): | |
| received.append(msg) | |
| await pubsub.subscribe("ticks:AAPL", callback, "test_sub") | |
| await pubsub.publish("ticks:AAPL", {"price": 150.0}) | |
| await pubsub.publish("ticks:AAPL", {"price": 151.0}) | |
| assert len(received) == 2 | |
| assert received[0]["price"] == 150.0 | |
| assert received[1]["price"] == 151.0 | |
| # Messages include metadata | |
| assert received[0]["_seq"] == 1 | |
| assert received[0]["_channel"] == "ticks:AAPL" | |
| await pubsub.stop() | |
| async def test_subscribe_wildcard(self, pubsub): | |
| await pubsub.start() | |
| received = [] | |
| async def callback(msg): | |
| received.append(msg) | |
| await pubsub.subscribe("ticks:*", callback, "wildcard_sub") | |
| await pubsub.publish("ticks:AAPL", {"symbol": "AAPL"}) | |
| await pubsub.publish("ticks:MSFT", {"symbol": "MSFT"}) | |
| await pubsub.publish("quotes:AAPL", {"symbol": "AAPL"}) # Not ticks:* | |
| assert len(received) == 2 | |
| assert received[0]["symbol"] == "AAPL" | |
| assert received[1]["symbol"] == "MSFT" | |
| await pubsub.stop() | |
| async def test_unsubscribe(self, pubsub): | |
| await pubsub.start() | |
| received = [] | |
| async def callback(msg): | |
| received.append(msg) | |
| await pubsub.subscribe("ticks:AAPL", callback, "test_sub") | |
| await pubsub.publish("ticks:AAPL", {"price": 150.0}) | |
| assert len(received) == 1 | |
| await pubsub.unsubscribe("ticks:AAPL", "test_sub") | |
| await pubsub.publish("ticks:AAPL", {"price": 151.0}) | |
| assert len(received) == 1 # No new message | |
| await pubsub.stop() | |
| async def test_multiple_subscribers(self, pubsub): | |
| await pubsub.start() | |
| received_a = [] | |
| received_b = [] | |
| async def cb_a(msg): | |
| received_a.append(msg) | |
| async def cb_b(msg): | |
| received_b.append(msg) | |
| await pubsub.subscribe("ticks:AAPL", cb_a, "sub_a") | |
| await pubsub.subscribe("ticks:AAPL", cb_b, "sub_b") | |
| await pubsub.publish("ticks:AAPL", {"price": 150.0}) | |
| assert len(received_a) == 1 | |
| assert len(received_b) == 1 | |
| await pubsub.stop() | |
| async def test_backpressure_buffer_overflow(self): | |
| pubsub = InMemoryPubSub(max_buffer_per_channel=5) | |
| await pubsub.start() | |
| for i in range(10): | |
| await pubsub.publish("ticks:AAPL", {"price": float(i)}) | |
| # Buffer should only keep last 5 | |
| buf = pubsub.get_channel_buffer("ticks:AAPL") | |
| assert len(buf) == 5 | |
| assert buf[0].data["price"] == 5.0 # Oldest kept | |
| assert buf[-1].data["price"] == 9.0 # Newest | |
| stats = pubsub.get_stats() | |
| assert stats["messages_dropped"] == 5 | |
| await pubsub.stop() | |
| async def test_channel_buffer_replay(self, pubsub): | |
| await pubsub.start() | |
| for i in range(5): | |
| await pubsub.publish("ticks:AAPL", {"price": float(i)}) | |
| # Replay from sequence 0 (all messages) | |
| all_msgs = pubsub.get_channel_buffer("ticks:AAPL", since_seq=0) | |
| assert len(all_msgs) == 5 | |
| # Replay from sequence 3 (only seq 4 and 5) | |
| recent = pubsub.get_channel_buffer("ticks:AAPL", since_seq=3) | |
| assert len(recent) == 2 | |
| assert recent[0].sequence == 4 | |
| await pubsub.stop() | |
| async def test_dead_subscriber_removal(self, pubsub): | |
| await pubsub.start() | |
| async def bad_callback(msg): | |
| raise ConnectionError("WebSocket closed") | |
| received_good = [] | |
| async def good_callback(msg): | |
| received_good.append(msg) | |
| await pubsub.subscribe("ticks:AAPL", bad_callback, "bad_sub") | |
| await pubsub.subscribe("ticks:AAPL", good_callback, "good_sub") | |
| await pubsub.publish("ticks:AAPL", {"price": 150.0}) | |
| # Bad subscriber should be removed, good one still works | |
| assert len(received_good) == 1 | |
| stats = pubsub.get_stats() | |
| assert stats["active_subscriptions"] == 1 | |
| await pubsub.stop() | |
| async def test_publish_batch(self, pubsub): | |
| await pubsub.start() | |
| received = [] | |
| async def callback(msg): | |
| received.append(msg) | |
| await pubsub.subscribe("ticks:AAPL", callback, "batch_sub") | |
| messages = [{"price": float(i)} for i in range(5)] | |
| last_seq = await pubsub.publish_batch("ticks:AAPL", messages) | |
| assert last_seq == 5 | |
| assert len(received) == 5 | |
| await pubsub.stop() | |
| async def test_stats(self, pubsub): | |
| await pubsub.start() | |
| async def noop(msg): | |
| pass | |
| await pubsub.subscribe("ticks:AAPL", noop, "s1") | |
| await pubsub.subscribe("ticks:MSFT", noop, "s2") | |
| await pubsub.publish("ticks:AAPL", {"price": 150.0}) | |
| await pubsub.publish("ticks:MSFT", {"price": 380.0}) | |
| stats = pubsub.get_stats() | |
| assert stats["backend"] == "in_memory" | |
| assert stats["messages_published"] == 2 | |
| assert stats["messages_delivered"] == 2 | |
| assert stats["active_subscriptions"] == 2 | |
| assert stats["channels"] == 2 | |
| await pubsub.stop() | |
| class TestCreatePubSub: | |
| """Tests for the factory function.""" | |
| def test_creates_inmemory_by_default(self): | |
| ps = create_pubsub(None) | |
| assert isinstance(ps, InMemoryPubSub) | |
| def test_creates_inmemory_for_empty_string(self): | |
| ps = create_pubsub("") | |
| assert isinstance(ps, InMemoryPubSub) | |
| def test_creates_redis_with_url(self): | |
| ps = create_pubsub("redis://localhost:6379") | |
| assert isinstance(ps, RedisPubSub) | |
| class TestPubSubPerformance: | |
| """Performance tests for InMemoryPubSub.""" | |
| async def test_publish_throughput(self): | |
| """Test raw publish throughput without subscribers.""" | |
| pubsub = InMemoryPubSub(max_buffer_per_channel=100000) | |
| await pubsub.start() | |
| num_msgs = 50000 | |
| start = time.time() | |
| for i in range(num_msgs): | |
| await pubsub.publish(f"ticks:SYM-{i % 100}", {"price": float(i)}) | |
| elapsed = time.time() - start | |
| rate = num_msgs / elapsed | |
| print(f"\nPublish throughput (no subscribers): {rate:.0f} msgs/sec") | |
| assert rate > 50000 # Should handle at least 50K msgs/sec | |
| await pubsub.stop() | |
| async def test_publish_with_subscribers_throughput(self): | |
| """Test publish throughput with active subscribers.""" | |
| pubsub = InMemoryPubSub(max_buffer_per_channel=100000) | |
| await pubsub.start() | |
| counter = {"count": 0} | |
| async def counting_callback(msg): | |
| counter["count"] += 1 | |
| # Subscribe to 10 channels | |
| for i in range(10): | |
| await pubsub.subscribe(f"ticks:SYM-{i}", counting_callback, f"sub_{i}") | |
| num_msgs = 10000 | |
| start = time.time() | |
| for i in range(num_msgs): | |
| await pubsub.publish(f"ticks:SYM-{i % 10}", {"price": float(i)}) | |
| elapsed = time.time() - start | |
| rate = num_msgs / elapsed | |
| print(f"\nPublish throughput (10 subscribers): {rate:.0f} msgs/sec") | |
| print(f"Messages delivered: {counter['count']}") | |
| assert rate > 5000 | |
| assert counter["count"] == num_msgs | |
| await pubsub.stop() | |
| async def test_many_channels(self): | |
| """Test performance with many channels.""" | |
| pubsub = InMemoryPubSub(max_buffer_per_channel=1000) | |
| await pubsub.start() | |
| num_channels = 1000 | |
| msgs_per_channel = 10 | |
| start = time.time() | |
| for ch in range(num_channels): | |
| for m in range(msgs_per_channel): | |
| await pubsub.publish(f"ticks:SYM-{ch}", {"price": float(m)}) | |
| elapsed = time.time() - start | |
| total = num_channels * msgs_per_channel | |
| rate = total / elapsed | |
| print(f"\nMany channels throughput: {rate:.0f} msgs/sec ({num_channels} channels)") | |
| stats = pubsub.get_stats() | |
| assert stats["channels"] == num_channels | |
| assert stats["messages_published"] == total | |
| await pubsub.stop() | |
| async def test_wildcard_subscriber_performance(self): | |
| """Test wildcard subscriber doesn't destroy performance.""" | |
| pubsub = InMemoryPubSub(max_buffer_per_channel=100000) | |
| await pubsub.start() | |
| counter = {"count": 0} | |
| async def callback(msg): | |
| counter["count"] += 1 | |
| # Wildcard subscriber for all ticks | |
| await pubsub.subscribe("ticks:*", callback, "wildcard") | |
| num_msgs = 10000 | |
| start = time.time() | |
| for i in range(num_msgs): | |
| await pubsub.publish(f"ticks:SYM-{i % 50}", {"price": float(i)}) | |
| elapsed = time.time() - start | |
| rate = num_msgs / elapsed | |
| print(f"\nWildcard subscriber throughput: {rate:.0f} msgs/sec") | |
| assert counter["count"] == num_msgs | |
| assert rate > 3000 | |
| await pubsub.stop() | |
| class TestPubSubStress: | |
| """Stress tests for PubSub under adverse conditions.""" | |
| async def test_concurrent_publish_subscribe(self): | |
| """Test concurrent publishers and subscribers.""" | |
| pubsub = InMemoryPubSub(max_buffer_per_channel=10000) | |
| await pubsub.start() | |
| results = {"total_received": 0} | |
| async def subscriber_cb(msg): | |
| results["total_received"] += 1 | |
| # Register 5 subscribers on different channels | |
| for i in range(5): | |
| await pubsub.subscribe(f"ticks:SYM-{i}", subscriber_cb, f"sub_{i}") | |
| # Concurrent publishers | |
| async def publisher(channel_idx, count): | |
| for j in range(count): | |
| await pubsub.publish(f"ticks:SYM-{channel_idx}", {"price": float(j)}) | |
| tasks = [publisher(i, 200) for i in range(5)] | |
| await asyncio.gather(*tasks) | |
| assert results["total_received"] == 1000 # 5 channels x 200 msgs | |
| await pubsub.stop() | |
| async def test_subscribe_unsubscribe_churn(self): | |
| """Test rapid subscribe/unsubscribe cycles.""" | |
| pubsub = InMemoryPubSub() | |
| await pubsub.start() | |
| async def noop(msg): | |
| pass | |
| for cycle in range(100): | |
| sub_id = f"churn_{cycle}" | |
| await pubsub.subscribe("ticks:AAPL", noop, sub_id) | |
| await pubsub.publish("ticks:AAPL", {"cycle": cycle}) | |
| await pubsub.unsubscribe("ticks:AAPL", sub_id) | |
| stats = pubsub.get_stats() | |
| assert stats["active_subscriptions"] == 0 | |
| assert stats["messages_published"] == 100 | |
| await pubsub.stop() | |
| async def test_buffer_overflow_under_load(self): | |
| """Test backpressure behavior under high load.""" | |
| pubsub = InMemoryPubSub(max_buffer_per_channel=100) | |
| await pubsub.start() | |
| # Publish 1000 messages to a channel with buffer size 100 | |
| for i in range(1000): | |
| await pubsub.publish("ticks:AAPL", {"price": float(i)}) | |
| stats = pubsub.get_stats() | |
| buf = pubsub.get_channel_buffer("ticks:AAPL") | |
| assert len(buf) == 100 # Buffer capped | |
| assert stats["messages_published"] == 1000 | |
| assert stats["messages_dropped"] == 900 | |
| # Buffer should have the latest 100 messages | |
| assert buf[-1].data["price"] == 999.0 | |
| await pubsub.stop() | |