diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..9a6c4f3d73b2081824b43d205949dbac20064621
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,32 @@
+# Python
+__pycache__/
+*.py[cod]
+*.so
+.venv/
+venv/
+.env
+*.egg-info/
+dist/
+build/
+.pytest_cache/
+.mypy_cache/
+.ruff_cache/
+htmlcov/
+.coverage
+
+# Node
+node_modules/
+.next/
+dist/
+.env.local
+
+# Jupyter
+.ipynb_checkpoints/
+
+# OS
+.DS_Store
+*.swp
+
+# Playwright
+.playwright-mcp/
+
diff --git a/CMakeLists.txt b/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..269b8c7bc6e623bb9b4674fefaa171c603ad00e2
--- /dev/null
+++ b/CMakeLists.txt
@@ -0,0 +1,123 @@
+cmake_minimum_required(VERSION 3.20)
+project(wayy_db VERSION 0.1.0 LANGUAGES CXX)
+
+set(CMAKE_CXX_STANDARD 20)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
+
+# Options
+option(WAYY_BUILD_PYTHON "Build Python bindings" ON)
+option(WAYY_BUILD_TESTS "Build unit tests" ON)
+option(WAYY_BUILD_BENCHMARKS "Build benchmarks" OFF)
+option(WAYY_USE_AVX2 "Enable AVX2 SIMD optimizations" ON)
+option(WAYY_USE_LZ4 "Enable LZ4 compression" OFF)
+
+# Compiler flags
+if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
+ add_compile_options(-Wall -Wextra -Wpedantic)
+ if(WAYY_USE_AVX2)
+ add_compile_options(-mavx2 -mfma)
+ endif()
+endif()
+
+# Core library
+add_library(wayy_core STATIC
+ src/types.cpp
+ src/column.cpp
+ src/string_column.cpp
+ src/hash_index.cpp
+ src/table.cpp
+ src/database.cpp
+ src/mmap_file.cpp
+ src/wal.cpp
+ src/ops/aggregations.cpp
+ src/ops/joins.cpp
+ src/ops/window.cpp
+)
+
+target_include_directories(wayy_core PUBLIC
+ $
+
+ High-performance columnar time-series database for quantitative finance
+
+ kdb+ functionality • Pythonic API • Zero-copy NumPy • SIMD-accelerated
+ WayyDB
+
+ Built with C++20 and Python by Wayy Research +
diff --git a/api/__pycache__/main.cpython-310.pyc b/api/__pycache__/main.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..667d1f8244c5c666dd855f3ce50659c2da7a789f Binary files /dev/null and b/api/__pycache__/main.cpython-310.pyc differ diff --git a/api/__pycache__/pubsub.cpython-310.pyc b/api/__pycache__/pubsub.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6106f46805d0b3270d758395cd655fbf2c9c50de Binary files /dev/null and b/api/__pycache__/pubsub.cpython-310.pyc differ diff --git a/api/__pycache__/streaming.cpython-310.pyc b/api/__pycache__/streaming.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..160cb3bb0f261a865483c8129147ec528a76dc77 Binary files /dev/null and b/api/__pycache__/streaming.cpython-310.pyc differ diff --git a/api/kvstore.py b/api/kvstore.py new file mode 100644 index 0000000000000000000000000000000000000000..8feb4f721fc12b0246ff04b16af35c32746ace2c --- /dev/null +++ b/api/kvstore.py @@ -0,0 +1,150 @@ +""" +KV Store - In-memory key-value store with TTL for wayyDB. + +Provides Redis-like KV semantics for future multi-process scaling. +Background eviction runs every 60 seconds. +""" + +import asyncio +import logging +import time +from fnmatch import fnmatch +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +class KVEntry: + """A stored value with optional TTL.""" + __slots__ = ("value", "expires_at", "created_at") + + def __init__(self, value: Any, ttl: Optional[float] = None): + now = time.time() + self.value = value + self.expires_at = now + ttl if ttl else float("inf") + self.created_at = now + + @property + def is_expired(self) -> bool: + return time.time() > self.expires_at + + @property + def ttl_remaining(self) -> Optional[float]: + if self.expires_at == float("inf"): + return None + remaining = self.expires_at - time.time() + return max(0, remaining) + + +class KVStore: + """ + In-memory KV store with TTL and background eviction. + + Thread-safe for single-process async use (GIL + event loop). + """ + + def __init__(self) -> None: + self._data: Dict[str, KVEntry] = {} + self._eviction_task: Optional[asyncio.Task] = None + self._sets: int = 0 + self._gets: int = 0 + self._deletes: int = 0 + self._evictions: int = 0 + + async def start(self) -> None: + """Start the background eviction task.""" + if self._eviction_task is None: + self._eviction_task = asyncio.create_task(self._eviction_loop()) + logger.info("KVStore eviction task started") + + async def stop(self) -> None: + """Stop the background eviction task.""" + if self._eviction_task: + self._eviction_task.cancel() + try: + await self._eviction_task + except asyncio.CancelledError: + pass + self._eviction_task = None + + def set(self, key: str, value: Any, ttl: Optional[float] = None) -> None: + """Set a key with optional TTL (seconds).""" + self._data[key] = KVEntry(value, ttl) + self._sets += 1 + + def get(self, key: str) -> Optional[Any]: + """Get a value by key. Returns None if missing or expired.""" + self._gets += 1 + entry = self._data.get(key) + if entry is None: + return None + if entry.is_expired: + del self._data[key] + self._evictions += 1 + return None + return entry.value + + def delete(self, key: str) -> bool: + """Delete a key. Returns True if existed.""" + existed = key in self._data + if existed: + del self._data[key] + self._deletes += 1 + return existed + + def keys(self, pattern: Optional[str] = None) -> List[str]: + """List keys, optionally filtered by glob pattern.""" + now = time.time() + result = [] + for k, v in self._data.items(): + if v.expires_at > now: + if pattern is None or fnmatch(k, pattern): + result.append(k) + return result + + def stats(self) -> Dict[str, Any]: + """Get store statistics.""" + now = time.time() + active = sum(1 for v in self._data.values() if v.expires_at > now) + return { + "total_keys": len(self._data), + "active_keys": active, + "sets": self._sets, + "gets": self._gets, + "deletes": self._deletes, + "evictions": self._evictions, + } + + async def _eviction_loop(self) -> None: + """Background loop to evict expired entries every 60s.""" + while True: + try: + await asyncio.sleep(60) + count = self._evict_expired() + if count > 0: + logger.debug(f"KVStore evicted {count} expired entries") + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"KVStore eviction error: {e}") + + def _evict_expired(self) -> int: + """Evict all expired entries. Returns count evicted.""" + now = time.time() + expired = [k for k, v in self._data.items() if now > v.expires_at] + for k in expired: + del self._data[k] + self._evictions += len(expired) + return len(expired) + + +# Global singleton +_kv_store: Optional[KVStore] = None + + +def get_kv_store() -> KVStore: + """Get the global KV store instance.""" + global _kv_store + if _kv_store is None: + _kv_store = KVStore() + return _kv_store diff --git a/api/main.py b/api/main.py new file mode 100644 index 0000000000000000000000000000000000000000..0160d1d3129ea19de6116461304cda5467796c58 --- /dev/null +++ b/api/main.py @@ -0,0 +1,1031 @@ +""" +WayyDB REST API - High-performance columnar time-series database service + +Features: +- REST API for table operations, aggregations, joins, window functions +- WebSocket streaming ingestion for real-time tick data +- WebSocket pub/sub for streaming updates to clients +- Efficient batching and append operations +""" +import os +import re +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from contextlib import asynccontextmanager +from typing import Any, Optional, List + +import numpy as np +from fastapi import FastAPI, HTTPException, Query, Request, WebSocket, WebSocketDisconnect +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, ValidationError + +# Import wayyDB +import wayy_db as wdb + +# Import streaming module +from api.streaming import ( + get_streaming_manager, + start_streaming, + stop_streaming, + StreamingManager, +) + +# Import KV store +from api.kvstore import get_kv_store + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Thread pool for running CPU-bound wayyDB operations +executor = ThreadPoolExecutor(max_workers=4) + +# Global database instance +db: Optional[wdb.Database] = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Initialize database and streaming on startup.""" + global db + data_path = os.environ.get("WAYY_DATA_PATH", "/data/wayydb") + os.makedirs(data_path, exist_ok=True) + db = wdb.Database(data_path) + + # Initialize streaming manager with database reference + streaming = get_streaming_manager() + streaming.set_database(db) + await start_streaming() + + # Start KV store eviction + kv = get_kv_store() + await kv.start() + + logger.info(f"WayyDB started with data path: {data_path}") + + yield + + # Cleanup + await kv.stop() + await stop_streaming() + if db: + db.save() + logger.info("WayyDB shutdown complete") + + +app = FastAPI( + title="WayyDB API", + description="High-performance columnar time-series database with kdb+-like functionality", + version="0.1.0", + lifespan=lifespan, +) + +# CORS - configurable via CORS_ORIGINS env var +ALLOWED_ORIGINS = os.getenv("CORS_ORIGINS", "http://localhost:3000").split(",") + +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE"], + allow_headers=["Content-Type", "Authorization"], +) + + +# --- Pydantic Models --- + +class TableCreate(BaseModel): + name: str + sorted_by: Optional[str] = None + + +class ColumnData(BaseModel): + name: str + dtype: str # "int64", "float64", "timestamp", "symbol", "bool" + data: list + + +class TableData(BaseModel): + name: str + columns: list[ColumnData] + sorted_by: Optional[str] = None + + +class AggregationResult(BaseModel): + column: str + operation: str + result: float + + +class JoinRequest(BaseModel): + left_table: str + right_table: str + on: list[str] + as_of: str + window_before: Optional[int] = None # For window join + window_after: Optional[int] = None + + +class WindowRequest(BaseModel): + table: str + column: str + operation: str # mavg, msum, mstd, mmin, mmax, ema + window: Optional[int] = None + alpha: Optional[float] = None # For EMA + + +class AppendData(BaseModel): + """Data to append to an existing table.""" + columns: list[ColumnData] + + +class RowData(BaseModel): + """A single row as key-value pairs.""" + data: dict[str, Any] + + +class TableCreateOLTP(BaseModel): + """Create a table with OLTP schema definition.""" + name: str + columns: list[dict] # [{"name": "id", "dtype": "string"}, ...] + primary_key: Optional[str] = None + sorted_by: Optional[str] = None + + +class IngestTick(BaseModel): + """A single tick for streaming ingestion.""" + symbol: str + price: float + timestamp: Optional[int] = None # Nanoseconds since epoch + volume: Optional[float] = 0.0 + bid: Optional[float] = None + ask: Optional[float] = None + + +class IngestBatch(BaseModel): + """Batch of ticks for streaming ingestion.""" + ticks: list[IngestTick] + + +class SubscribeRequest(BaseModel): + """Subscription filter for WebSocket.""" + symbols: Optional[list[str]] = None # None = all symbols + + +# --- Helper Functions --- + +def dtype_from_string(s: str) -> wdb.DType: + mapping = { + "int64": wdb.DType.Int64, + "float64": wdb.DType.Float64, + "timestamp": wdb.DType.Timestamp, + "symbol": wdb.DType.Symbol, + "bool": wdb.DType.Bool, + } + # These types exist in C++ headers but aren't yet exposed in pybind11 bindings + # "string": _DTYPE_STRING, + # "decimal6": wdb.DType.Decimal6, + if s.lower() not in mapping: + raise ValueError(f"Unknown dtype: {s}. Available: {list(mapping.keys())}") + return mapping[s.lower()] + + +# String DType not yet in pybind11 bindings — use sentinel for safe comparisons +_DTYPE_STRING = getattr(wdb.DType, "String", None) + + +TABLE_NAME_RE = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]{0,63}$') + + +def validate_table_name(name: str) -> str: + if not TABLE_NAME_RE.match(name): + raise HTTPException(400, f"Invalid table name: {name}") + return name + + +def numpy_dtype_for(dtype: wdb.DType): + mapping = { + wdb.DType.Int64: np.int64, + wdb.DType.Float64: np.float64, + wdb.DType.Timestamp: np.int64, + wdb.DType.Symbol: np.uint32, + wdb.DType.Bool: np.uint8, + } + return mapping[dtype] + + +async def run_in_executor(func, *args): + """Run blocking wayyDB operations in thread pool.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(executor, func, *args) + + +# --- Routes --- + +@app.get("/") +async def root(): + return { + "service": "WayyDB API", + "version": "0.1.0", + "status": "healthy", + } + + +@app.get("/health") +async def health(): + return {"status": "healthy", "tables": len(db.tables()) if db else 0} + + +# --- Table Operations --- + +@app.get("/tables") +async def list_tables(): + """List all tables in the database.""" + return {"tables": db.tables()} + + +@app.post("/tables") +async def create_table(table: TableCreate): + """Create a new empty table.""" + if db.has_table(table.name): + raise HTTPException(400, f"Table '{table.name}' already exists") + + t = db.create_table(table.name) + if table.sorted_by: + t.set_sorted_by(table.sorted_by) + db.save() + return {"created": table.name} + + +@app.post("/tables/upload") +async def upload_table(table_data: TableData): + """Upload a complete table with data.""" + if db.has_table(table_data.name): + raise HTTPException(400, f"Table '{table_data.name}' already exists") + + t = wdb.Table(table_data.name) + + for col in table_data.columns: + dtype = dtype_from_string(col.dtype) + np_dtype = numpy_dtype_for(dtype) + arr = np.array(col.data, dtype=np_dtype) + t.add_column_from_numpy(col.name, arr, dtype) + + if table_data.sorted_by: + t.set_sorted_by(table_data.sorted_by) + + db.add_table(t) + db.save() + + return { + "created": table_data.name, + "rows": t.num_rows, + "columns": t.column_names(), + } + + +@app.get("/tables/{name}") +async def get_table_info(name: str): + """Get table metadata.""" + if not db.has_table(name): + raise HTTPException(404, f"Table '{name}' not found") + + t = db[name] + return { + "name": t.name, + "num_rows": t.num_rows, + "num_columns": t.num_columns, + "columns": t.column_names(), + "sorted_by": t.sorted_by, + } + + +@app.get("/tables/{name}/data") +async def get_table_data( + name: str, + limit: int = Query(default=100, le=10000), + offset: int = Query(default=0, ge=0), +): + """Get table data as JSON.""" + if not db.has_table(name): + raise HTTPException(404, f"Table '{name}' not found") + + t = db[name] + end = min(offset + limit, t.num_rows) + + result = {} + for col_name in t.column_names(): + col = t[col_name] + arr = col.to_numpy()[offset:end] + result[col_name] = arr.tolist() + + return { + "table": name, + "offset": offset, + "limit": limit, + "total_rows": t.num_rows, + "data": result, + } + + +@app.delete("/tables/{name}") +async def delete_table(name: str): + """Delete a table.""" + if not db.has_table(name): + raise HTTPException(404, f"Table '{name}' not found") + + db.drop_table(name) + return {"deleted": name} + + +# --- Aggregations --- + +@app.get("/tables/{name}/agg/{column}/{operation}") +async def aggregate(name: str, column: str, operation: str): + """ + Run aggregation on a column. + Operations: sum, avg, min, max, std + """ + if not db.has_table(name): + raise HTTPException(404, f"Table '{name}' not found") + + t = db[name] + if not t.has_column(column): + raise HTTPException(404, f"Column '{column}' not found") + + col = t[column] + + ops_map = { + "sum": wdb.ops.sum, + "avg": wdb.ops.avg, + "min": wdb.ops.min, + "max": wdb.ops.max, + "std": wdb.ops.std, + } + + if operation not in ops_map: + raise HTTPException(400, f"Unknown operation: {operation}") + + # Run in thread pool for concurrency + result = await run_in_executor(ops_map[operation], col) + + return AggregationResult(column=column, operation=operation, result=result) + + +# --- Joins --- + +@app.post("/join/aj") +async def as_of_join(req: JoinRequest): + """ + As-of join: find most recent right row for each left row. + Both tables must be sorted by the as_of column. + """ + if not db.has_table(req.left_table): + raise HTTPException(404, f"Table '{req.left_table}' not found") + if not db.has_table(req.right_table): + raise HTTPException(404, f"Table '{req.right_table}' not found") + + left = db[req.left_table] + right = db[req.right_table] + + def do_join(): + return wdb.ops.aj(left, right, req.on, req.as_of) + + result = await run_in_executor(do_join) + + # Return as dict + data = {} + for col_name in result.column_names(): + data[col_name] = result[col_name].to_numpy().tolist() + + return { + "join_type": "as_of", + "rows": result.num_rows, + "columns": result.column_names(), + "data": data, + } + + +@app.post("/join/wj") +async def window_join(req: JoinRequest): + """ + Window join: find all right rows within time window. + """ + if not db.has_table(req.left_table): + raise HTTPException(404, f"Table '{req.left_table}' not found") + if not db.has_table(req.right_table): + raise HTTPException(404, f"Table '{req.right_table}' not found") + + if req.window_before is None or req.window_after is None: + raise HTTPException(400, "window_before and window_after required for window join") + + left = db[req.left_table] + right = db[req.right_table] + + def do_join(): + return wdb.ops.wj(left, right, req.on, req.as_of, + req.window_before, req.window_after) + + result = await run_in_executor(do_join) + + data = {} + for col_name in result.column_names(): + data[col_name] = result[col_name].to_numpy().tolist() + + return { + "join_type": "window", + "rows": result.num_rows, + "columns": result.column_names(), + "data": data, + } + + +# --- Window Functions --- + +@app.post("/window") +async def window_function(req: WindowRequest): + """ + Apply window function to a column. + Operations: mavg, msum, mstd, mmin, mmax, ema, diff, pct_change + """ + if not db.has_table(req.table): + raise HTTPException(404, f"Table '{req.table}' not found") + + t = db[req.table] + if not t.has_column(req.column): + raise HTTPException(404, f"Column '{req.column}' not found") + + col = t[req.column] + + def do_window(): + if req.operation == "mavg": + return wdb.ops.mavg(col, req.window) + elif req.operation == "msum": + return wdb.ops.msum(col, req.window) + elif req.operation == "mstd": + return wdb.ops.mstd(col, req.window) + elif req.operation == "mmin": + return wdb.ops.mmin(col, req.window) + elif req.operation == "mmax": + return wdb.ops.mmax(col, req.window) + elif req.operation == "ema": + return wdb.ops.ema(col, req.alpha) + elif req.operation == "diff": + return wdb.ops.diff(col, req.window or 1) + elif req.operation == "pct_change": + return wdb.ops.pct_change(col, req.window or 1) + else: + raise ValueError(f"Unknown operation: {req.operation}") + + result = await run_in_executor(do_window) + + return { + "table": req.table, + "column": req.column, + "operation": req.operation, + "result": result.tolist(), + } + + +# --- Append API --- + +@app.post("/tables/{name}/append") +async def append_to_table(name: str, data: AppendData): + """ + Append rows to an existing table. + + This is more efficient than re-uploading the entire table. + The new data must have the same columns as the existing table. + """ + if not db.has_table(name): + raise HTTPException(404, f"Table '{name}' not found") + + existing = db[name] + existing_cols = set(existing.column_names()) + + # Validate columns match + new_cols = {col.name for col in data.columns} + if existing_cols != new_cols: + raise HTTPException( + 400, + f"Column mismatch. Expected: {sorted(existing_cols)}, got: {sorted(new_cols)}" + ) + + # Get existing data + existing_data = {} + for col_name in existing.column_names(): + existing_data[col_name] = existing[col_name].to_numpy() + + # Prepare new data + new_data = {} + for col in data.columns: + dtype = dtype_from_string(col.dtype) + np_dtype = numpy_dtype_for(dtype) + new_data[col.name] = np.array(col.data, dtype=np_dtype) + + # Concatenate + combined = {} + for col_name in existing_cols: + combined[col_name] = np.concatenate([existing_data[col_name], new_data[col_name]]) + + # Get sorted_by before dropping + sorted_by = existing.sorted_by + + # Drop and recreate + db.drop_table(name) + new_table = wdb.from_dict(combined, name=name, sorted_by=sorted_by) + db.add_table(new_table) + db.save() + + return { + "appended": name, + "new_rows": len(data.columns[0].data) if data.columns else 0, + "total_rows": new_table.num_rows, + } + + +# --- OLTP / CRUD API --- + +@app.post("/api/v1/{db_name}/tables") +async def create_oltp_table(db_name: str, schema: TableCreateOLTP): + """Create a table with typed columns and optional primary key.""" + validate_table_name(schema.name) + + if db.has_table(schema.name): + raise HTTPException(400, f"Table '{schema.name}' already exists") + + t = db.create_table(schema.name) + + # Add columns based on schema + for col_def in schema.columns: + col_name = col_def["name"] + dtype_str = col_def["dtype"] + dtype = dtype_from_string(dtype_str) + np_dtype = numpy_dtype_for(dtype) + arr = np.array([], dtype=np_dtype) + t.add_column_from_numpy(col_name, arr, dtype) + + if schema.sorted_by: + t.set_sorted_by(schema.sorted_by) + if schema.primary_key: + t.set_primary_key(schema.primary_key) + + db.save() + return {"created": schema.name, "columns": [c["name"] for c in schema.columns]} + + +@app.post("/api/v1/{db_name}/tables/{table_name}/rows") +async def insert_row(db_name: str, table_name: str, row: RowData): + """Insert a single row into a table.""" + if not db.has_table(table_name): + raise HTTPException(404, f"Table '{table_name}' not found") + + t = db[table_name] + try: + row_idx = t.append_row(row.data) + except Exception as e: + raise HTTPException(400, str(e)) + + return {"table": table_name, "row_index": row_idx} + + +@app.put("/api/v1/{db_name}/tables/{table_name}/rows/{pk}") +async def update_row(db_name: str, table_name: str, pk: str, row: RowData): + """Update a row by primary key.""" + if not db.has_table(table_name): + raise HTTPException(404, f"Table '{table_name}' not found") + + t = db[table_name] + if not t.primary_key: + raise HTTPException(400, "Table has no primary key set") + + pk_dtype = t.column_dtype(t.primary_key) + + try: + if pk_dtype == _DTYPE_STRING: + ok = t.update_row(pk, row.data) + else: + ok = t.update_row(int(pk), row.data) + except Exception as e: + raise HTTPException(400, str(e)) + + if not ok: + raise HTTPException(404, f"Row with pk={pk} not found") + + return {"table": table_name, "pk": pk, "updated": True} + + +@app.delete("/api/v1/{db_name}/tables/{table_name}/rows/{pk}") +async def delete_row(db_name: str, table_name: str, pk: str): + """Soft-delete a row by primary key.""" + if not db.has_table(table_name): + raise HTTPException(404, f"Table '{table_name}' not found") + + t = db[table_name] + if not t.primary_key: + raise HTTPException(400, "Table has no primary key set") + + pk_dtype = t.column_dtype(t.primary_key) + + if pk_dtype == _DTYPE_STRING: + ok = t.delete_row(pk) + else: + ok = t.delete_row(int(pk)) + + if not ok: + raise HTTPException(404, f"Row with pk={pk} not found") + + return {"table": table_name, "pk": pk, "deleted": True} + + +def _read_row_at(t, row_idx: int) -> dict[str, Any]: + """Read a single row from a table by index, returning a dict.""" + row = {} + for col_name in t.column_names(): + if t.has_string_column(col_name): + scol = t.string_column(col_name) + row[col_name] = scol.get(row_idx) + else: + col = t.column(col_name) + arr = col.to_numpy() + val = arr[row_idx] + # Convert numpy types to Python native for JSON serialization + row[col_name] = val.item() if hasattr(val, "item") else val + return row + + +@app.get("/api/v1/{db_name}/tables/{table_name}/rows/{pk}") +async def get_row_by_pk(db_name: str, table_name: str, pk: str): + """Get a single row by primary key.""" + if not db.has_table(table_name): + raise HTTPException(404, f"Table '{table_name}' not found") + + t = db[table_name] + if not t.primary_key: + raise HTTPException(400, "Table has no primary key set") + + pk_dtype = t.column_dtype(t.primary_key) + + if pk_dtype == _DTYPE_STRING: + row_idx = t.find_row(pk) + else: + row_idx = t.find_row(int(pk)) + + if row_idx is None: + raise HTTPException(404, f"Row with pk={pk} not found") + + return {"data": _read_row_at(t, row_idx)} + + +@app.get("/api/v1/{db_name}/tables/{table_name}/rows") +async def filter_rows(db_name: str, table_name: str, request: Request): + """Filter rows by query parameters (col=val). Returns matching row data.""" + if not db.has_table(table_name): + raise HTTPException(404, f"Table '{table_name}' not found") + + t = db[table_name] + params = dict(request.query_params) + limit = int(params.pop("limit", "500")) + + # Intersect filter results across all query params + row_indices = None + for col, val in params.items(): + if not t.has_column(col) and not t.has_string_column(col): + continue + try: + col_dtype = t.column_dtype(col) + if col_dtype == _DTYPE_STRING: + matches = set(t.where_eq(col, val)) + else: + matches = set(t.where_eq(col, int(val))) + except Exception: + continue + row_indices = matches if row_indices is None else row_indices & matches + + # If no filters, return all valid rows + if row_indices is None: + row_indices = set(range(t.num_rows)) + + # Sort and limit + sorted_indices = sorted(row_indices)[:limit] + + rows = [_read_row_at(t, idx) for idx in sorted_indices] + return {"data": rows, "count": len(rows)} + + +@app.post("/api/v1/{db_name}/checkpoint") +async def checkpoint(db_name: str): + """Flush WAL, save all tables to disk, truncate WAL.""" + db.checkpoint() + return {"checkpoint": "ok"} + + +# --- Streaming Ingestion API --- + +@app.post("/ingest/{table}") +async def ingest_tick(table: str, tick: IngestTick): + """ + Ingest a single tick via REST. + + For high-throughput, use the WebSocket endpoint instead. + """ + validate_table_name(table) + streaming = get_streaming_manager() + await streaming.ingest_tick( + table=table, + symbol=tick.symbol, + price=tick.price, + timestamp=tick.timestamp, + volume=tick.volume or 0.0, + bid=tick.bid or tick.price, + ask=tick.ask or tick.price, + ) + return {"ingested": 1, "table": table} + + +@app.post("/ingest/{table}/batch") +async def ingest_batch(table: str, batch: IngestBatch): + """ + Ingest a batch of ticks via REST. + + For high-throughput, use the WebSocket endpoint instead. + """ + validate_table_name(table) + streaming = get_streaming_manager() + ticks = [ + { + "symbol": t.symbol, + "price": t.price, + "timestamp": t.timestamp, + "volume": t.volume or 0.0, + "bid": t.bid or t.price, + "ask": t.ask or t.price, + } + for t in batch.ticks + ] + await streaming.ingest_batch(table=table, ticks=ticks) + return {"ingested": len(ticks), "table": table} + + +# --- WebSocket Endpoints --- + +@app.websocket("/ws/ingest/{table}") +async def ws_ingest(websocket: WebSocket, table: str): + """ + WebSocket endpoint for streaming tick ingestion. + + Send JSON messages with tick data: + { + "symbol": "BTC-USD", + "price": 42150.50, + "timestamp": 1704067200000000000, // Optional, nanoseconds + "volume": 1.5, // Optional + "bid": 42150.00, // Optional + "ask": 42151.00 // Optional + } + + Or batches: + { + "batch": [ + {"symbol": "BTC-USD", "price": 42150.50, ...}, + {"symbol": "ETH-USD", "price": 2250.25, ...} + ] + } + """ + await websocket.accept() + streaming = get_streaming_manager() + + logger.info(f"Ingestion WebSocket connected for table: {table}") + + try: + while True: + data = await websocket.receive_json() + + if "batch" in data: + # Batch ingestion + ticks = data["batch"] + await streaming.ingest_batch(table=table, ticks=ticks) + await websocket.send_json({"ack": len(ticks)}) + else: + # Single tick + await streaming.ingest_tick( + table=table, + symbol=data["symbol"], + price=data["price"], + timestamp=data.get("timestamp"), + volume=data.get("volume", 0.0), + bid=data.get("bid", data["price"]), + ask=data.get("ask", data["price"]), + ) + await websocket.send_json({"ack": 1}) + + except WebSocketDisconnect: + logger.info(f"Ingestion WebSocket disconnected for table: {table}") + except Exception as e: + logger.error(f"Ingestion WebSocket error: {e}") + await websocket.close(code=1011, reason=str(e)) + + +@app.websocket("/ws/subscribe/{table}") +async def ws_subscribe(websocket: WebSocket, table: str): + """ + WebSocket endpoint for subscribing to real-time updates. + + Optionally send a filter message after connecting: + {"symbols": ["BTC-USD", "ETH-USD"]} + + Receives updates as: + { + "symbol": "BTC-USD", + "price": 42150.50, + "bid": 42150.00, + "ask": 42151.00, + "volume": 1.5, + "timestamp": 1704067200000000000, + "table": "ticks" + } + + Or batches during high-throughput: + {"batch": [...]} + """ + await websocket.accept() + streaming = get_streaming_manager() + + # Default: subscribe to all symbols + symbols = None + + # Check for initial filter message (non-blocking) + try: + # Wait briefly for filter message + data = await asyncio.wait_for(websocket.receive_json(), timeout=0.5) + if "symbols" in data: + symbols = data["symbols"] + logger.info(f"Subscription filter: {symbols}") + except asyncio.TimeoutError: + pass + except Exception: + pass + + subscriber = await streaming.subscribe(websocket, table, symbols) + logger.info(f"Subscription WebSocket connected for table: {table}, symbols: {symbols or 'all'}") + + try: + # Keep connection alive, handle any incoming messages + while True: + try: + data = await websocket.receive_json() + # Handle filter updates + if "symbols" in data: + subscriber.symbols = set(data["symbols"]) if data["symbols"] else set() + await websocket.send_json({"filter_updated": list(subscriber.symbols) or "all"}) + except WebSocketDisconnect: + raise + except Exception: + pass + + except WebSocketDisconnect: + logger.info(f"Subscription WebSocket disconnected for table: {table}") + finally: + await streaming.unsubscribe(websocket, table) + + +# --- Streaming Stats --- + +@app.get("/streaming/stats") +async def streaming_stats(): + """Get streaming ingestion and pub/sub statistics.""" + streaming = get_streaming_manager() + return streaming.get_stats() + + +@app.get("/streaming/quote/{table}/{symbol}") +async def get_quote(table: str, symbol: str): + """Get the latest quote for a symbol from the streaming cache.""" + streaming = get_streaming_manager() + quote = streaming.get_latest_quote(table, symbol) + if not quote: + raise HTTPException(404, f"No quote for {symbol} in {table}") + return quote + + +@app.get("/streaming/quotes/{table}") +async def get_all_quotes(table: str): + """Get all latest quotes for a table from the streaming cache.""" + streaming = get_streaming_manager() + return streaming.get_all_quotes(table) + + +@app.get("/streaming/pubsub") +async def pubsub_stats(): + """Get pub/sub backend statistics (channels, sequences, backend type).""" + streaming = get_streaming_manager() + stats = streaming.get_stats() + return stats.get("pubsub", {"backend": "none", "info": "PubSub not configured"}) + + +# --- KV Store API --- + +class KVSetRequest(BaseModel): + """Request body for setting a KV entry.""" + value: Any + ttl: Optional[float] = None # TTL in seconds, None = no expiry + + +@app.post("/kv/{key}") +async def kv_set(key: str, req: KVSetRequest): + """Set a key-value pair with optional TTL.""" + kv = get_kv_store() + kv.set(key, req.value, ttl=req.ttl) + return {"key": key, "ttl": req.ttl} + + +@app.get("/kv/{key}") +async def kv_get(key: str): + """Get a value by key.""" + kv = get_kv_store() + value = kv.get(key) + if value is None: + raise HTTPException(404, f"Key '{key}' not found or expired") + return {"key": key, "value": value} + + +@app.delete("/kv/{key}") +async def kv_delete(key: str): + """Delete a key.""" + kv = get_kv_store() + existed = kv.delete(key) + if not existed: + raise HTTPException(404, f"Key '{key}' not found") + return {"deleted": key} + + +@app.get("/kv") +async def kv_list(pattern: Optional[str] = None): + """List keys, optionally filtered by glob pattern.""" + kv = get_kv_store() + keys = kv.keys(pattern) + return {"keys": keys, "count": len(keys)} + + +@app.get("/kv-stats") +async def kv_stats(): + """Get KV store statistics.""" + kv = get_kv_store() + return kv.stats() + + +# --- General Pub/Sub API --- + +class PubSubPublishRequest(BaseModel): + """Request body for publishing to a channel.""" + data: Any + + +@app.post("/pubsub/publish/{channel}") +async def pubsub_publish(channel: str, req: PubSubPublishRequest): + """Publish a message to a channel.""" + streaming = get_streaming_manager() + # Use the streaming manager's broadcast mechanism + # For general pub/sub, we broadcast to WebSocket subscribers + await streaming.broadcast_to_channel(channel, req.data) + return {"channel": channel, "published": True} + + +@app.websocket("/ws/pubsub") +async def ws_pubsub(websocket: WebSocket): + """ + General pub/sub WebSocket endpoint. + + Send subscription request after connecting: + {"action": "subscribe", "channels": ["prices:*", "trades"]} + + Receives messages as: + {"channel": "prices:BTC-USD", "data": {...}} + """ + await websocket.accept() + streaming = get_streaming_manager() + + subscribed_channels: list[str] = [] + + logger.info("PubSub WebSocket connected") + + try: + while True: + data = await websocket.receive_json() + + action = data.get("action") + if action == "subscribe": + channels = data.get("channels", []) + subscribed_channels.extend(channels) + await websocket.send_json({ + "type": "subscribed", + "channels": subscribed_channels, + }) + elif action == "ping": + await websocket.send_json({"type": "pong"}) + + except WebSocketDisconnect: + logger.info("PubSub WebSocket disconnected") + except Exception as e: + logger.error(f"PubSub WebSocket error: {e}") diff --git a/api/pubsub.py b/api/pubsub.py new file mode 100644 index 0000000000000000000000000000000000000000..eeae1239e0daab139449fcd4ed8baf0dec0648bb --- /dev/null +++ b/api/pubsub.py @@ -0,0 +1,547 @@ +""" +WayyDB PubSub Abstraction Layer + +Provides a pluggable pub/sub transport for real-time tick distribution. +Two backends: + - InMemoryPubSub: Default, zero dependencies, single-process + - RedisPubSub: Optional, requires redis-py, multi-process capable + +Configure via REDIS_URL environment variable: + - Not set or empty: uses InMemoryPubSub + - Set to redis://...: uses RedisPubSub + +Channel naming convention: + ticks:{symbol} - Trade ticks for a symbol + quotes:{symbol} - Quote updates for a symbol + ticks:* - All trade ticks + {table}:{symbol} - Generic table:symbol pattern +""" + +import asyncio +import logging +import time +from abc import ABC, abstractmethod +from collections import defaultdict, deque +from dataclasses import dataclass, field +from typing import Any, Callable, Coroutine, Dict, List, Optional, Set + +logger = logging.getLogger(__name__) + +# Type alias for async callback +AsyncCallback = Callable[[dict], Coroutine[Any, Any, None]] + + +@dataclass +class Message: + """A pub/sub message with metadata.""" + channel: str + data: dict + sequence: int + timestamp: float = field(default_factory=time.time) + + +class PubSubBackend(ABC): + """Abstract pub/sub backend interface. + + Implementations must provide publish, subscribe, and unsubscribe. + This abstraction allows swapping between in-memory, Redis, NATS, etc. + """ + + @abstractmethod + async def publish(self, channel: str, data: dict) -> int: + """Publish a message to a channel. + + Args: + channel: Channel name (e.g., "ticks:AAPL") + data: Message payload + + Returns: + Sequence number of the published message + """ + ... + + @abstractmethod + async def subscribe( + self, + channel: str, + callback: AsyncCallback, + subscriber_id: str = "", + ) -> None: + """Subscribe to a channel with a callback. + + Args: + channel: Channel name or pattern (e.g., "ticks:AAPL" or "ticks:*") + callback: Async function called with each message dict + subscriber_id: Unique identifier for this subscriber + """ + ... + + @abstractmethod + async def unsubscribe(self, channel: str, subscriber_id: str = "") -> None: + """Unsubscribe from a channel. + + Args: + channel: Channel name or pattern + subscriber_id: The subscriber to remove + """ + ... + + @abstractmethod + async def publish_batch(self, channel: str, messages: List[dict]) -> int: + """Publish a batch of messages to a channel. + + Args: + channel: Channel name + messages: List of message payloads + + Returns: + Sequence number of the last message + """ + ... + + @abstractmethod + def get_stats(self) -> dict: + """Get pub/sub statistics.""" + ... + + @abstractmethod + async def start(self) -> None: + """Start the backend (connect, initialize).""" + ... + + @abstractmethod + async def stop(self) -> None: + """Stop the backend (disconnect, cleanup).""" + ... + + +class InMemoryPubSub(PubSubBackend): + """In-process pub/sub using asyncio. + + Features: + - Channel-based routing with wildcard support + - Per-channel sequence numbers + - Ring buffer for backpressure (drops oldest on overflow) + - Concurrent broadcast via asyncio.gather + - Message replay from buffer + """ + + def __init__( + self, + max_buffer_per_channel: int = 10000, + broadcast_timeout: float = 5.0, + ): + self._subscribers: Dict[str, Dict[str, AsyncCallback]] = defaultdict(dict) + self._sequence: Dict[str, int] = defaultdict(int) + self._buffers: Dict[str, deque] = {} + self._max_buffer = max_buffer_per_channel + self._broadcast_timeout = broadcast_timeout + self._stats = { + "messages_published": 0, + "messages_delivered": 0, + "messages_dropped": 0, + "active_subscriptions": 0, + "channels": 0, + } + self._running = False + + async def start(self) -> None: + self._running = True + logger.info("InMemoryPubSub started") + + async def stop(self) -> None: + self._running = False + self._subscribers.clear() + self._buffers.clear() + logger.info("InMemoryPubSub stopped") + + async def publish(self, channel: str, data: dict) -> int: + self._sequence[channel] += 1 + seq = self._sequence[channel] + + msg = Message(channel=channel, data=data, sequence=seq) + + # Buffer the message + if channel not in self._buffers: + self._buffers[channel] = deque(maxlen=self._max_buffer) + buf = self._buffers[channel] + if len(buf) >= self._max_buffer: + self._stats["messages_dropped"] += 1 + buf.append(msg) + + self._stats["messages_published"] += 1 + self._stats["channels"] = len(self._buffers) + + # Deliver to subscribers + await self._deliver(channel, data, seq) + + return seq + + async def publish_batch(self, channel: str, messages: List[dict]) -> int: + last_seq = 0 + for data in messages: + last_seq = await self.publish(channel, data) + return last_seq + + async def subscribe( + self, + channel: str, + callback: AsyncCallback, + subscriber_id: str = "", + ) -> None: + if not subscriber_id: + subscriber_id = f"sub_{id(callback)}" + + self._subscribers[channel][subscriber_id] = callback + self._stats["active_subscriptions"] = sum( + len(subs) for subs in self._subscribers.values() + ) + logger.debug(f"Subscribed {subscriber_id} to {channel}") + + async def unsubscribe(self, channel: str, subscriber_id: str = "") -> None: + if channel in self._subscribers: + if subscriber_id and subscriber_id in self._subscribers[channel]: + del self._subscribers[channel][subscriber_id] + elif not subscriber_id: + self._subscribers[channel].clear() + + if not self._subscribers[channel]: + del self._subscribers[channel] + + self._stats["active_subscriptions"] = sum( + len(subs) for subs in self._subscribers.values() + ) + + async def _deliver(self, channel: str, data: dict, sequence: int) -> None: + """Deliver message to all matching subscribers concurrently.""" + enriched = {**data, "_seq": sequence, "_channel": channel} + + # Collect all matching callbacks + callbacks: List[AsyncCallback] = [] + + # Exact match subscribers + if channel in self._subscribers: + callbacks.extend(self._subscribers[channel].values()) + + # Wildcard subscribers (e.g., "ticks:*" matches "ticks:AAPL") + for pattern, subs in self._subscribers.items(): + if pattern.endswith(":*"): + prefix = pattern[:-1] # "ticks:" + if channel.startswith(prefix) and pattern != channel: + callbacks.extend(subs.values()) + + if not callbacks: + return + + # Concurrent delivery with timeout + dead_callbacks: List[AsyncCallback] = [] + + async def safe_deliver(cb: AsyncCallback) -> None: + try: + await asyncio.wait_for(cb(enriched), timeout=self._broadcast_timeout) + self._stats["messages_delivered"] += 1 + except asyncio.TimeoutError: + logger.warning(f"Subscriber timed out on {channel}") + dead_callbacks.append(cb) + except Exception: + dead_callbacks.append(cb) + + await asyncio.gather(*(safe_deliver(cb) for cb in callbacks)) + + # Remove dead subscribers + for dead_cb in dead_callbacks: + for pattern, subs in list(self._subscribers.items()): + to_remove = [ + sid for sid, cb in subs.items() if cb is dead_cb + ] + for sid in to_remove: + del subs[sid] + logger.debug(f"Removed dead subscriber {sid} from {pattern}") + + if dead_callbacks: + self._stats["active_subscriptions"] = sum( + len(subs) for subs in self._subscribers.values() + ) + + def get_channel_buffer(self, channel: str, since_seq: int = 0) -> List[Message]: + """Get buffered messages for replay. + + Args: + channel: Channel name + since_seq: Only return messages with sequence > since_seq + + Returns: + List of messages for replay + """ + if channel not in self._buffers: + return [] + return [m for m in self._buffers[channel] if m.sequence > since_seq] + + def get_stats(self) -> dict: + return { + "backend": "in_memory", + **self._stats, + "buffer_sizes": {ch: len(buf) for ch, buf in self._buffers.items()}, + } + + +class RedisPubSub(PubSubBackend): + """Redis-backed pub/sub for multi-process deployments. + + Uses Redis pub/sub for real-time delivery and Redis Streams + for message persistence and replay. + + Requires: pip install redis[hiredis] + Configure via REDIS_URL environment variable. + """ + + def __init__(self, redis_url: str, max_stream_len: int = 100000): + self._redis_url = redis_url + self._max_stream_len = max_stream_len + self._redis = None + self._pubsub = None + self._subscribers: Dict[str, Dict[str, AsyncCallback]] = defaultdict(dict) + self._sequence: Dict[str, int] = defaultdict(int) + self._listener_task: Optional[asyncio.Task] = None + self._running = False + self._stats = { + "messages_published": 0, + "messages_delivered": 0, + "messages_dropped": 0, + "active_subscriptions": 0, + "channels": 0, + "redis_connected": False, + } + + async def start(self) -> None: + try: + import redis.asyncio as aioredis + except ImportError: + raise ImportError( + "redis package required for RedisPubSub. " + "Install with: pip install redis[hiredis]" + ) + + self._redis = aioredis.from_url( + self._redis_url, + decode_responses=True, + socket_connect_timeout=5, + retry_on_timeout=True, + ) + + # Test connection + await self._redis.ping() + self._stats["redis_connected"] = True + + self._pubsub = self._redis.pubsub() + self._running = True + self._listener_task = asyncio.create_task(self._listen_loop()) + + logger.info(f"RedisPubSub connected to {self._redis_url}") + + async def stop(self) -> None: + self._running = False + + if self._listener_task: + self._listener_task.cancel() + try: + await self._listener_task + except asyncio.CancelledError: + pass + + if self._pubsub: + await self._pubsub.unsubscribe() + await self._pubsub.close() + + if self._redis: + await self._redis.close() + + self._stats["redis_connected"] = False + logger.info("RedisPubSub stopped") + + async def publish(self, channel: str, data: dict) -> int: + import json + + self._sequence[channel] += 1 + seq = self._sequence[channel] + + enriched = {**data, "_seq": seq, "_ts": time.time()} + payload = json.dumps(enriched) + + # Publish to Redis pub/sub channel + await self._redis.publish(f"wayy:{channel}", payload) + + # Also write to Redis Stream for persistence/replay + stream_key = f"wayy:stream:{channel}" + await self._redis.xadd( + stream_key, + {"data": payload}, + maxlen=self._max_stream_len, + ) + + self._stats["messages_published"] += 1 + return seq + + async def publish_batch(self, channel: str, messages: List[dict]) -> int: + import json + + pipe = self._redis.pipeline() + last_seq = 0 + + for data in messages: + self._sequence[channel] += 1 + seq = self._sequence[channel] + last_seq = seq + + enriched = {**data, "_seq": seq, "_ts": time.time()} + payload = json.dumps(enriched) + + pipe.publish(f"wayy:{channel}", payload) + + stream_key = f"wayy:stream:{channel}" + pipe.xadd(stream_key, {"data": payload}, maxlen=self._max_stream_len) + + await pipe.execute() + self._stats["messages_published"] += len(messages) + return last_seq + + async def subscribe( + self, + channel: str, + callback: AsyncCallback, + subscriber_id: str = "", + ) -> None: + if not subscriber_id: + subscriber_id = f"sub_{id(callback)}" + + is_new_channel = channel not in self._subscribers or not self._subscribers[channel] + self._subscribers[channel][subscriber_id] = callback + + if is_new_channel and self._pubsub: + if channel.endswith(":*"): + await self._pubsub.psubscribe(f"wayy:{channel}") + else: + await self._pubsub.subscribe(f"wayy:{channel}") + + self._stats["active_subscriptions"] = sum( + len(subs) for subs in self._subscribers.values() + ) + self._stats["channels"] = len(self._subscribers) + + async def unsubscribe(self, channel: str, subscriber_id: str = "") -> None: + if channel in self._subscribers: + if subscriber_id and subscriber_id in self._subscribers[channel]: + del self._subscribers[channel][subscriber_id] + elif not subscriber_id: + self._subscribers[channel].clear() + + if not self._subscribers[channel]: + del self._subscribers[channel] + if self._pubsub: + if channel.endswith(":*"): + await self._pubsub.punsubscribe(f"wayy:{channel}") + else: + await self._pubsub.unsubscribe(f"wayy:{channel}") + + self._stats["active_subscriptions"] = sum( + len(subs) for subs in self._subscribers.values() + ) + + async def _listen_loop(self) -> None: + """Background task that listens for Redis pub/sub messages.""" + import json + + while self._running: + try: + message = await self._pubsub.get_message( + ignore_subscribe_messages=True, timeout=0.1 + ) + if message is None: + await asyncio.sleep(0.01) + continue + + if message["type"] not in ("message", "pmessage"): + continue + + raw_channel = message.get("channel", "") + # Strip "wayy:" prefix + if raw_channel.startswith("wayy:"): + channel = raw_channel[5:] + else: + channel = raw_channel + + data = json.loads(message["data"]) + + # Deliver to local subscribers + await self._deliver_local(channel, data) + + except asyncio.CancelledError: + raise + except Exception as e: + logger.error(f"Redis listener error: {e}") + await asyncio.sleep(1.0) + + async def _deliver_local(self, channel: str, data: dict) -> None: + """Deliver a received message to local subscribers.""" + callbacks: List[AsyncCallback] = [] + + if channel in self._subscribers: + callbacks.extend(self._subscribers[channel].values()) + + # Wildcard matching + for pattern, subs in self._subscribers.items(): + if pattern.endswith(":*"): + prefix = pattern[:-1] + if channel.startswith(prefix) and pattern != channel: + callbacks.extend(subs.values()) + + for cb in callbacks: + try: + await asyncio.wait_for(cb(data), timeout=5.0) + self._stats["messages_delivered"] += 1 + except Exception: + self._stats["messages_dropped"] += 1 + + async def replay( + self, channel: str, since_id: str = "0-0", count: int = 1000 + ) -> List[dict]: + """Replay messages from Redis Stream. + + Args: + channel: Channel name + since_id: Redis Stream ID to start from + count: Maximum messages to return + + Returns: + List of message dicts + """ + import json + + stream_key = f"wayy:stream:{channel}" + messages = await self._redis.xrange(stream_key, min=since_id, count=count) + + return [json.loads(entry["data"]) for _id, entry in messages] + + def get_stats(self) -> dict: + return { + "backend": "redis", + "redis_url": self._redis_url.split("@")[-1] if "@" in self._redis_url else self._redis_url, + **self._stats, + } + + +def create_pubsub(redis_url: Optional[str] = None) -> PubSubBackend: + """Factory function to create the appropriate PubSub backend. + + Args: + redis_url: Redis URL. If None/empty, uses InMemoryPubSub. + + Returns: + PubSubBackend instance + """ + if redis_url: + logger.info(f"Using RedisPubSub backend") + return RedisPubSub(redis_url=redis_url) + else: + logger.info("Using InMemoryPubSub backend (set REDIS_URL for Redis)") + return InMemoryPubSub() diff --git a/api/requirements.txt b/api/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b7ac87370a76ddab4ef2b42536e4d5f7199f884e --- /dev/null +++ b/api/requirements.txt @@ -0,0 +1,6 @@ +fastapi>=0.109.0 +uvicorn[standard]>=0.27.0 +numpy>=1.20 +pydantic>=2.0 +websockets>=12.0 +redis[hiredis]>=5.0 diff --git a/api/streaming.py b/api/streaming.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddcddddd6e162e118cc58bb4a3c8e0b79445420 --- /dev/null +++ b/api/streaming.py @@ -0,0 +1,553 @@ +""" +WayyDB Streaming Module - Real-time data ingestion and pub/sub + +Provides: +- WebSocket ingestion endpoint for real-time tick data +- Pub/Sub subscriptions via pluggable backend (in-memory or Redis) +- Efficient batching and append operations +- In-memory buffers with periodic flush to persistent storage +- Backpressure handling and sequence numbers + +Configuration via environment variables: +- FLUSH_INTERVAL: Seconds between flushes to disk (default: 1.0) +- MAX_BUFFER_SIZE: Max ticks in buffer before force flush (default: 10000) +- BROADCAST_INTERVAL: Seconds between subscriber broadcasts (default: 0.05) +- REDIS_URL: Optional Redis URL for distributed pub/sub +""" + +import asyncio +import logging +import os +import threading +import time +from collections import defaultdict +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Set + +import numpy as np +from fastapi import WebSocket + +from api.pubsub import PubSubBackend, create_pubsub + +logger = logging.getLogger(__name__) + +# Configuration from environment +DEFAULT_FLUSH_INTERVAL = float(os.getenv("FLUSH_INTERVAL", "1.0")) +DEFAULT_MAX_BUFFER_SIZE = int(os.getenv("MAX_BUFFER_SIZE", "10000")) +DEFAULT_BROADCAST_INTERVAL = float(os.getenv("BROADCAST_INTERVAL", "0.05")) + + +@dataclass +class TickBuffer: + """Buffer for incoming tick data before flush to table.""" + timestamps: List[int] = field(default_factory=list) + symbols: List[str] = field(default_factory=list) + prices: List[float] = field(default_factory=list) + volumes: List[float] = field(default_factory=list) + bids: List[float] = field(default_factory=list) + asks: List[float] = field(default_factory=list) + + def append(self, timestamp: int, symbol: str, price: float, + volume: float = 0.0, bid: float = 0.0, ask: float = 0.0): + self.timestamps.append(timestamp) + self.symbols.append(symbol) + self.prices.append(price) + self.volumes.append(volume) + self.bids.append(bid if bid else price) + self.asks.append(ask if ask else price) + + def __len__(self): + return len(self.timestamps) + + def clear(self): + self.timestamps.clear() + self.symbols.clear() + self.prices.clear() + self.volumes.clear() + self.bids.clear() + self.asks.clear() + + def to_columnar(self) -> Dict[str, np.ndarray]: + """Convert to columnar format for WayyDB.""" + return { + "timestamp": np.array(self.timestamps, dtype=np.int64), + "symbol": np.array([hash(s) % (2**32) for s in self.symbols], dtype=np.uint32), + "price": np.array(self.prices, dtype=np.float64), + "volume": np.array(self.volumes, dtype=np.float64), + "bid": np.array(self.bids, dtype=np.float64), + "ask": np.array(self.asks, dtype=np.float64), + } + + +@dataclass +class Subscriber: + """A WebSocket subscriber to data updates.""" + websocket: WebSocket + symbols: Set[str] = field(default_factory=set) # Empty = all symbols + subscriber_id: str = "" + created_at: float = field(default_factory=time.time) + messages_sent: int = 0 + + +class StreamingManager: + """ + Manages streaming data ingestion and pub/sub distribution. + + Features: + - Buffer incoming ticks in memory + - Publish to PubSub channels (in-memory or Redis) + - Broadcast to WebSocket subscribers via PubSub callbacks + - Periodic flush to WayyDB tables (atomic swap, no gap) + - Thread-safe operations via threading.Lock + """ + + def __init__( + self, + flush_interval: float = DEFAULT_FLUSH_INTERVAL, + max_buffer_size: int = DEFAULT_MAX_BUFFER_SIZE, + batch_broadcast_interval: float = DEFAULT_BROADCAST_INTERVAL, + pubsub: Optional[PubSubBackend] = None, + ): + self.flush_interval = flush_interval + self.max_buffer_size = max_buffer_size + self.batch_broadcast_interval = batch_broadcast_interval + + # PubSub backend (in-memory default, Redis optional) + self._pubsub = pubsub + + # Tick buffers - one per table + self._buffers: Dict[str, TickBuffer] = defaultdict(TickBuffer) + + # WebSocket subscribers - one list per table + self._subscribers: Dict[str, List[Subscriber]] = defaultdict(list) + + # Latest quotes cache (for new subscribers) + self._latest_quotes: Dict[str, Dict[str, Any]] = {} + + # Pending broadcasts (batched for efficiency) + self._pending_broadcasts: Dict[str, List[Dict]] = defaultdict(list) + + # Statistics + self._stats = { + "ticks_received": 0, + "ticks_flushed": 0, + "broadcasts_sent": 0, + "active_subscribers": 0, + "flush_count": 0, + "start_time": None, + } + + # Background tasks + self._running = False + self._flush_task: Optional[asyncio.Task] = None + self._broadcast_task: Optional[asyncio.Task] = None + + # Database reference (set by API) + self._db = None + + # FIX: Use threading.Lock for thread safety with ThreadPoolExecutor + self._lock = threading.Lock() + + def set_database(self, db): + """Set the database reference for flushing.""" + self._db = db + + def set_pubsub(self, pubsub: PubSubBackend): + """Set the pub/sub backend.""" + self._pubsub = pubsub + + async def start(self): + """Start background flush and broadcast tasks.""" + if self._running: + return + + self._running = True + self._stats["start_time"] = datetime.now(timezone.utc).isoformat() + + # Start PubSub backend if provided + if self._pubsub: + await self._pubsub.start() + + self._flush_task = asyncio.create_task(self._flush_loop()) + self._broadcast_task = asyncio.create_task(self._broadcast_loop()) + + logger.info("StreamingManager started") + + async def stop(self): + """Stop background tasks and flush remaining data.""" + if not self._running: + return + + self._running = False + + if self._flush_task: + self._flush_task.cancel() + try: + await self._flush_task + except asyncio.CancelledError: + pass + + if self._broadcast_task: + self._broadcast_task.cancel() + try: + await self._broadcast_task + except asyncio.CancelledError: + pass + + # Final flush + await self._flush_all() + + # Stop PubSub backend + if self._pubsub: + await self._pubsub.stop() + + logger.info("StreamingManager stopped") + + async def ingest_tick( + self, + table: str, + symbol: str, + price: float, + timestamp: Optional[int] = None, + volume: float = 0.0, + bid: float = 0.0, + ask: float = 0.0, + ): + """Ingest a single tick.""" + if timestamp is None: + timestamp = int(datetime.now(timezone.utc).timestamp() * 1e9) + + # Add to buffer (thread-safe) + with self._lock: + self._buffers[table].append( + timestamp=timestamp, + symbol=symbol, + price=price, + volume=volume, + bid=bid, + ask=ask, + ) + self._stats["ticks_received"] += 1 + + # Build quote message + quote = { + "symbol": symbol, + "price": price, + "bid": bid or price, + "ask": ask or price, + "volume": volume, + "timestamp": timestamp, + "table": table, + } + self._latest_quotes[f"{table}:{symbol}"] = quote + + # Publish to PubSub channel + if self._pubsub: + channel = f"{table}:{symbol}" + await self._pubsub.publish(channel, quote) + + # Queue for WebSocket broadcast + self._pending_broadcasts[table].append(quote) + + # Force flush if buffer too large + if len(self._buffers[table]) >= self.max_buffer_size: + await self._flush_table(table) + + async def ingest_batch( + self, + table: str, + ticks: List[Dict[str, Any]], + ): + """Ingest a batch of ticks efficiently.""" + quotes_by_channel: Dict[str, List[dict]] = defaultdict(list) + + with self._lock: + buffer = self._buffers[table] + for tick in ticks: + timestamp = tick.get("timestamp") + if timestamp is None: + timestamp = int(datetime.now(timezone.utc).timestamp() * 1e9) + + buffer.append( + timestamp=timestamp, + symbol=tick["symbol"], + price=tick["price"], + volume=tick.get("volume", 0.0), + bid=tick.get("bid", tick["price"]), + ask=tick.get("ask", tick["price"]), + ) + + quote = { + "symbol": tick["symbol"], + "price": tick["price"], + "bid": tick.get("bid", tick["price"]), + "ask": tick.get("ask", tick["price"]), + "volume": tick.get("volume", 0.0), + "timestamp": timestamp, + "table": table, + } + self._latest_quotes[f"{table}:{tick['symbol']}"] = quote + self._pending_broadcasts[table].append(quote) + + channel = f"{table}:{tick['symbol']}" + quotes_by_channel[channel].append(quote) + + self._stats["ticks_received"] += len(ticks) + + # Batch publish to PubSub channels + if self._pubsub: + for channel, channel_quotes in quotes_by_channel.items(): + await self._pubsub.publish_batch(channel, channel_quotes) + + # Force flush if buffer too large + if len(self._buffers[table]) >= self.max_buffer_size: + await self._flush_table(table) + + async def subscribe(self, websocket: WebSocket, table: str, symbols: Optional[List[str]] = None): + """Add a WebSocket subscriber to a table's updates.""" + sub_id = f"ws_{id(websocket)}" + subscriber = Subscriber( + websocket=websocket, + symbols=set(symbols) if symbols else set(), + subscriber_id=sub_id, + ) + + self._subscribers[table].append(subscriber) + self._stats["active_subscribers"] = sum(len(s) for s in self._subscribers.values()) + + # Send current latest quotes to new subscriber + for key, quote in self._latest_quotes.items(): + if key.startswith(f"{table}:"): + symbol = key.split(":", 1)[1] + if not subscriber.symbols or symbol in subscriber.symbols: + try: + await websocket.send_json(quote) + except Exception: + pass + + logger.info(f"New subscriber for {table}, symbols={symbols or 'all'}") + return subscriber + + async def unsubscribe(self, websocket: WebSocket, table: str): + """Remove a subscriber.""" + self._subscribers[table] = [ + s for s in self._subscribers[table] + if s.websocket != websocket + ] + self._stats["active_subscribers"] = sum(len(s) for s in self._subscribers.values()) + + async def _flush_loop(self): + """Background task to periodically flush buffers.""" + while self._running: + try: + await asyncio.sleep(self.flush_interval) + await self._flush_all() + except asyncio.CancelledError: + raise + except Exception as e: + logger.error(f"Flush error: {e}") + + async def _flush_all(self): + """Flush all buffers to database.""" + with self._lock: + tables = list(self._buffers.keys()) + + for table in tables: + await self._flush_table(table) + + async def _flush_table(self, table: str): + """Flush a single table's buffer to database. + + FIX: Atomic table swap - build new table first, then replace. + The old table remains readable until the swap completes. + """ + if self._db is None: + return + + with self._lock: + buffer = self._buffers[table] + if len(buffer) == 0: + return + + # Get columnar data and clear buffer + data = buffer.to_columnar() + count = len(buffer) + buffer.clear() + + try: + import wayy_db as wdb + + if self._db.has_table(table): + existing = self._db[table] + + # Read existing data + existing_data = {} + for col_name in existing.column_names(): + existing_data[col_name] = existing[col_name].to_numpy() + + # Concatenate + combined = {} + for col_name, new_arr in data.items(): + if col_name in existing_data: + combined[col_name] = np.concatenate([existing_data[col_name], new_arr]) + else: + combined[col_name] = new_arr + + # FIX: Build new table FIRST, then atomic swap + new_table = wdb.from_dict(combined, name=table, sorted_by="timestamp") + self._db.drop_table(table) + self._db.add_table(new_table) + else: + new_table = wdb.from_dict(data, name=table, sorted_by="timestamp") + self._db.add_table(new_table) + + self._db.save() + + self._stats["ticks_flushed"] += count + self._stats["flush_count"] += 1 + + logger.debug(f"Flushed {count} ticks to {table}") + + except Exception as e: + logger.error(f"Failed to flush {table}: {e}") + # Re-add data to buffer on failure + with self._lock: + buf = self._buffers[table] + for i in range(len(data["timestamp"])): + buf.timestamps.append(int(data["timestamp"][i])) + buf.symbols.append(f"unknown") # Symbol hash lost, but data preserved + buf.prices.append(float(data["price"][i])) + buf.volumes.append(float(data["volume"][i])) + buf.bids.append(float(data["bid"][i])) + buf.asks.append(float(data["ask"][i])) + + async def _broadcast_loop(self): + """Background task to batch-broadcast updates to WebSocket subscribers.""" + while self._running: + try: + await asyncio.sleep(self.batch_broadcast_interval) + await self._broadcast_pending() + except asyncio.CancelledError: + raise + except Exception as e: + logger.error(f"Broadcast error: {e}") + + async def _broadcast_pending(self): + """Broadcast pending updates to all subscribers. + + FIX: Uses asyncio.gather for concurrent WebSocket sends. + One slow subscriber no longer blocks all others. + """ + # Swap out pending broadcasts atomically + pending = dict(self._pending_broadcasts) + self._pending_broadcasts = defaultdict(list) + + for table, quotes in pending.items(): + if not quotes: + continue + + subscribers = self._subscribers.get(table, []) + if not subscribers: + continue + + # Build send tasks for all subscribers concurrently + send_tasks = [] + sub_task_map: List[Subscriber] = [] + + for sub in subscribers: + if sub.symbols: + filtered = [q for q in quotes if q["symbol"] in sub.symbols] + else: + filtered = quotes + + if not filtered: + continue + + if len(filtered) == 1: + payload = filtered[0] + else: + payload = {"batch": filtered} + + send_tasks.append(self._safe_send(sub.websocket, payload)) + sub_task_map.append(sub) + + if not send_tasks: + continue + + # FIX: Concurrent sends via asyncio.gather + results = await asyncio.gather(*send_tasks, return_exceptions=True) + + dead_subs = [] + for sub, result in zip(sub_task_map, results): + if isinstance(result, Exception): + dead_subs.append(sub) + else: + count = len(quotes) if not sub.symbols else len( + [q for q in quotes if q["symbol"] in sub.symbols] + ) + sub.messages_sent += count + self._stats["broadcasts_sent"] += count + + # Remove dead subscribers + for sub in dead_subs: + if sub in self._subscribers[table]: + self._subscribers[table].remove(sub) + + @staticmethod + async def _safe_send(websocket: WebSocket, payload: Any) -> None: + """Send JSON to a WebSocket with timeout.""" + await asyncio.wait_for(websocket.send_json(payload), timeout=5.0) + + def get_stats(self) -> Dict[str, Any]: + """Get streaming statistics.""" + stats = { + **self._stats, + "buffer_sizes": {t: len(b) for t, b in self._buffers.items()}, + "subscriber_counts": {t: len(s) for t, s in self._subscribers.items()}, + "latest_quotes": len(self._latest_quotes), + "running": self._running, + } + if self._pubsub: + stats["pubsub"] = self._pubsub.get_stats() + return stats + + def get_latest_quote(self, table: str, symbol: str) -> Optional[Dict[str, Any]]: + """Get the latest quote for a symbol.""" + return self._latest_quotes.get(f"{table}:{symbol}") + + def get_all_quotes(self, table: str) -> Dict[str, Dict[str, Any]]: + """Get all latest quotes for a table.""" + prefix = f"{table}:" + return { + k.split(":", 1)[1]: v + for k, v in self._latest_quotes.items() + if k.startswith(prefix) + } + + +# Global streaming manager instance +_streaming_manager: Optional[StreamingManager] = None + + +def get_streaming_manager() -> StreamingManager: + """Get or create the global streaming manager.""" + global _streaming_manager + if _streaming_manager is None: + redis_url = os.getenv("REDIS_URL", "") + pubsub = create_pubsub(redis_url if redis_url else None) + _streaming_manager = StreamingManager(pubsub=pubsub) + return _streaming_manager + + +async def start_streaming(): + """Start the global streaming manager.""" + manager = get_streaming_manager() + await manager.start() + + +async def stop_streaming(): + """Stop the global streaming manager.""" + global _streaming_manager + if _streaming_manager: + await _streaming_manager.stop() diff --git a/build/_deps/googletest-src b/build/_deps/googletest-src new file mode 160000 index 0000000000000000000000000000000000000000..f8d7d77c06936315286eb55f8de22cd23c188571 --- /dev/null +++ b/build/_deps/googletest-src @@ -0,0 +1 @@ +Subproject commit f8d7d77c06936315286eb55f8de22cd23c188571 diff --git a/build/_deps/pybind11-src b/build/_deps/pybind11-src new file mode 160000 index 0000000000000000000000000000000000000000..a2e59f0e7065404b44dfe92a28aca47ba1378dc4 --- /dev/null +++ b/build/_deps/pybind11-src @@ -0,0 +1 @@ +Subproject commit a2e59f0e7065404b44dfe92a28aca47ba1378dc4 diff --git a/dist/wayy_db-0.1.0-cp310-cp310-linux_x86_64.whl b/dist/wayy_db-0.1.0-cp310-cp310-linux_x86_64.whl new file mode 100644 index 0000000000000000000000000000000000000000..0d0508fd3a507a407c78884e6fba4567afcea06a Binary files /dev/null and b/dist/wayy_db-0.1.0-cp310-cp310-linux_x86_64.whl differ diff --git a/include/wayy_db/column.hpp b/include/wayy_db/column.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5e0cf7236977ed114fea9228e02c88c051d13453 --- /dev/null +++ b/include/wayy_db/column.hpp @@ -0,0 +1,135 @@ +#pragma once + +#include "wayy_db/types.hpp" +#include "wayy_db/column_view.hpp" + +#include