wayydb-api / tests /python /test_pubsub.py
rcgalbo's picture
Deploy wayyDB to HuggingFace Spaces
bf20cb7
"""Tests for WayyDB PubSub abstraction layer."""
import asyncio
import time
import pytest
import sys
sys.path.insert(0, str(__file__).replace("/tests/python/test_pubsub.py", ""))
from api.pubsub import InMemoryPubSub, RedisPubSub, create_pubsub, Message
class TestInMemoryPubSub:
"""Tests for InMemoryPubSub backend."""
@pytest.fixture
def pubsub(self):
return InMemoryPubSub(max_buffer_per_channel=100)
@pytest.mark.asyncio
async def test_start_stop(self, pubsub):
await pubsub.start()
assert pubsub._running
await pubsub.stop()
assert not pubsub._running
@pytest.mark.asyncio
async def test_publish_returns_sequence(self, pubsub):
await pubsub.start()
seq1 = await pubsub.publish("ticks:AAPL", {"price": 150.0})
seq2 = await pubsub.publish("ticks:AAPL", {"price": 151.0})
assert seq1 == 1
assert seq2 == 2
await pubsub.stop()
@pytest.mark.asyncio
async def test_sequence_per_channel(self, pubsub):
await pubsub.start()
seq_a = await pubsub.publish("ticks:AAPL", {"price": 150.0})
seq_m = await pubsub.publish("ticks:MSFT", {"price": 380.0})
assert seq_a == 1
assert seq_m == 1 # Separate sequence per channel
await pubsub.stop()
@pytest.mark.asyncio
async def test_subscribe_receives_messages(self, pubsub):
await pubsub.start()
received = []
async def callback(msg):
received.append(msg)
await pubsub.subscribe("ticks:AAPL", callback, "test_sub")
await pubsub.publish("ticks:AAPL", {"price": 150.0})
await pubsub.publish("ticks:AAPL", {"price": 151.0})
assert len(received) == 2
assert received[0]["price"] == 150.0
assert received[1]["price"] == 151.0
# Messages include metadata
assert received[0]["_seq"] == 1
assert received[0]["_channel"] == "ticks:AAPL"
await pubsub.stop()
@pytest.mark.asyncio
async def test_subscribe_wildcard(self, pubsub):
await pubsub.start()
received = []
async def callback(msg):
received.append(msg)
await pubsub.subscribe("ticks:*", callback, "wildcard_sub")
await pubsub.publish("ticks:AAPL", {"symbol": "AAPL"})
await pubsub.publish("ticks:MSFT", {"symbol": "MSFT"})
await pubsub.publish("quotes:AAPL", {"symbol": "AAPL"}) # Not ticks:*
assert len(received) == 2
assert received[0]["symbol"] == "AAPL"
assert received[1]["symbol"] == "MSFT"
await pubsub.stop()
@pytest.mark.asyncio
async def test_unsubscribe(self, pubsub):
await pubsub.start()
received = []
async def callback(msg):
received.append(msg)
await pubsub.subscribe("ticks:AAPL", callback, "test_sub")
await pubsub.publish("ticks:AAPL", {"price": 150.0})
assert len(received) == 1
await pubsub.unsubscribe("ticks:AAPL", "test_sub")
await pubsub.publish("ticks:AAPL", {"price": 151.0})
assert len(received) == 1 # No new message
await pubsub.stop()
@pytest.mark.asyncio
async def test_multiple_subscribers(self, pubsub):
await pubsub.start()
received_a = []
received_b = []
async def cb_a(msg):
received_a.append(msg)
async def cb_b(msg):
received_b.append(msg)
await pubsub.subscribe("ticks:AAPL", cb_a, "sub_a")
await pubsub.subscribe("ticks:AAPL", cb_b, "sub_b")
await pubsub.publish("ticks:AAPL", {"price": 150.0})
assert len(received_a) == 1
assert len(received_b) == 1
await pubsub.stop()
@pytest.mark.asyncio
async def test_backpressure_buffer_overflow(self):
pubsub = InMemoryPubSub(max_buffer_per_channel=5)
await pubsub.start()
for i in range(10):
await pubsub.publish("ticks:AAPL", {"price": float(i)})
# Buffer should only keep last 5
buf = pubsub.get_channel_buffer("ticks:AAPL")
assert len(buf) == 5
assert buf[0].data["price"] == 5.0 # Oldest kept
assert buf[-1].data["price"] == 9.0 # Newest
stats = pubsub.get_stats()
assert stats["messages_dropped"] == 5
await pubsub.stop()
@pytest.mark.asyncio
async def test_channel_buffer_replay(self, pubsub):
await pubsub.start()
for i in range(5):
await pubsub.publish("ticks:AAPL", {"price": float(i)})
# Replay from sequence 0 (all messages)
all_msgs = pubsub.get_channel_buffer("ticks:AAPL", since_seq=0)
assert len(all_msgs) == 5
# Replay from sequence 3 (only seq 4 and 5)
recent = pubsub.get_channel_buffer("ticks:AAPL", since_seq=3)
assert len(recent) == 2
assert recent[0].sequence == 4
await pubsub.stop()
@pytest.mark.asyncio
async def test_dead_subscriber_removal(self, pubsub):
await pubsub.start()
async def bad_callback(msg):
raise ConnectionError("WebSocket closed")
received_good = []
async def good_callback(msg):
received_good.append(msg)
await pubsub.subscribe("ticks:AAPL", bad_callback, "bad_sub")
await pubsub.subscribe("ticks:AAPL", good_callback, "good_sub")
await pubsub.publish("ticks:AAPL", {"price": 150.0})
# Bad subscriber should be removed, good one still works
assert len(received_good) == 1
stats = pubsub.get_stats()
assert stats["active_subscriptions"] == 1
await pubsub.stop()
@pytest.mark.asyncio
async def test_publish_batch(self, pubsub):
await pubsub.start()
received = []
async def callback(msg):
received.append(msg)
await pubsub.subscribe("ticks:AAPL", callback, "batch_sub")
messages = [{"price": float(i)} for i in range(5)]
last_seq = await pubsub.publish_batch("ticks:AAPL", messages)
assert last_seq == 5
assert len(received) == 5
await pubsub.stop()
@pytest.mark.asyncio
async def test_stats(self, pubsub):
await pubsub.start()
async def noop(msg):
pass
await pubsub.subscribe("ticks:AAPL", noop, "s1")
await pubsub.subscribe("ticks:MSFT", noop, "s2")
await pubsub.publish("ticks:AAPL", {"price": 150.0})
await pubsub.publish("ticks:MSFT", {"price": 380.0})
stats = pubsub.get_stats()
assert stats["backend"] == "in_memory"
assert stats["messages_published"] == 2
assert stats["messages_delivered"] == 2
assert stats["active_subscriptions"] == 2
assert stats["channels"] == 2
await pubsub.stop()
class TestCreatePubSub:
"""Tests for the factory function."""
def test_creates_inmemory_by_default(self):
ps = create_pubsub(None)
assert isinstance(ps, InMemoryPubSub)
def test_creates_inmemory_for_empty_string(self):
ps = create_pubsub("")
assert isinstance(ps, InMemoryPubSub)
def test_creates_redis_with_url(self):
ps = create_pubsub("redis://localhost:6379")
assert isinstance(ps, RedisPubSub)
class TestPubSubPerformance:
"""Performance tests for InMemoryPubSub."""
@pytest.mark.asyncio
async def test_publish_throughput(self):
"""Test raw publish throughput without subscribers."""
pubsub = InMemoryPubSub(max_buffer_per_channel=100000)
await pubsub.start()
num_msgs = 50000
start = time.time()
for i in range(num_msgs):
await pubsub.publish(f"ticks:SYM-{i % 100}", {"price": float(i)})
elapsed = time.time() - start
rate = num_msgs / elapsed
print(f"\nPublish throughput (no subscribers): {rate:.0f} msgs/sec")
assert rate > 50000 # Should handle at least 50K msgs/sec
await pubsub.stop()
@pytest.mark.asyncio
async def test_publish_with_subscribers_throughput(self):
"""Test publish throughput with active subscribers."""
pubsub = InMemoryPubSub(max_buffer_per_channel=100000)
await pubsub.start()
counter = {"count": 0}
async def counting_callback(msg):
counter["count"] += 1
# Subscribe to 10 channels
for i in range(10):
await pubsub.subscribe(f"ticks:SYM-{i}", counting_callback, f"sub_{i}")
num_msgs = 10000
start = time.time()
for i in range(num_msgs):
await pubsub.publish(f"ticks:SYM-{i % 10}", {"price": float(i)})
elapsed = time.time() - start
rate = num_msgs / elapsed
print(f"\nPublish throughput (10 subscribers): {rate:.0f} msgs/sec")
print(f"Messages delivered: {counter['count']}")
assert rate > 5000
assert counter["count"] == num_msgs
await pubsub.stop()
@pytest.mark.asyncio
async def test_many_channels(self):
"""Test performance with many channels."""
pubsub = InMemoryPubSub(max_buffer_per_channel=1000)
await pubsub.start()
num_channels = 1000
msgs_per_channel = 10
start = time.time()
for ch in range(num_channels):
for m in range(msgs_per_channel):
await pubsub.publish(f"ticks:SYM-{ch}", {"price": float(m)})
elapsed = time.time() - start
total = num_channels * msgs_per_channel
rate = total / elapsed
print(f"\nMany channels throughput: {rate:.0f} msgs/sec ({num_channels} channels)")
stats = pubsub.get_stats()
assert stats["channels"] == num_channels
assert stats["messages_published"] == total
await pubsub.stop()
@pytest.mark.asyncio
async def test_wildcard_subscriber_performance(self):
"""Test wildcard subscriber doesn't destroy performance."""
pubsub = InMemoryPubSub(max_buffer_per_channel=100000)
await pubsub.start()
counter = {"count": 0}
async def callback(msg):
counter["count"] += 1
# Wildcard subscriber for all ticks
await pubsub.subscribe("ticks:*", callback, "wildcard")
num_msgs = 10000
start = time.time()
for i in range(num_msgs):
await pubsub.publish(f"ticks:SYM-{i % 50}", {"price": float(i)})
elapsed = time.time() - start
rate = num_msgs / elapsed
print(f"\nWildcard subscriber throughput: {rate:.0f} msgs/sec")
assert counter["count"] == num_msgs
assert rate > 3000
await pubsub.stop()
class TestPubSubStress:
"""Stress tests for PubSub under adverse conditions."""
@pytest.mark.asyncio
async def test_concurrent_publish_subscribe(self):
"""Test concurrent publishers and subscribers."""
pubsub = InMemoryPubSub(max_buffer_per_channel=10000)
await pubsub.start()
results = {"total_received": 0}
async def subscriber_cb(msg):
results["total_received"] += 1
# Register 5 subscribers on different channels
for i in range(5):
await pubsub.subscribe(f"ticks:SYM-{i}", subscriber_cb, f"sub_{i}")
# Concurrent publishers
async def publisher(channel_idx, count):
for j in range(count):
await pubsub.publish(f"ticks:SYM-{channel_idx}", {"price": float(j)})
tasks = [publisher(i, 200) for i in range(5)]
await asyncio.gather(*tasks)
assert results["total_received"] == 1000 # 5 channels x 200 msgs
await pubsub.stop()
@pytest.mark.asyncio
async def test_subscribe_unsubscribe_churn(self):
"""Test rapid subscribe/unsubscribe cycles."""
pubsub = InMemoryPubSub()
await pubsub.start()
async def noop(msg):
pass
for cycle in range(100):
sub_id = f"churn_{cycle}"
await pubsub.subscribe("ticks:AAPL", noop, sub_id)
await pubsub.publish("ticks:AAPL", {"cycle": cycle})
await pubsub.unsubscribe("ticks:AAPL", sub_id)
stats = pubsub.get_stats()
assert stats["active_subscriptions"] == 0
assert stats["messages_published"] == 100
await pubsub.stop()
@pytest.mark.asyncio
async def test_buffer_overflow_under_load(self):
"""Test backpressure behavior under high load."""
pubsub = InMemoryPubSub(max_buffer_per_channel=100)
await pubsub.start()
# Publish 1000 messages to a channel with buffer size 100
for i in range(1000):
await pubsub.publish("ticks:AAPL", {"price": float(i)})
stats = pubsub.get_stats()
buf = pubsub.get_channel_buffer("ticks:AAPL")
assert len(buf) == 100 # Buffer capped
assert stats["messages_published"] == 1000
assert stats["messages_dropped"] == 900
# Buffer should have the latest 100 messages
assert buf[-1].data["price"] == 999.0
await pubsub.stop()