diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..9a6c4f3d73b2081824b43d205949dbac20064621 --- /dev/null +++ b/.gitignore @@ -0,0 +1,32 @@ +# Python +__pycache__/ +*.py[cod] +*.so +.venv/ +venv/ +.env +*.egg-info/ +dist/ +build/ +.pytest_cache/ +.mypy_cache/ +.ruff_cache/ +htmlcov/ +.coverage + +# Node +node_modules/ +.next/ +dist/ +.env.local + +# Jupyter +.ipynb_checkpoints/ + +# OS +.DS_Store +*.swp + +# Playwright +.playwright-mcp/ + diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..269b8c7bc6e623bb9b4674fefaa171c603ad00e2 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,123 @@ +cmake_minimum_required(VERSION 3.20) +project(wayy_db VERSION 0.1.0 LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +# Options +option(WAYY_BUILD_PYTHON "Build Python bindings" ON) +option(WAYY_BUILD_TESTS "Build unit tests" ON) +option(WAYY_BUILD_BENCHMARKS "Build benchmarks" OFF) +option(WAYY_USE_AVX2 "Enable AVX2 SIMD optimizations" ON) +option(WAYY_USE_LZ4 "Enable LZ4 compression" OFF) + +# Compiler flags +if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") + add_compile_options(-Wall -Wextra -Wpedantic) + if(WAYY_USE_AVX2) + add_compile_options(-mavx2 -mfma) + endif() +endif() + +# Core library +add_library(wayy_core STATIC + src/types.cpp + src/column.cpp + src/string_column.cpp + src/hash_index.cpp + src/table.cpp + src/database.cpp + src/mmap_file.cpp + src/wal.cpp + src/ops/aggregations.cpp + src/ops/joins.cpp + src/ops/window.cpp +) + +target_include_directories(wayy_core PUBLIC + $ + $ +) + +# Need PIC for linking into shared libraries (Python module) +set_target_properties(wayy_core PROPERTIES POSITION_INDEPENDENT_CODE ON) + +if(WAYY_USE_AVX2) + target_compile_definitions(wayy_core PUBLIC WAYY_USE_AVX2=1) +endif() + +if(WAYY_USE_LZ4) + find_package(lz4 REQUIRED) + target_link_libraries(wayy_core PRIVATE lz4::lz4) + target_compile_definitions(wayy_core PUBLIC WAYY_USE_LZ4=1) +endif() + +# Python bindings +if(WAYY_BUILD_PYTHON) + find_package(Python3 REQUIRED COMPONENTS Interpreter Development.Module) + + # Fetch pybind11 (v2.13+ required for free-threaded Python support) + include(FetchContent) + FetchContent_Declare( + pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11.git + GIT_TAG v2.13.6 + ) + FetchContent_MakeAvailable(pybind11) + + pybind11_add_module(_core python/bindings.cpp) + target_link_libraries(_core PRIVATE wayy_core) + + # Install Python module to the package directory + # scikit-build-core will place this in the wayy_db package + install(TARGETS _core DESTINATION wayy_db COMPONENT python) +endif() + +# Tests +if(WAYY_BUILD_TESTS) + enable_testing() + + # Fetch GoogleTest + include(FetchContent) + FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG v1.14.0 + ) + # Prevent overriding parent project's compiler/linker settings (Windows) + set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + FetchContent_MakeAvailable(googletest) + + add_executable(wayy_tests + tests/test_types.cpp + tests/test_column.cpp + tests/test_table.cpp + tests/test_mmap.cpp + tests/test_joins.cpp + ) + + target_link_libraries(wayy_tests PRIVATE + wayy_core + GTest::gtest + GTest::gtest_main + ) + + include(GoogleTest) + gtest_discover_tests(wayy_tests) +endif() + +# Benchmarks +if(WAYY_BUILD_BENCHMARKS) + find_package(benchmark REQUIRED) + + add_executable(wayy_benchmarks + benchmarks/bench_aggregations.cpp + benchmarks/bench_joins.cpp + ) + + target_link_libraries(wayy_benchmarks PRIVATE + wayy_core + benchmark::benchmark + ) +endif() diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..979f27827563b371c061fe1c01521f4571464f42 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,30 @@ +# WayyDB API Docker Image +FROM python:3.12 + +# Install C++ toolchain and cmake via apt (more reliable than pip cmake) +RUN apt-get update && apt-get install -y --no-install-recommends \ + g++ cmake ninja-build \ + && rm -rf /var/lib/apt/lists/* + +RUN useradd -m -u 1000 user +RUN mkdir -p /home/user/data/wayydb /data/wayydb && \ + chown -R user:user /home/user /data + +USER user +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:$PATH \ + WAYY_DATA_PATH=/data/wayydb \ + PORT=8080 + +WORKDIR $HOME/app + +RUN pip install --no-cache-dir --upgrade pip && \ + pip install --no-cache-dir scikit-build-core pybind11 numpy build + +COPY --chown=user . . + +RUN pip install --no-cache-dir -v ".[api,cli]" + +EXPOSE 8080 + +CMD uvicorn api.main:app --host 0.0.0.0 --port ${PORT:-8080} diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5f922266702bd271550c8fa9c0ea7af411d04257 --- /dev/null +++ b/README.md @@ -0,0 +1,357 @@ +--- +title: WayyDB API +emoji: ⚡ +colorFrom: blue +colorTo: purple +sdk: docker +app_port: 7860 +--- + +

+

WayyDB

+

+ High-performance columnar time-series database for quantitative finance +

+

+ kdb+ functionality • Pythonic API • Zero-copy NumPy • SIMD-accelerated +

+

+ PyPI + CI + License: MIT + Python versions +

+

+ +--- + +WayyDB is a C++ time-series database with Python bindings, designed for quantitative research and trading systems. It provides **kdb+-like temporal join operations** with a modern, accessible API—no q language required. + +## Why WayyDB? + +| Challenge | WayyDB Solution | +|-----------|-----------------| +| kdb+ costs $100K+/year | Open source, free forever | +| q language learning curve | Pythonic API you already know | +| Pandas/Polars lack temporal joins | Native `aj()` and `wj()` primitives | +| Memory copies kill performance | Zero-copy NumPy via mmap | +| Slow aggregations | AVX2/AVX-512 SIMD acceleration | + +## Features + +- **As-of Join (aj)** — For each trade, find the most recent quote. O(n log m) via binary search on sorted indices +- **Window Join (wj)** — Get all quotes within a time window around each trade +- **Zero-copy NumPy** — Columns are memory-mapped; `to_numpy()` returns views, not copies +- **SIMD Aggregations** — Sum, avg, min, max accelerated with AVX2 intrinsics +- **Window Functions** — Moving average, EMA, rolling std with O(n) complexity +- **Persistent Storage** — Tables saved as memory-mapped files for instant loading +- **Streaming API** — FastAPI REST + WebSocket endpoints for real-time tick ingestion and subscription +- **Pluggable Pub/Sub** — InMemory (default) or Redis backend for distributed deployments + +## Installation + +```bash +pip install wayy-db +``` + +Or build from source: + +```bash +git clone https://github.com/wayy-research/wayydb.git +cd wayydb +pip install -e . +``` + +## Quick Start + +### Create Tables from NumPy Arrays + +```python +import wayy_db as wdb +import numpy as np + +# Create trades table +trades = wdb.from_dict({ + "timestamp": np.array([1000, 2000, 3000, 4000, 5000], dtype=np.int64), + "symbol": np.array([0, 1, 0, 1, 0], dtype=np.uint32), # AAPL=0, MSFT=1 + "price": np.array([150.25, 380.50, 151.00, 381.25, 152.00]), + "size": np.array([100, 200, 150, 250, 100], dtype=np.int64), +}, name="trades", sorted_by="timestamp") + +# Create quotes table +quotes = wdb.from_dict({ + "timestamp": np.array([500, 900, 1500, 2500, 3500], dtype=np.int64), + "symbol": np.array([0, 1, 0, 1, 0], dtype=np.uint32), + "bid": np.array([149.50, 379.50, 150.50, 380.50, 151.50]), + "ask": np.array([150.00, 380.00, 151.00, 381.00, 152.00]), +}, name="quotes", sorted_by="timestamp") +``` + +### As-of Join: Match Trades to Quotes + +```python +# For each trade, get the most recent quote for that symbol +result = wdb.ops.aj(trades, quotes, on=["symbol"], as_of="timestamp") + +# Result contains trade columns + quote columns (bid, ask) +print(result["bid"].to_numpy()) # [149.5, 379.5, 150.5, 380.5, 151.5] +``` + +### Aggregations and Window Functions + +```python +# SIMD-accelerated aggregations +total_volume = wdb.ops.sum(trades["size"]) +avg_price = wdb.ops.avg(trades["price"]) +price_std = wdb.ops.std(trades["price"]) + +# Window functions +mavg_20 = wdb.ops.mavg(trades["price"], window=20) +ema = wdb.ops.ema(trades["price"], alpha=0.1) +rolling_std = wdb.ops.mstd(trades["price"], window=10) + +# Returns and changes +returns = wdb.ops.pct_change(trades["price"]) +price_diff = wdb.ops.diff(trades["price"]) +``` + +### Persistent Database + +```python +# Create persistent database +db = wdb.Database("/data/markets") + +# Add table (automatically saved) +db.add_table(trades) + +# Later: reload with zero-copy mmap +db2 = wdb.Database("/data/markets") +trades = db2["trades"] # Instant load via memory mapping + +# Access data without copying +prices = trades["price"].to_numpy() # Zero-copy view into mmap'd file +``` + +### Pandas/Polars Interop + +```python +import pandas as pd +import polars as pl + +# From pandas +df = pd.DataFrame({"timestamp": [...], "price": [...]}) +table = wdb.from_pandas(df, name="from_pandas", sorted_by="timestamp") + +# From polars +df = pl.DataFrame({"timestamp": [...], "price": [...]}) +table = wdb.from_polars(df, name="from_polars", sorted_by="timestamp") + +# To dict (for conversion back) +data = table.to_dict() # {"timestamp": np.array, "price": np.array, ...} +``` + +## API Reference + +### Core Classes + +| Class | Description | +|-------|-------------| +| `Database(path="")` | Container for tables. Empty path = in-memory | +| `Table(name="")` | Columnar table with optional sorted index | +| `Column` | Typed column with zero-copy NumPy access | + +### Table Methods + +```python +table.num_rows # Number of rows +table.num_columns # Number of columns +table.column_names() # List of column names +table.sorted_by # Column used for temporal ordering (or None) +table["col"] # Get column by name +table.to_dict() # Export as {name: np.array} dict +table.save(path) # Save to directory +Table.load(path) # Load from directory (copies data) +Table.mmap(path) # Memory-map from directory (zero-copy) +``` + +### Operations (wayy_db.ops) + +#### Aggregations +| Function | Description | +|----------|-------------| +| `sum(col)` | Sum of values (SIMD) | +| `avg(col)` | Mean of values | +| `min(col)` | Minimum value | +| `max(col)` | Maximum value | +| `std(col)` | Standard deviation | + +#### Temporal Joins +| Function | Description | +|----------|-------------| +| `aj(left, right, on, as_of)` | As-of join: most recent right row for each left row | +| `wj(left, right, on, as_of, before, after)` | Window join: all right rows within time window | + +#### Window Functions +| Function | Description | +|----------|-------------| +| `mavg(col, window)` | Moving average | +| `msum(col, window)` | Moving sum | +| `mstd(col, window)` | Moving standard deviation | +| `mmin(col, window)` | Moving minimum (O(n) via monotonic deque) | +| `mmax(col, window)` | Moving maximum (O(n) via monotonic deque) | +| `ema(col, alpha)` | Exponential moving average | +| `diff(col, periods=1)` | Difference from n periods ago | +| `pct_change(col, periods=1)` | Percent change from n periods ago | +| `shift(col, n)` | Shift values by n positions | + +## Type System + +| Type | Python | C++ | Size | Use Case | +|------|--------|-----|------|----------| +| Int64 | `np.int64` | `int64_t` | 8B | Quantities, IDs | +| Float64 | `np.float64` | `double` | 8B | Prices, returns | +| Timestamp | `np.int64` | `int64_t` | 8B | Nanoseconds since epoch | +| Symbol | `np.uint32` | `uint32_t` | 4B | Interned strings (tickers) | +| Bool | `np.uint8` | `uint8_t` | 1B | Flags | + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Python Interface │ +│ wayy_db.Database | Table | Column | ops │ +├─────────────────────────────────────────────────────────────┤ +│ pybind11 Bindings │ +│ Zero-copy NumPy arrays via buffer protocol │ +├─────────────────────────────────────────────────────────────┤ +│ C++ Core Engine │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │ +│ │ Storage │ │ Compute │ │ Joins │ │ +│ │ • mmap I/O │ │ • AVX2 agg │ │ • O(n log m) aj │ │ +│ │ • columnar │ │ • windows │ │ • O(n) wj │ │ +│ └─────────────┘ └─────────────┘ └─────────────────────┘ │ +├─────────────────────────────────────────────────────────────┤ +│ Memory-Mapped File Storage │ +│ Zero-copy | Lazy loading | Shared │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Performance + +### Complexity + +| Operation | Complexity | Notes | +|-----------|------------|-------| +| As-of join | O(n log(m/k)) | n=left rows, m=right rows, k=unique keys | +| Window join | O(n log m + matches) | Plus output size | +| Aggregations | O(n) | SIMD 4x speedup for sum | +| Window functions | O(n) | Single pass with O(1) update | +| Point lookup | O(log n) | Binary search on sorted index | +| Load from disk | O(1) | Memory mapping, no deserialization | + +### Benchmarks vs Alternatives + +Run the benchmark suite yourself: +```bash +pip install wayy-db[bench] +python -m benchmarks.benchmark --compare pandas,polars,duckdb +``` + +| Operation | wayyDB | pandas | Polars | DuckDB | +|-----------|--------|--------|--------|--------| +| As-of Join (1M x 1M) | 142ms | 8,234ms (58x slower) | 568ms (4x) | 345ms (2.4x) | +| Aggregation (5 ops) | 0.8ms | 16.2ms (20x) | 4.1ms (5x) | 5.6ms (7x) | +| Create Table (1M) | 12ms | 145ms (12x) | 35ms (3x) | 89ms (7x) | +| Load from Disk (1M) | 0.05ms (mmap) | 62ms (1240x) | 18ms (360x) | 32ms (640x) | + +### Design Targets + +| Metric | Target | +|--------|--------| +| As-of join (1M x 1M rows) | < 150ms | +| Simple aggregation (1B rows) | < 80ms | +| Binary size | < 5 MB | +| Memory overhead | < 1% beyond data | + +## Building from Source + +### Requirements + +- CMake >= 3.20 +- C++20 compiler (GCC 11+, Clang 14+, MSVC 2022+) +- Python >= 3.9 + +### Build + +```bash +git clone https://github.com/wayy-research/wayydb.git +cd wayydb + +# Option 1: pip install (recommended) +pip install -e . + +# Option 2: CMake directly +mkdir build && cd build +cmake .. -DWAYY_BUILD_PYTHON=ON -DWAYY_BUILD_TESTS=ON +make -j$(nproc) +``` + +### Run Tests + +```bash +# C++ tests (31 tests) +cd build && ctest --output-on-failure + +# Python tests (81 tests) +pytest tests/python -v +``` + +## Comparison with Alternatives + +| Feature | WayyDB | kdb+ | DuckDB | Polars | +|---------|--------|------|--------|--------| +| As-of join | Native | Native | Extension | None | +| Window join | Native | Native | None | None | +| Zero-copy Python | Yes | No | No | Limited | +| Sorted index optimization | Yes | Yes | No | No | +| License | MIT | Commercial | MIT | MIT | +| Learning curve | Low | High (q) | Low | Low | +| Persistence | mmap | Native | Native | None | + +## Roadmap + +- [x] Streaming ingestion API (WebSocket + REST) +- [x] Pluggable pub/sub (InMemory + Redis) +- [x] Multi-deployment Docker (Fly.io, Render, HF Spaces) +- [ ] String column type with dictionary encoding +- [ ] LZ4 compression for columns +- [ ] Parallel aggregations +- [ ] More join types (inner, left, full) +- [ ] Query optimizer + +## License + +MIT License - see [LICENSE](LICENSE) for details. + +## Contributing + +Contributions welcome! Please read our contributing guidelines and submit PRs to the `develop` branch. + +## Citation + +If you use wayyDB in your research, please cite: + +```bibtex +@software{wayydb2026, + title = {wayyDB: A High-Performance Columnar Time-Series Database}, + author = {Galbo, Rick}, + year = {2026}, + url = {https://github.com/Wayy-Research/wayyDB} +} +``` + +--- + +

+ Built with C++20 and Python by Wayy Research +

diff --git a/api/__pycache__/main.cpython-310.pyc b/api/__pycache__/main.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..667d1f8244c5c666dd855f3ce50659c2da7a789f Binary files /dev/null and b/api/__pycache__/main.cpython-310.pyc differ diff --git a/api/__pycache__/pubsub.cpython-310.pyc b/api/__pycache__/pubsub.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6106f46805d0b3270d758395cd655fbf2c9c50de Binary files /dev/null and b/api/__pycache__/pubsub.cpython-310.pyc differ diff --git a/api/__pycache__/streaming.cpython-310.pyc b/api/__pycache__/streaming.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..160cb3bb0f261a865483c8129147ec528a76dc77 Binary files /dev/null and b/api/__pycache__/streaming.cpython-310.pyc differ diff --git a/api/kvstore.py b/api/kvstore.py new file mode 100644 index 0000000000000000000000000000000000000000..8feb4f721fc12b0246ff04b16af35c32746ace2c --- /dev/null +++ b/api/kvstore.py @@ -0,0 +1,150 @@ +""" +KV Store - In-memory key-value store with TTL for wayyDB. + +Provides Redis-like KV semantics for future multi-process scaling. +Background eviction runs every 60 seconds. +""" + +import asyncio +import logging +import time +from fnmatch import fnmatch +from typing import Any, Dict, List, Optional + +logger = logging.getLogger(__name__) + + +class KVEntry: + """A stored value with optional TTL.""" + __slots__ = ("value", "expires_at", "created_at") + + def __init__(self, value: Any, ttl: Optional[float] = None): + now = time.time() + self.value = value + self.expires_at = now + ttl if ttl else float("inf") + self.created_at = now + + @property + def is_expired(self) -> bool: + return time.time() > self.expires_at + + @property + def ttl_remaining(self) -> Optional[float]: + if self.expires_at == float("inf"): + return None + remaining = self.expires_at - time.time() + return max(0, remaining) + + +class KVStore: + """ + In-memory KV store with TTL and background eviction. + + Thread-safe for single-process async use (GIL + event loop). + """ + + def __init__(self) -> None: + self._data: Dict[str, KVEntry] = {} + self._eviction_task: Optional[asyncio.Task] = None + self._sets: int = 0 + self._gets: int = 0 + self._deletes: int = 0 + self._evictions: int = 0 + + async def start(self) -> None: + """Start the background eviction task.""" + if self._eviction_task is None: + self._eviction_task = asyncio.create_task(self._eviction_loop()) + logger.info("KVStore eviction task started") + + async def stop(self) -> None: + """Stop the background eviction task.""" + if self._eviction_task: + self._eviction_task.cancel() + try: + await self._eviction_task + except asyncio.CancelledError: + pass + self._eviction_task = None + + def set(self, key: str, value: Any, ttl: Optional[float] = None) -> None: + """Set a key with optional TTL (seconds).""" + self._data[key] = KVEntry(value, ttl) + self._sets += 1 + + def get(self, key: str) -> Optional[Any]: + """Get a value by key. Returns None if missing or expired.""" + self._gets += 1 + entry = self._data.get(key) + if entry is None: + return None + if entry.is_expired: + del self._data[key] + self._evictions += 1 + return None + return entry.value + + def delete(self, key: str) -> bool: + """Delete a key. Returns True if existed.""" + existed = key in self._data + if existed: + del self._data[key] + self._deletes += 1 + return existed + + def keys(self, pattern: Optional[str] = None) -> List[str]: + """List keys, optionally filtered by glob pattern.""" + now = time.time() + result = [] + for k, v in self._data.items(): + if v.expires_at > now: + if pattern is None or fnmatch(k, pattern): + result.append(k) + return result + + def stats(self) -> Dict[str, Any]: + """Get store statistics.""" + now = time.time() + active = sum(1 for v in self._data.values() if v.expires_at > now) + return { + "total_keys": len(self._data), + "active_keys": active, + "sets": self._sets, + "gets": self._gets, + "deletes": self._deletes, + "evictions": self._evictions, + } + + async def _eviction_loop(self) -> None: + """Background loop to evict expired entries every 60s.""" + while True: + try: + await asyncio.sleep(60) + count = self._evict_expired() + if count > 0: + logger.debug(f"KVStore evicted {count} expired entries") + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"KVStore eviction error: {e}") + + def _evict_expired(self) -> int: + """Evict all expired entries. Returns count evicted.""" + now = time.time() + expired = [k for k, v in self._data.items() if now > v.expires_at] + for k in expired: + del self._data[k] + self._evictions += len(expired) + return len(expired) + + +# Global singleton +_kv_store: Optional[KVStore] = None + + +def get_kv_store() -> KVStore: + """Get the global KV store instance.""" + global _kv_store + if _kv_store is None: + _kv_store = KVStore() + return _kv_store diff --git a/api/main.py b/api/main.py new file mode 100644 index 0000000000000000000000000000000000000000..0160d1d3129ea19de6116461304cda5467796c58 --- /dev/null +++ b/api/main.py @@ -0,0 +1,1031 @@ +""" +WayyDB REST API - High-performance columnar time-series database service + +Features: +- REST API for table operations, aggregations, joins, window functions +- WebSocket streaming ingestion for real-time tick data +- WebSocket pub/sub for streaming updates to clients +- Efficient batching and append operations +""" +import os +import re +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from contextlib import asynccontextmanager +from typing import Any, Optional, List + +import numpy as np +from fastapi import FastAPI, HTTPException, Query, Request, WebSocket, WebSocketDisconnect +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, ValidationError + +# Import wayyDB +import wayy_db as wdb + +# Import streaming module +from api.streaming import ( + get_streaming_manager, + start_streaming, + stop_streaming, + StreamingManager, +) + +# Import KV store +from api.kvstore import get_kv_store + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Thread pool for running CPU-bound wayyDB operations +executor = ThreadPoolExecutor(max_workers=4) + +# Global database instance +db: Optional[wdb.Database] = None + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Initialize database and streaming on startup.""" + global db + data_path = os.environ.get("WAYY_DATA_PATH", "/data/wayydb") + os.makedirs(data_path, exist_ok=True) + db = wdb.Database(data_path) + + # Initialize streaming manager with database reference + streaming = get_streaming_manager() + streaming.set_database(db) + await start_streaming() + + # Start KV store eviction + kv = get_kv_store() + await kv.start() + + logger.info(f"WayyDB started with data path: {data_path}") + + yield + + # Cleanup + await kv.stop() + await stop_streaming() + if db: + db.save() + logger.info("WayyDB shutdown complete") + + +app = FastAPI( + title="WayyDB API", + description="High-performance columnar time-series database with kdb+-like functionality", + version="0.1.0", + lifespan=lifespan, +) + +# CORS - configurable via CORS_ORIGINS env var +ALLOWED_ORIGINS = os.getenv("CORS_ORIGINS", "http://localhost:3000").split(",") + +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE"], + allow_headers=["Content-Type", "Authorization"], +) + + +# --- Pydantic Models --- + +class TableCreate(BaseModel): + name: str + sorted_by: Optional[str] = None + + +class ColumnData(BaseModel): + name: str + dtype: str # "int64", "float64", "timestamp", "symbol", "bool" + data: list + + +class TableData(BaseModel): + name: str + columns: list[ColumnData] + sorted_by: Optional[str] = None + + +class AggregationResult(BaseModel): + column: str + operation: str + result: float + + +class JoinRequest(BaseModel): + left_table: str + right_table: str + on: list[str] + as_of: str + window_before: Optional[int] = None # For window join + window_after: Optional[int] = None + + +class WindowRequest(BaseModel): + table: str + column: str + operation: str # mavg, msum, mstd, mmin, mmax, ema + window: Optional[int] = None + alpha: Optional[float] = None # For EMA + + +class AppendData(BaseModel): + """Data to append to an existing table.""" + columns: list[ColumnData] + + +class RowData(BaseModel): + """A single row as key-value pairs.""" + data: dict[str, Any] + + +class TableCreateOLTP(BaseModel): + """Create a table with OLTP schema definition.""" + name: str + columns: list[dict] # [{"name": "id", "dtype": "string"}, ...] + primary_key: Optional[str] = None + sorted_by: Optional[str] = None + + +class IngestTick(BaseModel): + """A single tick for streaming ingestion.""" + symbol: str + price: float + timestamp: Optional[int] = None # Nanoseconds since epoch + volume: Optional[float] = 0.0 + bid: Optional[float] = None + ask: Optional[float] = None + + +class IngestBatch(BaseModel): + """Batch of ticks for streaming ingestion.""" + ticks: list[IngestTick] + + +class SubscribeRequest(BaseModel): + """Subscription filter for WebSocket.""" + symbols: Optional[list[str]] = None # None = all symbols + + +# --- Helper Functions --- + +def dtype_from_string(s: str) -> wdb.DType: + mapping = { + "int64": wdb.DType.Int64, + "float64": wdb.DType.Float64, + "timestamp": wdb.DType.Timestamp, + "symbol": wdb.DType.Symbol, + "bool": wdb.DType.Bool, + } + # These types exist in C++ headers but aren't yet exposed in pybind11 bindings + # "string": _DTYPE_STRING, + # "decimal6": wdb.DType.Decimal6, + if s.lower() not in mapping: + raise ValueError(f"Unknown dtype: {s}. Available: {list(mapping.keys())}") + return mapping[s.lower()] + + +# String DType not yet in pybind11 bindings — use sentinel for safe comparisons +_DTYPE_STRING = getattr(wdb.DType, "String", None) + + +TABLE_NAME_RE = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]{0,63}$') + + +def validate_table_name(name: str) -> str: + if not TABLE_NAME_RE.match(name): + raise HTTPException(400, f"Invalid table name: {name}") + return name + + +def numpy_dtype_for(dtype: wdb.DType): + mapping = { + wdb.DType.Int64: np.int64, + wdb.DType.Float64: np.float64, + wdb.DType.Timestamp: np.int64, + wdb.DType.Symbol: np.uint32, + wdb.DType.Bool: np.uint8, + } + return mapping[dtype] + + +async def run_in_executor(func, *args): + """Run blocking wayyDB operations in thread pool.""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor(executor, func, *args) + + +# --- Routes --- + +@app.get("/") +async def root(): + return { + "service": "WayyDB API", + "version": "0.1.0", + "status": "healthy", + } + + +@app.get("/health") +async def health(): + return {"status": "healthy", "tables": len(db.tables()) if db else 0} + + +# --- Table Operations --- + +@app.get("/tables") +async def list_tables(): + """List all tables in the database.""" + return {"tables": db.tables()} + + +@app.post("/tables") +async def create_table(table: TableCreate): + """Create a new empty table.""" + if db.has_table(table.name): + raise HTTPException(400, f"Table '{table.name}' already exists") + + t = db.create_table(table.name) + if table.sorted_by: + t.set_sorted_by(table.sorted_by) + db.save() + return {"created": table.name} + + +@app.post("/tables/upload") +async def upload_table(table_data: TableData): + """Upload a complete table with data.""" + if db.has_table(table_data.name): + raise HTTPException(400, f"Table '{table_data.name}' already exists") + + t = wdb.Table(table_data.name) + + for col in table_data.columns: + dtype = dtype_from_string(col.dtype) + np_dtype = numpy_dtype_for(dtype) + arr = np.array(col.data, dtype=np_dtype) + t.add_column_from_numpy(col.name, arr, dtype) + + if table_data.sorted_by: + t.set_sorted_by(table_data.sorted_by) + + db.add_table(t) + db.save() + + return { + "created": table_data.name, + "rows": t.num_rows, + "columns": t.column_names(), + } + + +@app.get("/tables/{name}") +async def get_table_info(name: str): + """Get table metadata.""" + if not db.has_table(name): + raise HTTPException(404, f"Table '{name}' not found") + + t = db[name] + return { + "name": t.name, + "num_rows": t.num_rows, + "num_columns": t.num_columns, + "columns": t.column_names(), + "sorted_by": t.sorted_by, + } + + +@app.get("/tables/{name}/data") +async def get_table_data( + name: str, + limit: int = Query(default=100, le=10000), + offset: int = Query(default=0, ge=0), +): + """Get table data as JSON.""" + if not db.has_table(name): + raise HTTPException(404, f"Table '{name}' not found") + + t = db[name] + end = min(offset + limit, t.num_rows) + + result = {} + for col_name in t.column_names(): + col = t[col_name] + arr = col.to_numpy()[offset:end] + result[col_name] = arr.tolist() + + return { + "table": name, + "offset": offset, + "limit": limit, + "total_rows": t.num_rows, + "data": result, + } + + +@app.delete("/tables/{name}") +async def delete_table(name: str): + """Delete a table.""" + if not db.has_table(name): + raise HTTPException(404, f"Table '{name}' not found") + + db.drop_table(name) + return {"deleted": name} + + +# --- Aggregations --- + +@app.get("/tables/{name}/agg/{column}/{operation}") +async def aggregate(name: str, column: str, operation: str): + """ + Run aggregation on a column. + Operations: sum, avg, min, max, std + """ + if not db.has_table(name): + raise HTTPException(404, f"Table '{name}' not found") + + t = db[name] + if not t.has_column(column): + raise HTTPException(404, f"Column '{column}' not found") + + col = t[column] + + ops_map = { + "sum": wdb.ops.sum, + "avg": wdb.ops.avg, + "min": wdb.ops.min, + "max": wdb.ops.max, + "std": wdb.ops.std, + } + + if operation not in ops_map: + raise HTTPException(400, f"Unknown operation: {operation}") + + # Run in thread pool for concurrency + result = await run_in_executor(ops_map[operation], col) + + return AggregationResult(column=column, operation=operation, result=result) + + +# --- Joins --- + +@app.post("/join/aj") +async def as_of_join(req: JoinRequest): + """ + As-of join: find most recent right row for each left row. + Both tables must be sorted by the as_of column. + """ + if not db.has_table(req.left_table): + raise HTTPException(404, f"Table '{req.left_table}' not found") + if not db.has_table(req.right_table): + raise HTTPException(404, f"Table '{req.right_table}' not found") + + left = db[req.left_table] + right = db[req.right_table] + + def do_join(): + return wdb.ops.aj(left, right, req.on, req.as_of) + + result = await run_in_executor(do_join) + + # Return as dict + data = {} + for col_name in result.column_names(): + data[col_name] = result[col_name].to_numpy().tolist() + + return { + "join_type": "as_of", + "rows": result.num_rows, + "columns": result.column_names(), + "data": data, + } + + +@app.post("/join/wj") +async def window_join(req: JoinRequest): + """ + Window join: find all right rows within time window. + """ + if not db.has_table(req.left_table): + raise HTTPException(404, f"Table '{req.left_table}' not found") + if not db.has_table(req.right_table): + raise HTTPException(404, f"Table '{req.right_table}' not found") + + if req.window_before is None or req.window_after is None: + raise HTTPException(400, "window_before and window_after required for window join") + + left = db[req.left_table] + right = db[req.right_table] + + def do_join(): + return wdb.ops.wj(left, right, req.on, req.as_of, + req.window_before, req.window_after) + + result = await run_in_executor(do_join) + + data = {} + for col_name in result.column_names(): + data[col_name] = result[col_name].to_numpy().tolist() + + return { + "join_type": "window", + "rows": result.num_rows, + "columns": result.column_names(), + "data": data, + } + + +# --- Window Functions --- + +@app.post("/window") +async def window_function(req: WindowRequest): + """ + Apply window function to a column. + Operations: mavg, msum, mstd, mmin, mmax, ema, diff, pct_change + """ + if not db.has_table(req.table): + raise HTTPException(404, f"Table '{req.table}' not found") + + t = db[req.table] + if not t.has_column(req.column): + raise HTTPException(404, f"Column '{req.column}' not found") + + col = t[req.column] + + def do_window(): + if req.operation == "mavg": + return wdb.ops.mavg(col, req.window) + elif req.operation == "msum": + return wdb.ops.msum(col, req.window) + elif req.operation == "mstd": + return wdb.ops.mstd(col, req.window) + elif req.operation == "mmin": + return wdb.ops.mmin(col, req.window) + elif req.operation == "mmax": + return wdb.ops.mmax(col, req.window) + elif req.operation == "ema": + return wdb.ops.ema(col, req.alpha) + elif req.operation == "diff": + return wdb.ops.diff(col, req.window or 1) + elif req.operation == "pct_change": + return wdb.ops.pct_change(col, req.window or 1) + else: + raise ValueError(f"Unknown operation: {req.operation}") + + result = await run_in_executor(do_window) + + return { + "table": req.table, + "column": req.column, + "operation": req.operation, + "result": result.tolist(), + } + + +# --- Append API --- + +@app.post("/tables/{name}/append") +async def append_to_table(name: str, data: AppendData): + """ + Append rows to an existing table. + + This is more efficient than re-uploading the entire table. + The new data must have the same columns as the existing table. + """ + if not db.has_table(name): + raise HTTPException(404, f"Table '{name}' not found") + + existing = db[name] + existing_cols = set(existing.column_names()) + + # Validate columns match + new_cols = {col.name for col in data.columns} + if existing_cols != new_cols: + raise HTTPException( + 400, + f"Column mismatch. Expected: {sorted(existing_cols)}, got: {sorted(new_cols)}" + ) + + # Get existing data + existing_data = {} + for col_name in existing.column_names(): + existing_data[col_name] = existing[col_name].to_numpy() + + # Prepare new data + new_data = {} + for col in data.columns: + dtype = dtype_from_string(col.dtype) + np_dtype = numpy_dtype_for(dtype) + new_data[col.name] = np.array(col.data, dtype=np_dtype) + + # Concatenate + combined = {} + for col_name in existing_cols: + combined[col_name] = np.concatenate([existing_data[col_name], new_data[col_name]]) + + # Get sorted_by before dropping + sorted_by = existing.sorted_by + + # Drop and recreate + db.drop_table(name) + new_table = wdb.from_dict(combined, name=name, sorted_by=sorted_by) + db.add_table(new_table) + db.save() + + return { + "appended": name, + "new_rows": len(data.columns[0].data) if data.columns else 0, + "total_rows": new_table.num_rows, + } + + +# --- OLTP / CRUD API --- + +@app.post("/api/v1/{db_name}/tables") +async def create_oltp_table(db_name: str, schema: TableCreateOLTP): + """Create a table with typed columns and optional primary key.""" + validate_table_name(schema.name) + + if db.has_table(schema.name): + raise HTTPException(400, f"Table '{schema.name}' already exists") + + t = db.create_table(schema.name) + + # Add columns based on schema + for col_def in schema.columns: + col_name = col_def["name"] + dtype_str = col_def["dtype"] + dtype = dtype_from_string(dtype_str) + np_dtype = numpy_dtype_for(dtype) + arr = np.array([], dtype=np_dtype) + t.add_column_from_numpy(col_name, arr, dtype) + + if schema.sorted_by: + t.set_sorted_by(schema.sorted_by) + if schema.primary_key: + t.set_primary_key(schema.primary_key) + + db.save() + return {"created": schema.name, "columns": [c["name"] for c in schema.columns]} + + +@app.post("/api/v1/{db_name}/tables/{table_name}/rows") +async def insert_row(db_name: str, table_name: str, row: RowData): + """Insert a single row into a table.""" + if not db.has_table(table_name): + raise HTTPException(404, f"Table '{table_name}' not found") + + t = db[table_name] + try: + row_idx = t.append_row(row.data) + except Exception as e: + raise HTTPException(400, str(e)) + + return {"table": table_name, "row_index": row_idx} + + +@app.put("/api/v1/{db_name}/tables/{table_name}/rows/{pk}") +async def update_row(db_name: str, table_name: str, pk: str, row: RowData): + """Update a row by primary key.""" + if not db.has_table(table_name): + raise HTTPException(404, f"Table '{table_name}' not found") + + t = db[table_name] + if not t.primary_key: + raise HTTPException(400, "Table has no primary key set") + + pk_dtype = t.column_dtype(t.primary_key) + + try: + if pk_dtype == _DTYPE_STRING: + ok = t.update_row(pk, row.data) + else: + ok = t.update_row(int(pk), row.data) + except Exception as e: + raise HTTPException(400, str(e)) + + if not ok: + raise HTTPException(404, f"Row with pk={pk} not found") + + return {"table": table_name, "pk": pk, "updated": True} + + +@app.delete("/api/v1/{db_name}/tables/{table_name}/rows/{pk}") +async def delete_row(db_name: str, table_name: str, pk: str): + """Soft-delete a row by primary key.""" + if not db.has_table(table_name): + raise HTTPException(404, f"Table '{table_name}' not found") + + t = db[table_name] + if not t.primary_key: + raise HTTPException(400, "Table has no primary key set") + + pk_dtype = t.column_dtype(t.primary_key) + + if pk_dtype == _DTYPE_STRING: + ok = t.delete_row(pk) + else: + ok = t.delete_row(int(pk)) + + if not ok: + raise HTTPException(404, f"Row with pk={pk} not found") + + return {"table": table_name, "pk": pk, "deleted": True} + + +def _read_row_at(t, row_idx: int) -> dict[str, Any]: + """Read a single row from a table by index, returning a dict.""" + row = {} + for col_name in t.column_names(): + if t.has_string_column(col_name): + scol = t.string_column(col_name) + row[col_name] = scol.get(row_idx) + else: + col = t.column(col_name) + arr = col.to_numpy() + val = arr[row_idx] + # Convert numpy types to Python native for JSON serialization + row[col_name] = val.item() if hasattr(val, "item") else val + return row + + +@app.get("/api/v1/{db_name}/tables/{table_name}/rows/{pk}") +async def get_row_by_pk(db_name: str, table_name: str, pk: str): + """Get a single row by primary key.""" + if not db.has_table(table_name): + raise HTTPException(404, f"Table '{table_name}' not found") + + t = db[table_name] + if not t.primary_key: + raise HTTPException(400, "Table has no primary key set") + + pk_dtype = t.column_dtype(t.primary_key) + + if pk_dtype == _DTYPE_STRING: + row_idx = t.find_row(pk) + else: + row_idx = t.find_row(int(pk)) + + if row_idx is None: + raise HTTPException(404, f"Row with pk={pk} not found") + + return {"data": _read_row_at(t, row_idx)} + + +@app.get("/api/v1/{db_name}/tables/{table_name}/rows") +async def filter_rows(db_name: str, table_name: str, request: Request): + """Filter rows by query parameters (col=val). Returns matching row data.""" + if not db.has_table(table_name): + raise HTTPException(404, f"Table '{table_name}' not found") + + t = db[table_name] + params = dict(request.query_params) + limit = int(params.pop("limit", "500")) + + # Intersect filter results across all query params + row_indices = None + for col, val in params.items(): + if not t.has_column(col) and not t.has_string_column(col): + continue + try: + col_dtype = t.column_dtype(col) + if col_dtype == _DTYPE_STRING: + matches = set(t.where_eq(col, val)) + else: + matches = set(t.where_eq(col, int(val))) + except Exception: + continue + row_indices = matches if row_indices is None else row_indices & matches + + # If no filters, return all valid rows + if row_indices is None: + row_indices = set(range(t.num_rows)) + + # Sort and limit + sorted_indices = sorted(row_indices)[:limit] + + rows = [_read_row_at(t, idx) for idx in sorted_indices] + return {"data": rows, "count": len(rows)} + + +@app.post("/api/v1/{db_name}/checkpoint") +async def checkpoint(db_name: str): + """Flush WAL, save all tables to disk, truncate WAL.""" + db.checkpoint() + return {"checkpoint": "ok"} + + +# --- Streaming Ingestion API --- + +@app.post("/ingest/{table}") +async def ingest_tick(table: str, tick: IngestTick): + """ + Ingest a single tick via REST. + + For high-throughput, use the WebSocket endpoint instead. + """ + validate_table_name(table) + streaming = get_streaming_manager() + await streaming.ingest_tick( + table=table, + symbol=tick.symbol, + price=tick.price, + timestamp=tick.timestamp, + volume=tick.volume or 0.0, + bid=tick.bid or tick.price, + ask=tick.ask or tick.price, + ) + return {"ingested": 1, "table": table} + + +@app.post("/ingest/{table}/batch") +async def ingest_batch(table: str, batch: IngestBatch): + """ + Ingest a batch of ticks via REST. + + For high-throughput, use the WebSocket endpoint instead. + """ + validate_table_name(table) + streaming = get_streaming_manager() + ticks = [ + { + "symbol": t.symbol, + "price": t.price, + "timestamp": t.timestamp, + "volume": t.volume or 0.0, + "bid": t.bid or t.price, + "ask": t.ask or t.price, + } + for t in batch.ticks + ] + await streaming.ingest_batch(table=table, ticks=ticks) + return {"ingested": len(ticks), "table": table} + + +# --- WebSocket Endpoints --- + +@app.websocket("/ws/ingest/{table}") +async def ws_ingest(websocket: WebSocket, table: str): + """ + WebSocket endpoint for streaming tick ingestion. + + Send JSON messages with tick data: + { + "symbol": "BTC-USD", + "price": 42150.50, + "timestamp": 1704067200000000000, // Optional, nanoseconds + "volume": 1.5, // Optional + "bid": 42150.00, // Optional + "ask": 42151.00 // Optional + } + + Or batches: + { + "batch": [ + {"symbol": "BTC-USD", "price": 42150.50, ...}, + {"symbol": "ETH-USD", "price": 2250.25, ...} + ] + } + """ + await websocket.accept() + streaming = get_streaming_manager() + + logger.info(f"Ingestion WebSocket connected for table: {table}") + + try: + while True: + data = await websocket.receive_json() + + if "batch" in data: + # Batch ingestion + ticks = data["batch"] + await streaming.ingest_batch(table=table, ticks=ticks) + await websocket.send_json({"ack": len(ticks)}) + else: + # Single tick + await streaming.ingest_tick( + table=table, + symbol=data["symbol"], + price=data["price"], + timestamp=data.get("timestamp"), + volume=data.get("volume", 0.0), + bid=data.get("bid", data["price"]), + ask=data.get("ask", data["price"]), + ) + await websocket.send_json({"ack": 1}) + + except WebSocketDisconnect: + logger.info(f"Ingestion WebSocket disconnected for table: {table}") + except Exception as e: + logger.error(f"Ingestion WebSocket error: {e}") + await websocket.close(code=1011, reason=str(e)) + + +@app.websocket("/ws/subscribe/{table}") +async def ws_subscribe(websocket: WebSocket, table: str): + """ + WebSocket endpoint for subscribing to real-time updates. + + Optionally send a filter message after connecting: + {"symbols": ["BTC-USD", "ETH-USD"]} + + Receives updates as: + { + "symbol": "BTC-USD", + "price": 42150.50, + "bid": 42150.00, + "ask": 42151.00, + "volume": 1.5, + "timestamp": 1704067200000000000, + "table": "ticks" + } + + Or batches during high-throughput: + {"batch": [...]} + """ + await websocket.accept() + streaming = get_streaming_manager() + + # Default: subscribe to all symbols + symbols = None + + # Check for initial filter message (non-blocking) + try: + # Wait briefly for filter message + data = await asyncio.wait_for(websocket.receive_json(), timeout=0.5) + if "symbols" in data: + symbols = data["symbols"] + logger.info(f"Subscription filter: {symbols}") + except asyncio.TimeoutError: + pass + except Exception: + pass + + subscriber = await streaming.subscribe(websocket, table, symbols) + logger.info(f"Subscription WebSocket connected for table: {table}, symbols: {symbols or 'all'}") + + try: + # Keep connection alive, handle any incoming messages + while True: + try: + data = await websocket.receive_json() + # Handle filter updates + if "symbols" in data: + subscriber.symbols = set(data["symbols"]) if data["symbols"] else set() + await websocket.send_json({"filter_updated": list(subscriber.symbols) or "all"}) + except WebSocketDisconnect: + raise + except Exception: + pass + + except WebSocketDisconnect: + logger.info(f"Subscription WebSocket disconnected for table: {table}") + finally: + await streaming.unsubscribe(websocket, table) + + +# --- Streaming Stats --- + +@app.get("/streaming/stats") +async def streaming_stats(): + """Get streaming ingestion and pub/sub statistics.""" + streaming = get_streaming_manager() + return streaming.get_stats() + + +@app.get("/streaming/quote/{table}/{symbol}") +async def get_quote(table: str, symbol: str): + """Get the latest quote for a symbol from the streaming cache.""" + streaming = get_streaming_manager() + quote = streaming.get_latest_quote(table, symbol) + if not quote: + raise HTTPException(404, f"No quote for {symbol} in {table}") + return quote + + +@app.get("/streaming/quotes/{table}") +async def get_all_quotes(table: str): + """Get all latest quotes for a table from the streaming cache.""" + streaming = get_streaming_manager() + return streaming.get_all_quotes(table) + + +@app.get("/streaming/pubsub") +async def pubsub_stats(): + """Get pub/sub backend statistics (channels, sequences, backend type).""" + streaming = get_streaming_manager() + stats = streaming.get_stats() + return stats.get("pubsub", {"backend": "none", "info": "PubSub not configured"}) + + +# --- KV Store API --- + +class KVSetRequest(BaseModel): + """Request body for setting a KV entry.""" + value: Any + ttl: Optional[float] = None # TTL in seconds, None = no expiry + + +@app.post("/kv/{key}") +async def kv_set(key: str, req: KVSetRequest): + """Set a key-value pair with optional TTL.""" + kv = get_kv_store() + kv.set(key, req.value, ttl=req.ttl) + return {"key": key, "ttl": req.ttl} + + +@app.get("/kv/{key}") +async def kv_get(key: str): + """Get a value by key.""" + kv = get_kv_store() + value = kv.get(key) + if value is None: + raise HTTPException(404, f"Key '{key}' not found or expired") + return {"key": key, "value": value} + + +@app.delete("/kv/{key}") +async def kv_delete(key: str): + """Delete a key.""" + kv = get_kv_store() + existed = kv.delete(key) + if not existed: + raise HTTPException(404, f"Key '{key}' not found") + return {"deleted": key} + + +@app.get("/kv") +async def kv_list(pattern: Optional[str] = None): + """List keys, optionally filtered by glob pattern.""" + kv = get_kv_store() + keys = kv.keys(pattern) + return {"keys": keys, "count": len(keys)} + + +@app.get("/kv-stats") +async def kv_stats(): + """Get KV store statistics.""" + kv = get_kv_store() + return kv.stats() + + +# --- General Pub/Sub API --- + +class PubSubPublishRequest(BaseModel): + """Request body for publishing to a channel.""" + data: Any + + +@app.post("/pubsub/publish/{channel}") +async def pubsub_publish(channel: str, req: PubSubPublishRequest): + """Publish a message to a channel.""" + streaming = get_streaming_manager() + # Use the streaming manager's broadcast mechanism + # For general pub/sub, we broadcast to WebSocket subscribers + await streaming.broadcast_to_channel(channel, req.data) + return {"channel": channel, "published": True} + + +@app.websocket("/ws/pubsub") +async def ws_pubsub(websocket: WebSocket): + """ + General pub/sub WebSocket endpoint. + + Send subscription request after connecting: + {"action": "subscribe", "channels": ["prices:*", "trades"]} + + Receives messages as: + {"channel": "prices:BTC-USD", "data": {...}} + """ + await websocket.accept() + streaming = get_streaming_manager() + + subscribed_channels: list[str] = [] + + logger.info("PubSub WebSocket connected") + + try: + while True: + data = await websocket.receive_json() + + action = data.get("action") + if action == "subscribe": + channels = data.get("channels", []) + subscribed_channels.extend(channels) + await websocket.send_json({ + "type": "subscribed", + "channels": subscribed_channels, + }) + elif action == "ping": + await websocket.send_json({"type": "pong"}) + + except WebSocketDisconnect: + logger.info("PubSub WebSocket disconnected") + except Exception as e: + logger.error(f"PubSub WebSocket error: {e}") diff --git a/api/pubsub.py b/api/pubsub.py new file mode 100644 index 0000000000000000000000000000000000000000..eeae1239e0daab139449fcd4ed8baf0dec0648bb --- /dev/null +++ b/api/pubsub.py @@ -0,0 +1,547 @@ +""" +WayyDB PubSub Abstraction Layer + +Provides a pluggable pub/sub transport for real-time tick distribution. +Two backends: + - InMemoryPubSub: Default, zero dependencies, single-process + - RedisPubSub: Optional, requires redis-py, multi-process capable + +Configure via REDIS_URL environment variable: + - Not set or empty: uses InMemoryPubSub + - Set to redis://...: uses RedisPubSub + +Channel naming convention: + ticks:{symbol} - Trade ticks for a symbol + quotes:{symbol} - Quote updates for a symbol + ticks:* - All trade ticks + {table}:{symbol} - Generic table:symbol pattern +""" + +import asyncio +import logging +import time +from abc import ABC, abstractmethod +from collections import defaultdict, deque +from dataclasses import dataclass, field +from typing import Any, Callable, Coroutine, Dict, List, Optional, Set + +logger = logging.getLogger(__name__) + +# Type alias for async callback +AsyncCallback = Callable[[dict], Coroutine[Any, Any, None]] + + +@dataclass +class Message: + """A pub/sub message with metadata.""" + channel: str + data: dict + sequence: int + timestamp: float = field(default_factory=time.time) + + +class PubSubBackend(ABC): + """Abstract pub/sub backend interface. + + Implementations must provide publish, subscribe, and unsubscribe. + This abstraction allows swapping between in-memory, Redis, NATS, etc. + """ + + @abstractmethod + async def publish(self, channel: str, data: dict) -> int: + """Publish a message to a channel. + + Args: + channel: Channel name (e.g., "ticks:AAPL") + data: Message payload + + Returns: + Sequence number of the published message + """ + ... + + @abstractmethod + async def subscribe( + self, + channel: str, + callback: AsyncCallback, + subscriber_id: str = "", + ) -> None: + """Subscribe to a channel with a callback. + + Args: + channel: Channel name or pattern (e.g., "ticks:AAPL" or "ticks:*") + callback: Async function called with each message dict + subscriber_id: Unique identifier for this subscriber + """ + ... + + @abstractmethod + async def unsubscribe(self, channel: str, subscriber_id: str = "") -> None: + """Unsubscribe from a channel. + + Args: + channel: Channel name or pattern + subscriber_id: The subscriber to remove + """ + ... + + @abstractmethod + async def publish_batch(self, channel: str, messages: List[dict]) -> int: + """Publish a batch of messages to a channel. + + Args: + channel: Channel name + messages: List of message payloads + + Returns: + Sequence number of the last message + """ + ... + + @abstractmethod + def get_stats(self) -> dict: + """Get pub/sub statistics.""" + ... + + @abstractmethod + async def start(self) -> None: + """Start the backend (connect, initialize).""" + ... + + @abstractmethod + async def stop(self) -> None: + """Stop the backend (disconnect, cleanup).""" + ... + + +class InMemoryPubSub(PubSubBackend): + """In-process pub/sub using asyncio. + + Features: + - Channel-based routing with wildcard support + - Per-channel sequence numbers + - Ring buffer for backpressure (drops oldest on overflow) + - Concurrent broadcast via asyncio.gather + - Message replay from buffer + """ + + def __init__( + self, + max_buffer_per_channel: int = 10000, + broadcast_timeout: float = 5.0, + ): + self._subscribers: Dict[str, Dict[str, AsyncCallback]] = defaultdict(dict) + self._sequence: Dict[str, int] = defaultdict(int) + self._buffers: Dict[str, deque] = {} + self._max_buffer = max_buffer_per_channel + self._broadcast_timeout = broadcast_timeout + self._stats = { + "messages_published": 0, + "messages_delivered": 0, + "messages_dropped": 0, + "active_subscriptions": 0, + "channels": 0, + } + self._running = False + + async def start(self) -> None: + self._running = True + logger.info("InMemoryPubSub started") + + async def stop(self) -> None: + self._running = False + self._subscribers.clear() + self._buffers.clear() + logger.info("InMemoryPubSub stopped") + + async def publish(self, channel: str, data: dict) -> int: + self._sequence[channel] += 1 + seq = self._sequence[channel] + + msg = Message(channel=channel, data=data, sequence=seq) + + # Buffer the message + if channel not in self._buffers: + self._buffers[channel] = deque(maxlen=self._max_buffer) + buf = self._buffers[channel] + if len(buf) >= self._max_buffer: + self._stats["messages_dropped"] += 1 + buf.append(msg) + + self._stats["messages_published"] += 1 + self._stats["channels"] = len(self._buffers) + + # Deliver to subscribers + await self._deliver(channel, data, seq) + + return seq + + async def publish_batch(self, channel: str, messages: List[dict]) -> int: + last_seq = 0 + for data in messages: + last_seq = await self.publish(channel, data) + return last_seq + + async def subscribe( + self, + channel: str, + callback: AsyncCallback, + subscriber_id: str = "", + ) -> None: + if not subscriber_id: + subscriber_id = f"sub_{id(callback)}" + + self._subscribers[channel][subscriber_id] = callback + self._stats["active_subscriptions"] = sum( + len(subs) for subs in self._subscribers.values() + ) + logger.debug(f"Subscribed {subscriber_id} to {channel}") + + async def unsubscribe(self, channel: str, subscriber_id: str = "") -> None: + if channel in self._subscribers: + if subscriber_id and subscriber_id in self._subscribers[channel]: + del self._subscribers[channel][subscriber_id] + elif not subscriber_id: + self._subscribers[channel].clear() + + if not self._subscribers[channel]: + del self._subscribers[channel] + + self._stats["active_subscriptions"] = sum( + len(subs) for subs in self._subscribers.values() + ) + + async def _deliver(self, channel: str, data: dict, sequence: int) -> None: + """Deliver message to all matching subscribers concurrently.""" + enriched = {**data, "_seq": sequence, "_channel": channel} + + # Collect all matching callbacks + callbacks: List[AsyncCallback] = [] + + # Exact match subscribers + if channel in self._subscribers: + callbacks.extend(self._subscribers[channel].values()) + + # Wildcard subscribers (e.g., "ticks:*" matches "ticks:AAPL") + for pattern, subs in self._subscribers.items(): + if pattern.endswith(":*"): + prefix = pattern[:-1] # "ticks:" + if channel.startswith(prefix) and pattern != channel: + callbacks.extend(subs.values()) + + if not callbacks: + return + + # Concurrent delivery with timeout + dead_callbacks: List[AsyncCallback] = [] + + async def safe_deliver(cb: AsyncCallback) -> None: + try: + await asyncio.wait_for(cb(enriched), timeout=self._broadcast_timeout) + self._stats["messages_delivered"] += 1 + except asyncio.TimeoutError: + logger.warning(f"Subscriber timed out on {channel}") + dead_callbacks.append(cb) + except Exception: + dead_callbacks.append(cb) + + await asyncio.gather(*(safe_deliver(cb) for cb in callbacks)) + + # Remove dead subscribers + for dead_cb in dead_callbacks: + for pattern, subs in list(self._subscribers.items()): + to_remove = [ + sid for sid, cb in subs.items() if cb is dead_cb + ] + for sid in to_remove: + del subs[sid] + logger.debug(f"Removed dead subscriber {sid} from {pattern}") + + if dead_callbacks: + self._stats["active_subscriptions"] = sum( + len(subs) for subs in self._subscribers.values() + ) + + def get_channel_buffer(self, channel: str, since_seq: int = 0) -> List[Message]: + """Get buffered messages for replay. + + Args: + channel: Channel name + since_seq: Only return messages with sequence > since_seq + + Returns: + List of messages for replay + """ + if channel not in self._buffers: + return [] + return [m for m in self._buffers[channel] if m.sequence > since_seq] + + def get_stats(self) -> dict: + return { + "backend": "in_memory", + **self._stats, + "buffer_sizes": {ch: len(buf) for ch, buf in self._buffers.items()}, + } + + +class RedisPubSub(PubSubBackend): + """Redis-backed pub/sub for multi-process deployments. + + Uses Redis pub/sub for real-time delivery and Redis Streams + for message persistence and replay. + + Requires: pip install redis[hiredis] + Configure via REDIS_URL environment variable. + """ + + def __init__(self, redis_url: str, max_stream_len: int = 100000): + self._redis_url = redis_url + self._max_stream_len = max_stream_len + self._redis = None + self._pubsub = None + self._subscribers: Dict[str, Dict[str, AsyncCallback]] = defaultdict(dict) + self._sequence: Dict[str, int] = defaultdict(int) + self._listener_task: Optional[asyncio.Task] = None + self._running = False + self._stats = { + "messages_published": 0, + "messages_delivered": 0, + "messages_dropped": 0, + "active_subscriptions": 0, + "channels": 0, + "redis_connected": False, + } + + async def start(self) -> None: + try: + import redis.asyncio as aioredis + except ImportError: + raise ImportError( + "redis package required for RedisPubSub. " + "Install with: pip install redis[hiredis]" + ) + + self._redis = aioredis.from_url( + self._redis_url, + decode_responses=True, + socket_connect_timeout=5, + retry_on_timeout=True, + ) + + # Test connection + await self._redis.ping() + self._stats["redis_connected"] = True + + self._pubsub = self._redis.pubsub() + self._running = True + self._listener_task = asyncio.create_task(self._listen_loop()) + + logger.info(f"RedisPubSub connected to {self._redis_url}") + + async def stop(self) -> None: + self._running = False + + if self._listener_task: + self._listener_task.cancel() + try: + await self._listener_task + except asyncio.CancelledError: + pass + + if self._pubsub: + await self._pubsub.unsubscribe() + await self._pubsub.close() + + if self._redis: + await self._redis.close() + + self._stats["redis_connected"] = False + logger.info("RedisPubSub stopped") + + async def publish(self, channel: str, data: dict) -> int: + import json + + self._sequence[channel] += 1 + seq = self._sequence[channel] + + enriched = {**data, "_seq": seq, "_ts": time.time()} + payload = json.dumps(enriched) + + # Publish to Redis pub/sub channel + await self._redis.publish(f"wayy:{channel}", payload) + + # Also write to Redis Stream for persistence/replay + stream_key = f"wayy:stream:{channel}" + await self._redis.xadd( + stream_key, + {"data": payload}, + maxlen=self._max_stream_len, + ) + + self._stats["messages_published"] += 1 + return seq + + async def publish_batch(self, channel: str, messages: List[dict]) -> int: + import json + + pipe = self._redis.pipeline() + last_seq = 0 + + for data in messages: + self._sequence[channel] += 1 + seq = self._sequence[channel] + last_seq = seq + + enriched = {**data, "_seq": seq, "_ts": time.time()} + payload = json.dumps(enriched) + + pipe.publish(f"wayy:{channel}", payload) + + stream_key = f"wayy:stream:{channel}" + pipe.xadd(stream_key, {"data": payload}, maxlen=self._max_stream_len) + + await pipe.execute() + self._stats["messages_published"] += len(messages) + return last_seq + + async def subscribe( + self, + channel: str, + callback: AsyncCallback, + subscriber_id: str = "", + ) -> None: + if not subscriber_id: + subscriber_id = f"sub_{id(callback)}" + + is_new_channel = channel not in self._subscribers or not self._subscribers[channel] + self._subscribers[channel][subscriber_id] = callback + + if is_new_channel and self._pubsub: + if channel.endswith(":*"): + await self._pubsub.psubscribe(f"wayy:{channel}") + else: + await self._pubsub.subscribe(f"wayy:{channel}") + + self._stats["active_subscriptions"] = sum( + len(subs) for subs in self._subscribers.values() + ) + self._stats["channels"] = len(self._subscribers) + + async def unsubscribe(self, channel: str, subscriber_id: str = "") -> None: + if channel in self._subscribers: + if subscriber_id and subscriber_id in self._subscribers[channel]: + del self._subscribers[channel][subscriber_id] + elif not subscriber_id: + self._subscribers[channel].clear() + + if not self._subscribers[channel]: + del self._subscribers[channel] + if self._pubsub: + if channel.endswith(":*"): + await self._pubsub.punsubscribe(f"wayy:{channel}") + else: + await self._pubsub.unsubscribe(f"wayy:{channel}") + + self._stats["active_subscriptions"] = sum( + len(subs) for subs in self._subscribers.values() + ) + + async def _listen_loop(self) -> None: + """Background task that listens for Redis pub/sub messages.""" + import json + + while self._running: + try: + message = await self._pubsub.get_message( + ignore_subscribe_messages=True, timeout=0.1 + ) + if message is None: + await asyncio.sleep(0.01) + continue + + if message["type"] not in ("message", "pmessage"): + continue + + raw_channel = message.get("channel", "") + # Strip "wayy:" prefix + if raw_channel.startswith("wayy:"): + channel = raw_channel[5:] + else: + channel = raw_channel + + data = json.loads(message["data"]) + + # Deliver to local subscribers + await self._deliver_local(channel, data) + + except asyncio.CancelledError: + raise + except Exception as e: + logger.error(f"Redis listener error: {e}") + await asyncio.sleep(1.0) + + async def _deliver_local(self, channel: str, data: dict) -> None: + """Deliver a received message to local subscribers.""" + callbacks: List[AsyncCallback] = [] + + if channel in self._subscribers: + callbacks.extend(self._subscribers[channel].values()) + + # Wildcard matching + for pattern, subs in self._subscribers.items(): + if pattern.endswith(":*"): + prefix = pattern[:-1] + if channel.startswith(prefix) and pattern != channel: + callbacks.extend(subs.values()) + + for cb in callbacks: + try: + await asyncio.wait_for(cb(data), timeout=5.0) + self._stats["messages_delivered"] += 1 + except Exception: + self._stats["messages_dropped"] += 1 + + async def replay( + self, channel: str, since_id: str = "0-0", count: int = 1000 + ) -> List[dict]: + """Replay messages from Redis Stream. + + Args: + channel: Channel name + since_id: Redis Stream ID to start from + count: Maximum messages to return + + Returns: + List of message dicts + """ + import json + + stream_key = f"wayy:stream:{channel}" + messages = await self._redis.xrange(stream_key, min=since_id, count=count) + + return [json.loads(entry["data"]) for _id, entry in messages] + + def get_stats(self) -> dict: + return { + "backend": "redis", + "redis_url": self._redis_url.split("@")[-1] if "@" in self._redis_url else self._redis_url, + **self._stats, + } + + +def create_pubsub(redis_url: Optional[str] = None) -> PubSubBackend: + """Factory function to create the appropriate PubSub backend. + + Args: + redis_url: Redis URL. If None/empty, uses InMemoryPubSub. + + Returns: + PubSubBackend instance + """ + if redis_url: + logger.info(f"Using RedisPubSub backend") + return RedisPubSub(redis_url=redis_url) + else: + logger.info("Using InMemoryPubSub backend (set REDIS_URL for Redis)") + return InMemoryPubSub() diff --git a/api/requirements.txt b/api/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b7ac87370a76ddab4ef2b42536e4d5f7199f884e --- /dev/null +++ b/api/requirements.txt @@ -0,0 +1,6 @@ +fastapi>=0.109.0 +uvicorn[standard]>=0.27.0 +numpy>=1.20 +pydantic>=2.0 +websockets>=12.0 +redis[hiredis]>=5.0 diff --git a/api/streaming.py b/api/streaming.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddcddddd6e162e118cc58bb4a3c8e0b79445420 --- /dev/null +++ b/api/streaming.py @@ -0,0 +1,553 @@ +""" +WayyDB Streaming Module - Real-time data ingestion and pub/sub + +Provides: +- WebSocket ingestion endpoint for real-time tick data +- Pub/Sub subscriptions via pluggable backend (in-memory or Redis) +- Efficient batching and append operations +- In-memory buffers with periodic flush to persistent storage +- Backpressure handling and sequence numbers + +Configuration via environment variables: +- FLUSH_INTERVAL: Seconds between flushes to disk (default: 1.0) +- MAX_BUFFER_SIZE: Max ticks in buffer before force flush (default: 10000) +- BROADCAST_INTERVAL: Seconds between subscriber broadcasts (default: 0.05) +- REDIS_URL: Optional Redis URL for distributed pub/sub +""" + +import asyncio +import logging +import os +import threading +import time +from collections import defaultdict +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Set + +import numpy as np +from fastapi import WebSocket + +from api.pubsub import PubSubBackend, create_pubsub + +logger = logging.getLogger(__name__) + +# Configuration from environment +DEFAULT_FLUSH_INTERVAL = float(os.getenv("FLUSH_INTERVAL", "1.0")) +DEFAULT_MAX_BUFFER_SIZE = int(os.getenv("MAX_BUFFER_SIZE", "10000")) +DEFAULT_BROADCAST_INTERVAL = float(os.getenv("BROADCAST_INTERVAL", "0.05")) + + +@dataclass +class TickBuffer: + """Buffer for incoming tick data before flush to table.""" + timestamps: List[int] = field(default_factory=list) + symbols: List[str] = field(default_factory=list) + prices: List[float] = field(default_factory=list) + volumes: List[float] = field(default_factory=list) + bids: List[float] = field(default_factory=list) + asks: List[float] = field(default_factory=list) + + def append(self, timestamp: int, symbol: str, price: float, + volume: float = 0.0, bid: float = 0.0, ask: float = 0.0): + self.timestamps.append(timestamp) + self.symbols.append(symbol) + self.prices.append(price) + self.volumes.append(volume) + self.bids.append(bid if bid else price) + self.asks.append(ask if ask else price) + + def __len__(self): + return len(self.timestamps) + + def clear(self): + self.timestamps.clear() + self.symbols.clear() + self.prices.clear() + self.volumes.clear() + self.bids.clear() + self.asks.clear() + + def to_columnar(self) -> Dict[str, np.ndarray]: + """Convert to columnar format for WayyDB.""" + return { + "timestamp": np.array(self.timestamps, dtype=np.int64), + "symbol": np.array([hash(s) % (2**32) for s in self.symbols], dtype=np.uint32), + "price": np.array(self.prices, dtype=np.float64), + "volume": np.array(self.volumes, dtype=np.float64), + "bid": np.array(self.bids, dtype=np.float64), + "ask": np.array(self.asks, dtype=np.float64), + } + + +@dataclass +class Subscriber: + """A WebSocket subscriber to data updates.""" + websocket: WebSocket + symbols: Set[str] = field(default_factory=set) # Empty = all symbols + subscriber_id: str = "" + created_at: float = field(default_factory=time.time) + messages_sent: int = 0 + + +class StreamingManager: + """ + Manages streaming data ingestion and pub/sub distribution. + + Features: + - Buffer incoming ticks in memory + - Publish to PubSub channels (in-memory or Redis) + - Broadcast to WebSocket subscribers via PubSub callbacks + - Periodic flush to WayyDB tables (atomic swap, no gap) + - Thread-safe operations via threading.Lock + """ + + def __init__( + self, + flush_interval: float = DEFAULT_FLUSH_INTERVAL, + max_buffer_size: int = DEFAULT_MAX_BUFFER_SIZE, + batch_broadcast_interval: float = DEFAULT_BROADCAST_INTERVAL, + pubsub: Optional[PubSubBackend] = None, + ): + self.flush_interval = flush_interval + self.max_buffer_size = max_buffer_size + self.batch_broadcast_interval = batch_broadcast_interval + + # PubSub backend (in-memory default, Redis optional) + self._pubsub = pubsub + + # Tick buffers - one per table + self._buffers: Dict[str, TickBuffer] = defaultdict(TickBuffer) + + # WebSocket subscribers - one list per table + self._subscribers: Dict[str, List[Subscriber]] = defaultdict(list) + + # Latest quotes cache (for new subscribers) + self._latest_quotes: Dict[str, Dict[str, Any]] = {} + + # Pending broadcasts (batched for efficiency) + self._pending_broadcasts: Dict[str, List[Dict]] = defaultdict(list) + + # Statistics + self._stats = { + "ticks_received": 0, + "ticks_flushed": 0, + "broadcasts_sent": 0, + "active_subscribers": 0, + "flush_count": 0, + "start_time": None, + } + + # Background tasks + self._running = False + self._flush_task: Optional[asyncio.Task] = None + self._broadcast_task: Optional[asyncio.Task] = None + + # Database reference (set by API) + self._db = None + + # FIX: Use threading.Lock for thread safety with ThreadPoolExecutor + self._lock = threading.Lock() + + def set_database(self, db): + """Set the database reference for flushing.""" + self._db = db + + def set_pubsub(self, pubsub: PubSubBackend): + """Set the pub/sub backend.""" + self._pubsub = pubsub + + async def start(self): + """Start background flush and broadcast tasks.""" + if self._running: + return + + self._running = True + self._stats["start_time"] = datetime.now(timezone.utc).isoformat() + + # Start PubSub backend if provided + if self._pubsub: + await self._pubsub.start() + + self._flush_task = asyncio.create_task(self._flush_loop()) + self._broadcast_task = asyncio.create_task(self._broadcast_loop()) + + logger.info("StreamingManager started") + + async def stop(self): + """Stop background tasks and flush remaining data.""" + if not self._running: + return + + self._running = False + + if self._flush_task: + self._flush_task.cancel() + try: + await self._flush_task + except asyncio.CancelledError: + pass + + if self._broadcast_task: + self._broadcast_task.cancel() + try: + await self._broadcast_task + except asyncio.CancelledError: + pass + + # Final flush + await self._flush_all() + + # Stop PubSub backend + if self._pubsub: + await self._pubsub.stop() + + logger.info("StreamingManager stopped") + + async def ingest_tick( + self, + table: str, + symbol: str, + price: float, + timestamp: Optional[int] = None, + volume: float = 0.0, + bid: float = 0.0, + ask: float = 0.0, + ): + """Ingest a single tick.""" + if timestamp is None: + timestamp = int(datetime.now(timezone.utc).timestamp() * 1e9) + + # Add to buffer (thread-safe) + with self._lock: + self._buffers[table].append( + timestamp=timestamp, + symbol=symbol, + price=price, + volume=volume, + bid=bid, + ask=ask, + ) + self._stats["ticks_received"] += 1 + + # Build quote message + quote = { + "symbol": symbol, + "price": price, + "bid": bid or price, + "ask": ask or price, + "volume": volume, + "timestamp": timestamp, + "table": table, + } + self._latest_quotes[f"{table}:{symbol}"] = quote + + # Publish to PubSub channel + if self._pubsub: + channel = f"{table}:{symbol}" + await self._pubsub.publish(channel, quote) + + # Queue for WebSocket broadcast + self._pending_broadcasts[table].append(quote) + + # Force flush if buffer too large + if len(self._buffers[table]) >= self.max_buffer_size: + await self._flush_table(table) + + async def ingest_batch( + self, + table: str, + ticks: List[Dict[str, Any]], + ): + """Ingest a batch of ticks efficiently.""" + quotes_by_channel: Dict[str, List[dict]] = defaultdict(list) + + with self._lock: + buffer = self._buffers[table] + for tick in ticks: + timestamp = tick.get("timestamp") + if timestamp is None: + timestamp = int(datetime.now(timezone.utc).timestamp() * 1e9) + + buffer.append( + timestamp=timestamp, + symbol=tick["symbol"], + price=tick["price"], + volume=tick.get("volume", 0.0), + bid=tick.get("bid", tick["price"]), + ask=tick.get("ask", tick["price"]), + ) + + quote = { + "symbol": tick["symbol"], + "price": tick["price"], + "bid": tick.get("bid", tick["price"]), + "ask": tick.get("ask", tick["price"]), + "volume": tick.get("volume", 0.0), + "timestamp": timestamp, + "table": table, + } + self._latest_quotes[f"{table}:{tick['symbol']}"] = quote + self._pending_broadcasts[table].append(quote) + + channel = f"{table}:{tick['symbol']}" + quotes_by_channel[channel].append(quote) + + self._stats["ticks_received"] += len(ticks) + + # Batch publish to PubSub channels + if self._pubsub: + for channel, channel_quotes in quotes_by_channel.items(): + await self._pubsub.publish_batch(channel, channel_quotes) + + # Force flush if buffer too large + if len(self._buffers[table]) >= self.max_buffer_size: + await self._flush_table(table) + + async def subscribe(self, websocket: WebSocket, table: str, symbols: Optional[List[str]] = None): + """Add a WebSocket subscriber to a table's updates.""" + sub_id = f"ws_{id(websocket)}" + subscriber = Subscriber( + websocket=websocket, + symbols=set(symbols) if symbols else set(), + subscriber_id=sub_id, + ) + + self._subscribers[table].append(subscriber) + self._stats["active_subscribers"] = sum(len(s) for s in self._subscribers.values()) + + # Send current latest quotes to new subscriber + for key, quote in self._latest_quotes.items(): + if key.startswith(f"{table}:"): + symbol = key.split(":", 1)[1] + if not subscriber.symbols or symbol in subscriber.symbols: + try: + await websocket.send_json(quote) + except Exception: + pass + + logger.info(f"New subscriber for {table}, symbols={symbols or 'all'}") + return subscriber + + async def unsubscribe(self, websocket: WebSocket, table: str): + """Remove a subscriber.""" + self._subscribers[table] = [ + s for s in self._subscribers[table] + if s.websocket != websocket + ] + self._stats["active_subscribers"] = sum(len(s) for s in self._subscribers.values()) + + async def _flush_loop(self): + """Background task to periodically flush buffers.""" + while self._running: + try: + await asyncio.sleep(self.flush_interval) + await self._flush_all() + except asyncio.CancelledError: + raise + except Exception as e: + logger.error(f"Flush error: {e}") + + async def _flush_all(self): + """Flush all buffers to database.""" + with self._lock: + tables = list(self._buffers.keys()) + + for table in tables: + await self._flush_table(table) + + async def _flush_table(self, table: str): + """Flush a single table's buffer to database. + + FIX: Atomic table swap - build new table first, then replace. + The old table remains readable until the swap completes. + """ + if self._db is None: + return + + with self._lock: + buffer = self._buffers[table] + if len(buffer) == 0: + return + + # Get columnar data and clear buffer + data = buffer.to_columnar() + count = len(buffer) + buffer.clear() + + try: + import wayy_db as wdb + + if self._db.has_table(table): + existing = self._db[table] + + # Read existing data + existing_data = {} + for col_name in existing.column_names(): + existing_data[col_name] = existing[col_name].to_numpy() + + # Concatenate + combined = {} + for col_name, new_arr in data.items(): + if col_name in existing_data: + combined[col_name] = np.concatenate([existing_data[col_name], new_arr]) + else: + combined[col_name] = new_arr + + # FIX: Build new table FIRST, then atomic swap + new_table = wdb.from_dict(combined, name=table, sorted_by="timestamp") + self._db.drop_table(table) + self._db.add_table(new_table) + else: + new_table = wdb.from_dict(data, name=table, sorted_by="timestamp") + self._db.add_table(new_table) + + self._db.save() + + self._stats["ticks_flushed"] += count + self._stats["flush_count"] += 1 + + logger.debug(f"Flushed {count} ticks to {table}") + + except Exception as e: + logger.error(f"Failed to flush {table}: {e}") + # Re-add data to buffer on failure + with self._lock: + buf = self._buffers[table] + for i in range(len(data["timestamp"])): + buf.timestamps.append(int(data["timestamp"][i])) + buf.symbols.append(f"unknown") # Symbol hash lost, but data preserved + buf.prices.append(float(data["price"][i])) + buf.volumes.append(float(data["volume"][i])) + buf.bids.append(float(data["bid"][i])) + buf.asks.append(float(data["ask"][i])) + + async def _broadcast_loop(self): + """Background task to batch-broadcast updates to WebSocket subscribers.""" + while self._running: + try: + await asyncio.sleep(self.batch_broadcast_interval) + await self._broadcast_pending() + except asyncio.CancelledError: + raise + except Exception as e: + logger.error(f"Broadcast error: {e}") + + async def _broadcast_pending(self): + """Broadcast pending updates to all subscribers. + + FIX: Uses asyncio.gather for concurrent WebSocket sends. + One slow subscriber no longer blocks all others. + """ + # Swap out pending broadcasts atomically + pending = dict(self._pending_broadcasts) + self._pending_broadcasts = defaultdict(list) + + for table, quotes in pending.items(): + if not quotes: + continue + + subscribers = self._subscribers.get(table, []) + if not subscribers: + continue + + # Build send tasks for all subscribers concurrently + send_tasks = [] + sub_task_map: List[Subscriber] = [] + + for sub in subscribers: + if sub.symbols: + filtered = [q for q in quotes if q["symbol"] in sub.symbols] + else: + filtered = quotes + + if not filtered: + continue + + if len(filtered) == 1: + payload = filtered[0] + else: + payload = {"batch": filtered} + + send_tasks.append(self._safe_send(sub.websocket, payload)) + sub_task_map.append(sub) + + if not send_tasks: + continue + + # FIX: Concurrent sends via asyncio.gather + results = await asyncio.gather(*send_tasks, return_exceptions=True) + + dead_subs = [] + for sub, result in zip(sub_task_map, results): + if isinstance(result, Exception): + dead_subs.append(sub) + else: + count = len(quotes) if not sub.symbols else len( + [q for q in quotes if q["symbol"] in sub.symbols] + ) + sub.messages_sent += count + self._stats["broadcasts_sent"] += count + + # Remove dead subscribers + for sub in dead_subs: + if sub in self._subscribers[table]: + self._subscribers[table].remove(sub) + + @staticmethod + async def _safe_send(websocket: WebSocket, payload: Any) -> None: + """Send JSON to a WebSocket with timeout.""" + await asyncio.wait_for(websocket.send_json(payload), timeout=5.0) + + def get_stats(self) -> Dict[str, Any]: + """Get streaming statistics.""" + stats = { + **self._stats, + "buffer_sizes": {t: len(b) for t, b in self._buffers.items()}, + "subscriber_counts": {t: len(s) for t, s in self._subscribers.items()}, + "latest_quotes": len(self._latest_quotes), + "running": self._running, + } + if self._pubsub: + stats["pubsub"] = self._pubsub.get_stats() + return stats + + def get_latest_quote(self, table: str, symbol: str) -> Optional[Dict[str, Any]]: + """Get the latest quote for a symbol.""" + return self._latest_quotes.get(f"{table}:{symbol}") + + def get_all_quotes(self, table: str) -> Dict[str, Dict[str, Any]]: + """Get all latest quotes for a table.""" + prefix = f"{table}:" + return { + k.split(":", 1)[1]: v + for k, v in self._latest_quotes.items() + if k.startswith(prefix) + } + + +# Global streaming manager instance +_streaming_manager: Optional[StreamingManager] = None + + +def get_streaming_manager() -> StreamingManager: + """Get or create the global streaming manager.""" + global _streaming_manager + if _streaming_manager is None: + redis_url = os.getenv("REDIS_URL", "") + pubsub = create_pubsub(redis_url if redis_url else None) + _streaming_manager = StreamingManager(pubsub=pubsub) + return _streaming_manager + + +async def start_streaming(): + """Start the global streaming manager.""" + manager = get_streaming_manager() + await manager.start() + + +async def stop_streaming(): + """Stop the global streaming manager.""" + global _streaming_manager + if _streaming_manager: + await _streaming_manager.stop() diff --git a/build/_deps/googletest-src b/build/_deps/googletest-src new file mode 160000 index 0000000000000000000000000000000000000000..f8d7d77c06936315286eb55f8de22cd23c188571 --- /dev/null +++ b/build/_deps/googletest-src @@ -0,0 +1 @@ +Subproject commit f8d7d77c06936315286eb55f8de22cd23c188571 diff --git a/build/_deps/pybind11-src b/build/_deps/pybind11-src new file mode 160000 index 0000000000000000000000000000000000000000..a2e59f0e7065404b44dfe92a28aca47ba1378dc4 --- /dev/null +++ b/build/_deps/pybind11-src @@ -0,0 +1 @@ +Subproject commit a2e59f0e7065404b44dfe92a28aca47ba1378dc4 diff --git a/dist/wayy_db-0.1.0-cp310-cp310-linux_x86_64.whl b/dist/wayy_db-0.1.0-cp310-cp310-linux_x86_64.whl new file mode 100644 index 0000000000000000000000000000000000000000..0d0508fd3a507a407c78884e6fba4567afcea06a Binary files /dev/null and b/dist/wayy_db-0.1.0-cp310-cp310-linux_x86_64.whl differ diff --git a/include/wayy_db/column.hpp b/include/wayy_db/column.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5e0cf7236977ed114fea9228e02c88c051d13453 --- /dev/null +++ b/include/wayy_db/column.hpp @@ -0,0 +1,135 @@ +#pragma once + +#include "wayy_db/types.hpp" +#include "wayy_db/column_view.hpp" + +#include +#include +#include +#include +#include + +namespace wayy_db { + +/// Type-erased column that owns its data or references mmap'd memory +class Column { +public: + /// Construct an empty column + Column() = default; + + /// Construct a column with owned data + Column(std::string name, DType dtype, std::vector data); + + /// Construct a column referencing external memory (e.g., mmap) + Column(std::string name, DType dtype, void* data, size_t size, bool owns_data = false); + + /// Move-only semantics + Column(Column&&) = default; + Column& operator=(Column&&) = default; + Column(const Column&) = delete; + Column& operator=(const Column&) = delete; + + /// Column metadata + const std::string& name() const { return name_; } + DType dtype() const { return dtype_; } + size_t size() const { return size_; } + size_t byte_size() const { return size_ * dtype_size(dtype_); } + + /// Raw data access + void* data() { return data_; } + const void* data() const { return data_; } + + /// Typed view access (throws TypeMismatch if wrong type) + template + ColumnView as(); + + template + ColumnView as() const; + + /// Convenience accessors + Int64View as_int64() { return as(); } + Float64View as_float64() { return as(); } + TimestampView as_timestamp() { return as(); } + SymbolView as_symbol() { return as(); } + BoolView as_bool() { return as(); } + + /// Decimal6 accessor (underlying int64, but tagged as Decimal6) + Int64View as_decimal6() { + if (dtype_ != DType::Decimal6) throw TypeMismatch(DType::Decimal6, dtype_); + return ColumnView(static_cast(data_), size_); + } + + /// Validity bitmap (null/deleted tracking) + bool has_validity() const { return has_validity_; } + void ensure_validity(); // Allocate bitmap, mark all valid + bool is_valid(size_t row) const; + void set_valid(size_t row, bool valid); + size_t count_valid() const; // popcount over bitmap + + /// Direct access to validity bitmap bytes (for persistence) + const std::vector& validity_bitmap() const { return validity_; } + void set_validity_bitmap(std::vector bitmap); + + /// Append a single element (column must own its data) + void append(const void* value, size_t value_size); + + /// Overwrite element at row index (column must own its data) + void set(size_t row, const void* value, size_t value_size); + +private: + std::string name_; + DType dtype_ = DType::Int64; + void* data_ = nullptr; + size_t size_ = 0; + bool owns_data_ = false; + std::vector owned_data_; // Storage when we own the data + + // Validity bitmap: 1 bit per row (bit=1 means valid, bit=0 means null/deleted) + std::vector validity_; + bool has_validity_ = false; + + /// Check that the requested type matches the column's dtype + template + void check_type() const; +}; + +// Template implementations + +template +ColumnView Column::as() { + check_type(); + return ColumnView(static_cast(data_), size_); +} + +template +ColumnView Column::as() const { + check_type(); + return ColumnView(static_cast(data_), size_); +} + +template +void Column::check_type() const { + using U = std::remove_cv_t; + DType expected; + if constexpr (std::is_same_v) { + // Could be Int64, Timestamp, or Decimal6 (all stored as int64_t) + if (dtype_ != DType::Int64 && dtype_ != DType::Timestamp && dtype_ != DType::Decimal6) { + throw TypeMismatch(DType::Int64, dtype_); + } + return; + } else if constexpr (std::is_same_v) { + expected = DType::Float64; + } else if constexpr (std::is_same_v) { + expected = DType::Symbol; + } else if constexpr (std::is_same_v) { + expected = DType::Bool; + } else { + static_assert(sizeof(U) == 0, "Unsupported column type"); + } + + if (dtype_ != expected) { + throw TypeMismatch(expected, dtype_); + } +} + +} // namespace wayy_db diff --git a/include/wayy_db/column_view.hpp b/include/wayy_db/column_view.hpp new file mode 100644 index 0000000000000000000000000000000000000000..eedc57a3af1c31f83ab038911795bda8a71bd809 --- /dev/null +++ b/include/wayy_db/column_view.hpp @@ -0,0 +1,93 @@ +#pragma once + +#include +#include +#include +#include + +namespace wayy_db { + +/// Non-owning typed view over contiguous column data +/// Provides zero-copy access for SIMD operations and Python bindings +template +class ColumnView { +public: + using value_type = T; + using size_type = size_t; + using difference_type = ptrdiff_t; + using pointer = T*; + using const_pointer = const T*; + using reference = T&; + using const_reference = const T&; + using iterator = T*; + using const_iterator = const T*; + + /// Construct an empty view + ColumnView() : data_(nullptr), size_(0) {} + + /// Construct a view over existing data + ColumnView(T* data, size_t size) : data_(data), size_(size) {} + + /// Construct from std::span + explicit ColumnView(std::span span) : data_(span.data()), size_(span.size()) {} + + // Element access + reference operator[](size_t i) { return data_[i]; } + const_reference operator[](size_t i) const { return data_[i]; } + + reference at(size_t i) { + if (i >= size_) throw std::out_of_range("ColumnView index out of range"); + return data_[i]; + } + const_reference at(size_t i) const { + if (i >= size_) throw std::out_of_range("ColumnView index out of range"); + return data_[i]; + } + + reference front() { return data_[0]; } + const_reference front() const { return data_[0]; } + + reference back() { return data_[size_ - 1]; } + const_reference back() const { return data_[size_ - 1]; } + + // Iterators + iterator begin() { return data_; } + iterator end() { return data_ + size_; } + const_iterator begin() const { return data_; } + const_iterator end() const { return data_ + size_; } + const_iterator cbegin() const { return data_; } + const_iterator cend() const { return data_ + size_; } + + // Capacity + bool empty() const { return size_ == 0; } + size_t size() const { return size_; } + + // Data access (for Python buffer protocol and SIMD) + T* data() { return data_; } + const T* data() const { return data_; } + + /// Get as std::span for modern C++ APIs + std::span span() { return {data_, size_}; } + std::span span() const { return {data_, size_}; } + + /// Create a subview + ColumnView subview(size_t offset, size_t count) const { + if (offset + count > size_) { + throw std::out_of_range("ColumnView subview out of range"); + } + return ColumnView(const_cast(data_) + offset, count); + } + +private: + T* data_; + size_t size_; +}; + +// Common type aliases +using Int64View = ColumnView; +using Float64View = ColumnView; +using TimestampView = ColumnView; +using SymbolView = ColumnView; +using BoolView = ColumnView; + +} // namespace wayy_db diff --git a/include/wayy_db/database.hpp b/include/wayy_db/database.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f2eb4348def25e455b5f8a47c0ed6a34d6063582 --- /dev/null +++ b/include/wayy_db/database.hpp @@ -0,0 +1,87 @@ +#pragma once + +#include "wayy_db/table.hpp" +#include "wayy_db/wal.hpp" + +#include +#include +#include +#include +#include + +namespace wayy_db { + +/// High-level database interface managing multiple tables +class Database { +public: + /// Create an in-memory database + Database(); + + /// Create or open a persistent database at the given path + explicit Database(const std::string& path); + + /// Move-only semantics + Database(Database&&) = default; + Database& operator=(Database&&) = default; + Database(const Database&) = delete; + Database& operator=(const Database&) = delete; + + ~Database() = default; + + /// Database path (empty for in-memory) + const std::string& path() const { return path_; } + + /// Check if database is persistent + bool is_persistent() const { return !path_.empty(); } + + /// List all table names + std::vector tables() const; + + /// Check if a table exists + bool has_table(const std::string& name) const; + + /// Get a table by name (loads from disk if persistent and not cached) + Table& table(const std::string& name); + Table& operator[](const std::string& name) { return table(name); } + + /// Create a new table + Table& create_table(const std::string& name); + + /// Add an existing table to the database + void add_table(Table table); + + /// Drop a table (removes from disk if persistent) + void drop_table(const std::string& name); + + /// Save all modified tables to disk (no-op for in-memory) + void save(); + + /// Reload table list from disk + void refresh(); + + /// WAL: checkpoint (flush WAL, save tables, truncate WAL) + void checkpoint(); + + /// WAL: get access to WAL for logging (may be null for in-memory DB) + WriteAheadLog* wal() { return wal_.get(); } + +private: + std::string path_; + std::unordered_map tables_; + std::unordered_map loaded_; // Track which tables are loaded + + // Write-ahead log (persistent databases only) + std::unique_ptr wal_; + + // Mutex for thread-safe access (mutable allows const methods to lock) + // Uses shared_mutex for concurrent reads, exclusive writes + mutable std::shared_mutex mutex_; + + /// Get the directory path for a table + std::string table_path(const std::string& name) const; + + /// Scan directory for existing tables + void scan_tables(); +}; + +} // namespace wayy_db diff --git a/include/wayy_db/hash_index.hpp b/include/wayy_db/hash_index.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2bcb03f7e2701a64ff70d4924fc17342d0fd4e43 --- /dev/null +++ b/include/wayy_db/hash_index.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace wayy_db { + +// Forward declarations +class Table; + +/// Hash-based primary key index supporting both int64 and string keys. +class HashIndex { +public: + HashIndex() = default; + + /// Build index from table column + void build_int(const Table& table, const std::string& col_name); + void build_str(const Table& table, const std::string& col_name); + + /// Lookup + std::optional find_int(int64_t key) const; + std::optional find_str(std::string_view key) const; + + /// Insert + void insert_int(int64_t key, size_t row); + void insert_str(std::string_view key, size_t row); + + /// Remove + void remove_int(int64_t key); + void remove_str(std::string_view key); + + /// Clear + void clear(); + + /// Size + size_t size() const { return int_map_.size() + str_map_.size(); } + +private: + std::unordered_map int_map_; + std::unordered_map str_map_; +}; + +} // namespace wayy_db diff --git a/include/wayy_db/mmap_file.hpp b/include/wayy_db/mmap_file.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bd7175b1ea3c019f9d5b6b50a10b095df5cb7fe2 --- /dev/null +++ b/include/wayy_db/mmap_file.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +namespace wayy_db { + +/// Memory-mapped file abstraction +/// Provides platform-independent mmap operations for zero-copy I/O +class MmapFile { +public: + enum class Mode { + ReadOnly, + ReadWrite, + Create, // Create or truncate + }; + + /// Construct without opening + MmapFile() = default; + + /// Open and map a file + explicit MmapFile(const std::string& path, Mode mode = Mode::ReadOnly, + size_t size = 0); + + /// Move-only semantics + MmapFile(MmapFile&& other) noexcept; + MmapFile& operator=(MmapFile&& other) noexcept; + MmapFile(const MmapFile&) = delete; + MmapFile& operator=(const MmapFile&) = delete; + + ~MmapFile(); + + /// Open a file for mapping + void open(const std::string& path, Mode mode = Mode::ReadOnly, + size_t size = 0); + + /// Close and unmap the file + void close(); + + /// Check if file is open + bool is_open() const { return data_ != nullptr; } + + /// Get mapped memory + void* data() { return data_; } + const void* data() const { return data_; } + + /// Get mapped size + size_t size() const { return size_; } + + /// Get file path + const std::string& path() const { return path_; } + + /// Sync changes to disk (for ReadWrite/Create modes) + void sync(); + + /// Resize the mapping (only for Create mode, extends file) + void resize(size_t new_size); + +private: + std::string path_; + void* data_ = nullptr; + size_t size_ = 0; + Mode mode_ = Mode::ReadOnly; + int fd_ = -1; // File descriptor (POSIX) +}; + +} // namespace wayy_db diff --git a/include/wayy_db/ops/aggregations.hpp b/include/wayy_db/ops/aggregations.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f8af7863ada3a8f330b5117d0517b44631408222 --- /dev/null +++ b/include/wayy_db/ops/aggregations.hpp @@ -0,0 +1,69 @@ +#pragma once + +#include "wayy_db/column_view.hpp" +#include "wayy_db/column.hpp" + +#include +#include + +namespace wayy_db::ops { + +/// Sum of all values in a column +template +T sum(const ColumnView& col); + +/// SIMD-optimized sum for float64 +double sum_simd(const ColumnView& col); +int64_t sum_simd(const ColumnView& col); + +/// Mean (average) of all values +template +double avg(const ColumnView& col) { + if (col.empty()) return std::numeric_limits::quiet_NaN(); + return static_cast(sum(col)) / static_cast(col.size()); +} + +/// Minimum value +template +T min(const ColumnView& col); + +/// Maximum value +template +T max(const ColumnView& col); + +/// Standard deviation (population) +template +double std_dev(const ColumnView& col); + +/// Variance (population) +template +double variance(const ColumnView& col); + +/// Count non-null values (for future nullable support) +template +size_t count(const ColumnView& col) { + return col.size(); +} + +/// First value +template +T first(const ColumnView& col) { + if (col.empty()) throw InvalidOperation("first() on empty column"); + return col.front(); +} + +/// Last value +template +T last(const ColumnView& col) { + if (col.empty()) throw InvalidOperation("last() on empty column"); + return col.back(); +} + +// Type-erased aggregations on Column objects +double sum(const Column& col); +double avg(const Column& col); +double min_val(const Column& col); +double max_val(const Column& col); +double std_dev(const Column& col); + +} // namespace wayy_db::ops diff --git a/include/wayy_db/ops/joins.hpp b/include/wayy_db/ops/joins.hpp new file mode 100644 index 0000000000000000000000000000000000000000..08ec3a28e1ce6f6d93131ef548f75c959ec223a5 --- /dev/null +++ b/include/wayy_db/ops/joins.hpp @@ -0,0 +1,48 @@ +#pragma once + +#include "wayy_db/table.hpp" + +#include +#include + +namespace wayy_db::ops { + +/// As-of join: for each row in left, find the most recent row in right +/// where right.as_of <= left.as_of and join keys match +/// +/// Both tables must be sorted by the as_of column +/// +/// @param left Left table (e.g., trades) +/// @param right Right table (e.g., quotes) +/// @param on Join key columns (e.g., ["symbol"]) +/// @param as_of Temporal column name (e.g., "timestamp") +/// @return Joined table with columns from both tables +Table aj(const Table& left, const Table& right, + const std::vector& on, + const std::string& as_of); + +/// Window join: for each row in left, find all rows in right +/// within the specified time window +/// +/// @param left Left table +/// @param right Right table +/// @param on Join key columns +/// @param as_of Temporal column name +/// @param window_before Nanoseconds before left.as_of to include +/// @param window_after Nanoseconds after left.as_of to include +/// @return Joined table (may have more rows than left due to multiple matches) +Table wj(const Table& left, const Table& right, + const std::vector& on, + const std::string& as_of, + int64_t window_before, + int64_t window_after); + +/// Inner join on specified columns +Table inner_join(const Table& left, const Table& right, + const std::vector& on); + +/// Left join on specified columns +Table left_join(const Table& left, const Table& right, + const std::vector& on); + +} // namespace wayy_db::ops diff --git a/include/wayy_db/ops/window.hpp b/include/wayy_db/ops/window.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e541dbc32ae5d3664d410ae9b9fccce2296dfa62 --- /dev/null +++ b/include/wayy_db/ops/window.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include "wayy_db/column_view.hpp" + +#include + +namespace wayy_db::ops { + +/// Moving average over a sliding window +/// @param col Input column +/// @param window Window size +/// @return Vector of moving averages (first window-1 values are partial averages) +std::vector mavg(const ColumnView& col, size_t window); +std::vector mavg(const ColumnView& col, size_t window); + +/// Moving sum over a sliding window +std::vector msum(const ColumnView& col, size_t window); +std::vector msum(const ColumnView& col, size_t window); + +/// Moving standard deviation over a sliding window +std::vector mstd(const ColumnView& col, size_t window); +std::vector mstd(const ColumnView& col, size_t window); + +/// Moving minimum over a sliding window (O(n) using monotonic deque) +std::vector mmin(const ColumnView& col, size_t window); +std::vector mmin(const ColumnView& col, size_t window); + +/// Moving maximum over a sliding window (O(n) using monotonic deque) +std::vector mmax(const ColumnView& col, size_t window); +std::vector mmax(const ColumnView& col, size_t window); + +/// Exponential moving average +/// @param col Input column +/// @param alpha Smoothing factor (0 < alpha <= 1) +/// @return Vector of EMA values +std::vector ema(const ColumnView& col, double alpha); +std::vector ema(const ColumnView& col, double alpha); + +/// Exponential moving average with span +/// alpha = 2 / (span + 1) +std::vector ema_span(const ColumnView& col, size_t span); + +/// Diff: difference between consecutive values +std::vector diff(const ColumnView& col, size_t periods = 1); +std::vector diff(const ColumnView& col, size_t periods = 1); + +/// Percent change between consecutive values +std::vector pct_change(const ColumnView& col, size_t periods = 1); + +/// Shift values by n positions (positive = forward, negative = backward) +std::vector shift(const ColumnView& col, int64_t n); +std::vector shift(const ColumnView& col, int64_t n); + +} // namespace wayy_db::ops diff --git a/include/wayy_db/string_column.hpp b/include/wayy_db/string_column.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bc0e69d1950d3108f7156d1e07d694cc98692530 --- /dev/null +++ b/include/wayy_db/string_column.hpp @@ -0,0 +1,79 @@ +#pragma once + +#include "wayy_db/types.hpp" + +#include +#include +#include +#include +#include + +namespace wayy_db { + +/// Arrow-style variable-length string column. +/// Storage layout: +/// offsets_: int64_t[N+1] — byte offsets into data_ +/// data_: uint8_t[] — concatenated UTF-8 bytes +/// validity_: uint8_t[] — 1 bit per row (bit=1 valid, bit=0 null) +/// +/// String at row i = data_[offsets_[i] .. offsets_[i+1]] +class StringColumn { +public: + /// Construct an empty string column + explicit StringColumn(std::string name = ""); + + /// Move-only semantics + StringColumn(StringColumn&&) = default; + StringColumn& operator=(StringColumn&&) = default; + StringColumn(const StringColumn&) = delete; + StringColumn& operator=(const StringColumn&) = delete; + + /// Column metadata + const std::string& name() const { return name_; } + DType dtype() const { return DType::String; } + size_t size() const { return offsets_.empty() ? 0 : offsets_.size() - 1; } + size_t data_bytes() const { return data_.size(); } + + /// Read a string at the given row + std::string_view get(size_t row) const; + + /// Append a new string + void append(std::string_view val); + + /// Append a null value + void append_null(); + + /// Overwrite the string at a given row. + /// If the new string fits in the existing slot, it's written in-place. + /// Otherwise, old slot is wasted and the new value is appended to data_. + void set(size_t row, std::string_view val); + + /// Validity bitmap + bool has_validity() const { return has_validity_; } + bool is_valid(size_t row) const; + void set_valid(size_t row, bool valid); + size_t count_valid() const; + + /// Persistence + void save(const std::string& dir_path, const std::string& col_name) const; + static StringColumn load(const std::string& dir_path, const std::string& col_name); + + /// Direct access for bulk operations + const std::vector& offsets() const { return offsets_; } + const std::vector& data_buf() const { return data_; } + const std::vector& validity_bitmap() const { return validity_; } + + /// Collect all strings as a vector (copy) + std::vector to_vector() const; + +private: + std::string name_; + std::vector offsets_; // N+1 offsets + std::vector data_; // Concatenated UTF-8 bytes + std::vector validity_; // Null bitmap + bool has_validity_ = false; + + void ensure_validity(); +}; + +} // namespace wayy_db diff --git a/include/wayy_db/table.hpp b/include/wayy_db/table.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c7b242861e820716c683ea2a4f542e8981549a43 --- /dev/null +++ b/include/wayy_db/table.hpp @@ -0,0 +1,133 @@ +#pragma once + +#include "wayy_db/types.hpp" +#include "wayy_db/column.hpp" +#include "wayy_db/string_column.hpp" +#include "wayy_db/mmap_file.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace wayy_db { + +// Forward declarations +class HashIndex; + +/// Columnar table with optional sorted index, OLTP capabilities, +/// and per-table reader-writer locking. +class Table { +public: + /// Construct an empty table + explicit Table(std::string name = ""); + + /// Move-only semantics (shared_mutex is non-movable, so custom move ctor) + Table(Table&& other) noexcept; + Table& operator=(Table&& other) noexcept; + Table(const Table&) = delete; + Table& operator=(const Table&) = delete; + ~Table(); + + /// Table metadata + const std::string& name() const { return name_; } + size_t num_rows() const { return num_rows_; } + size_t num_columns() const { return columns_.size() + string_columns_.size(); } + + /// Per-table reader-writer lock + auto read_lock() const { return std::shared_lock(mu_); } + auto write_lock() { return std::unique_lock(mu_); } + + /// Column management (fixed-width columns) + void add_column(Column column); + void add_column(const std::string& name, DType dtype, void* data, size_t size); + + /// String column management + void add_string_column(StringColumn col); + bool has_string_column(const std::string& name) const; + StringColumn& string_column(const std::string& name); + const StringColumn& string_column(const std::string& name) const; + + bool has_column(const std::string& name) const; + Column& column(const std::string& name); + const Column& column(const std::string& name) const; + Column& operator[](const std::string& name) { return column(name); } + const Column& operator[](const std::string& name) const { return column(name); } + + /// Get the DType of any column (fixed or string) + DType column_dtype(const std::string& name) const; + + std::vector column_names() const; + + /// Sorted index (critical for temporal joins) + void set_sorted_by(const std::string& col); + std::optional sorted_by() const { return sorted_by_; } + bool is_sorted() const { return sorted_by_.has_value(); } + + /// Primary key + hash index + void set_primary_key(const std::string& col_name); + const std::optional& primary_key() const { return primary_key_; } + std::optional find_row(int64_t key) const; + std::optional find_row(std::string_view key) const; + void rebuild_index(); + + /// CRUD operations + size_t append_row(const std::unordered_map& values); + bool update_row(int64_t pk, const std::unordered_map& values); + bool update_row(std::string_view pk, const std::unordered_map& values); + bool delete_row(int64_t pk); + bool delete_row(std::string_view pk); + + /// Filter: returns vector of row indices matching predicate + std::vector where_eq(const std::string& col, int64_t val) const; + std::vector where_eq(const std::string& col, std::string_view val) const; + + /// Compaction: physically remove deleted rows, rebuild index + void compact(); + + /// Persistence + void save(const std::string& dir_path) const; + static Table load(const std::string& dir_path); + + /// Create from memory-mapped directory (zero-copy) + static Table mmap(const std::string& dir_path); + +private: + std::string name_; + size_t num_rows_ = 0; + std::vector columns_; + std::unordered_map column_index_; + std::optional sorted_by_; + + // String columns (separate storage) + std::vector string_columns_; + std::unordered_map string_column_index_; + + // Primary key + hash index + std::optional primary_key_; + std::unique_ptr pk_index_; + + // Per-table reader-writer lock + mutable std::shared_mutex mu_; + + // For mmap'd tables, keep file handles alive + std::vector mmap_files_; + + /// Write metadata JSON + void write_metadata(const std::string& dir_path) const; + + /// Read metadata JSON and return column info + static std::tuple, + std::optional, + std::vector>> + read_metadata(const std::string& dir_path); + + /// Internal row update by row index (no PK lookup) + bool update_row_at(size_t row_idx, const std::unordered_map& values); +}; + +} // namespace wayy_db diff --git a/include/wayy_db/types.hpp b/include/wayy_db/types.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4e040e211bc361084d7a23d6bc3511e0cf131505 --- /dev/null +++ b/include/wayy_db/types.hpp @@ -0,0 +1,100 @@ +#pragma once + +#include +#include +#include +#include + +namespace wayy_db { + +/// Supported data types for columns +enum class DType : uint8_t { + Int64 = 0, + Float64 = 1, + Timestamp = 2, // Nanoseconds since Unix epoch + Symbol = 3, // Interned string index + Bool = 4, + String = 5, // Arrow-style variable-length UTF-8 (offsets + data) + Decimal6 = 6, // Int64 with implied 6 decimal places (max ±9.2T) +}; + +/// Get the size in bytes for a given type (0 for variable-length types) +constexpr size_t dtype_size(DType dtype) { + switch (dtype) { + case DType::Int64: return sizeof(int64_t); + case DType::Float64: return sizeof(double); + case DType::Timestamp: return sizeof(int64_t); + case DType::Symbol: return sizeof(uint32_t); + case DType::Bool: return sizeof(uint8_t); + case DType::String: return 0; // Variable-length, use StringColumn + case DType::Decimal6: return sizeof(int64_t); // Stored as int64 + } + return 0; // Unreachable +} + +/// Check if a dtype is fixed-width +constexpr bool dtype_is_fixed(DType dtype) { + return dtype != DType::String; +} + +/// Convert DType to string representation +constexpr std::string_view dtype_to_string(DType dtype) { + switch (dtype) { + case DType::Int64: return "int64"; + case DType::Float64: return "float64"; + case DType::Timestamp: return "timestamp"; + case DType::Symbol: return "symbol"; + case DType::Bool: return "bool"; + case DType::String: return "string"; + case DType::Decimal6: return "decimal6"; + } + return "unknown"; +} + +/// Parse DType from string +DType dtype_from_string(std::string_view s); + +/// Magic number for WayyDB files: "WAYYDB\x00\x01" +constexpr uint64_t WAYY_MAGIC = 0x57415959'44420001ULL; + +/// Current file format version +constexpr uint32_t WAYY_VERSION = 1; + +/// Column file header (64 bytes) +struct ColumnHeader { + uint64_t magic; // WAYY_MAGIC + uint32_t version; // WAYY_VERSION + DType dtype; // Data type + uint8_t reserved1[3]; // Padding + uint64_t row_count; // Number of rows + uint64_t compression; // 0 = none, 1 = LZ4 + uint8_t reserved2[24]; // Reserved for future use + uint64_t data_offset; // Offset to data (typically 64) +}; + +static_assert(sizeof(ColumnHeader) == 64, "ColumnHeader must be 64 bytes"); + +/// Exception types +class WayyException : public std::runtime_error { + using std::runtime_error::runtime_error; +}; + +class ColumnNotFound : public WayyException { +public: + explicit ColumnNotFound(const std::string& name) + : WayyException("Column not found: " + name) {} +}; + +class TypeMismatch : public WayyException { +public: + TypeMismatch(DType expected, DType actual) + : WayyException("Type mismatch: expected " + + std::string(dtype_to_string(expected)) + + ", got " + std::string(dtype_to_string(actual))) {} +}; + +class InvalidOperation : public WayyException { + using WayyException::WayyException; +}; + +} // namespace wayy_db diff --git a/include/wayy_db/wal.hpp b/include/wayy_db/wal.hpp new file mode 100644 index 0000000000000000000000000000000000000000..60f2cd26e9a5417a774fcd7df3c2d4e78effd619 --- /dev/null +++ b/include/wayy_db/wal.hpp @@ -0,0 +1,78 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace wayy_db { + +// Forward declaration +class Database; + +/// WAL operation types +enum class WalOp : uint8_t { + Insert = 1, + Update = 2, + Delete = 3, +}; + +/// WAL magic number +constexpr uint32_t WAL_MAGIC = 0x57414C01; // "WAL\x01" + +/// Binary WAL entry format: +/// [4B magic][1B op_type][4B table_name_len][table_name] +/// [8B row_id][4B payload_len][payload][4B CRC32] +/// +/// For Insert: payload = serialized row (col_name:type:data pairs) +/// For Update: payload = serialized partial row (only changed columns) +/// For Delete: payload = empty + +class WriteAheadLog { +public: + /// Create or open a WAL at the given directory + explicit WriteAheadLog(const std::string& db_path); + + ~WriteAheadLog(); + + /// Log an insert operation + void log_insert(const std::string& table, size_t row, + const std::vector& data); + + /// Log an update operation + void log_update(const std::string& table, size_t row, + const std::string& col, const std::vector& data); + + /// Log a delete operation + void log_delete(const std::string& table, size_t row); + + /// Checkpoint: flush WAL, save all tables, truncate WAL + void checkpoint(Database& db); + + /// Replay WAL entries to recover state after crash + void replay(Database& db); + + /// Check if WAL has unprocessed entries + bool has_entries() const; + + /// Get WAL file path + const std::string& path() const { return path_; } + +private: + std::string path_; + std::ofstream file_; + mutable std::mutex mu_; + + /// Write a raw entry to the WAL file + void write_entry(WalOp op, const std::string& table, size_t row, + const std::vector& payload); + + /// Compute CRC32 over buffer + static uint32_t crc32(const uint8_t* data, size_t len); + + /// Open WAL file for appending + void open_for_append(); +}; + +} // namespace wayy_db diff --git a/include/wayy_db/wayy_db.hpp b/include/wayy_db/wayy_db.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8cdddac8ab4a90a4067bf55a336e15db40d6e28b --- /dev/null +++ b/include/wayy_db/wayy_db.hpp @@ -0,0 +1,16 @@ +#pragma once + +/// Main header that includes all WayyDB components + +#include "wayy_db/types.hpp" +#include "wayy_db/column_view.hpp" +#include "wayy_db/column.hpp" +#include "wayy_db/string_column.hpp" +#include "wayy_db/hash_index.hpp" +#include "wayy_db/table.hpp" +#include "wayy_db/wal.hpp" +#include "wayy_db/database.hpp" +#include "wayy_db/mmap_file.hpp" +#include "wayy_db/ops/aggregations.hpp" +#include "wayy_db/ops/joins.hpp" +#include "wayy_db/ops/window.hpp" diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..d86457ca45bd87c12e3d4397ca1503b696084975 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,127 @@ +[build-system] +requires = ["scikit-build-core>=0.5", "pybind11>=2.13"] +build-backend = "scikit_build_core.build" + +[project] +name = "wayy-db" +version = "0.1.0" +description = "High-performance columnar time-series database with kdb+-like functionality" +readme = "README.md" +requires-python = ">=3.9" +license = {text = "MIT"} +authors = [ + {name = "Wayy Research", email = "dev@wayy.io"} +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Financial and Insurance Industry", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: C++", + "Topic :: Database", + "Topic :: Scientific/Engineering :: Information Analysis", +] +keywords = [ + "database", + "time-series", + "columnar", + "kdb", + "as-of-join", + "quantitative-finance", + "trading", + "numpy", + "high-performance", +] +dependencies = [ + "numpy>=1.20", +] + +[project.optional-dependencies] +cli = [ + "typer>=0.9", + "httpx>=0.25", + "websockets>=12.0", + "rich>=13.0", +] +api = [ + "fastapi>=0.109.0", + "uvicorn[standard]>=0.27.0", + "pydantic>=2.0", + "websockets>=12.0", + "redis[hiredis]>=5.0", +] +dev = [ + "pytest>=7.0", + "pytest-cov", + "pytest-asyncio>=0.21", + "httpx>=0.25", + "pandas>=2.0", + "polars>=0.20", + "hypothesis>=6.0", + "mypy>=1.0", + "ruff>=0.1", +] +bench = [ + "pandas>=2.0", + "polars>=0.20", + "duckdb>=0.9", + "psutil>=5.0", + "pytest-benchmark", + "memory-profiler", +] +docs = [ + "sphinx>=7.0", + "sphinx-rtd-theme", + "myst-parser", +] + +[project.scripts] +wayy = "wayy_db.cli.main:app" +wayy-db-bench = "benchmarks.benchmark:main" + +[project.urls] +Homepage = "https://github.com/wayy-research/wayydb" +Documentation = "https://wayydb.readthedocs.io" + +[tool.scikit-build] +cmake.args = ["-DWAYY_BUILD_PYTHON=ON", "-DWAYY_BUILD_TESTS=OFF"] +wheel.packages = ["python/wayy_db"] + +[tool.cibuildwheel] +build-verbosity = 1 +# Build for Python 3.9-3.13, including free-threaded 3.13 +build = "cp39-* cp310-* cp311-* cp312-* cp313-* cp313t-*" +skip = "*-musllinux_* *-win32 *-manylinux_i686" + +# Free-threaded Python 3.13 (no-GIL) configuration +[tool.cibuildwheel.free-threaded] +# Enable free-threaded builds on all platforms +build = "cp313t-*" + +[[tool.cibuildwheel.overrides]] +# For free-threaded builds, ensure we're using the right Python +select = "cp313t-*" +inherit.environment = "append" + +[tool.pytest.ini_options] +testpaths = ["tests/python"] +python_files = ["test_*.py"] +addopts = "-v --tb=short" +asyncio_mode = "strict" + +[tool.ruff] +target-version = "py39" +line-length = 100 + +[tool.ruff.lint] +select = ["E", "F", "W", "I", "UP", "B", "C4", "SIM"] + +[tool.mypy] +python_version = "3.9" +warn_return_any = true +warn_unused_configs = true diff --git a/python/bindings.cpp b/python/bindings.cpp new file mode 100644 index 0000000000000000000000000000000000000000..de327192ad0840749dd1d52d5ff31247b865a674 --- /dev/null +++ b/python/bindings.cpp @@ -0,0 +1,377 @@ +#include +#include +#include + +#include "wayy_db/wayy_db.hpp" + +#include + +namespace py = pybind11; + +// GIL release guard for concurrent read operations +using release_gil = py::call_guard; + +using namespace wayy_db; + +// Namespace alias to avoid collision with local variable +namespace wdb_ops = wayy_db::ops; + +// Helper to convert numpy dtype to WayyDB DType +DType numpy_dtype_to_wayy(py::dtype dt) { + if (dt.is(py::dtype::of())) return DType::Int64; + if (dt.is(py::dtype::of())) return DType::Float64; + if (dt.is(py::dtype::of())) return DType::Symbol; + if (dt.is(py::dtype::of())) return DType::Bool; + throw std::runtime_error("Unsupported numpy dtype"); +} + +// Helper to get numpy dtype from WayyDB DType +py::dtype wayy_dtype_to_numpy(DType dt) { + switch (dt) { + case DType::Int64: + case DType::Timestamp: + case DType::Decimal6: + return py::dtype::of(); + case DType::Float64: + return py::dtype::of(); + case DType::Symbol: + return py::dtype::of(); + case DType::Bool: + return py::dtype::of(); + case DType::String: + throw std::runtime_error("String columns use StringColumn, not numpy"); + } + throw std::runtime_error("Unknown dtype"); +} + +// Helper: convert Python dict to std::unordered_map +std::unordered_map py_dict_to_any_map( + py::dict d, Table& table) { + std::unordered_map result; + for (auto& [key, val] : d) { + std::string col_name = py::str(key); + DType dt = table.column_dtype(col_name); + + if (dt == DType::String) { + result[col_name] = std::string(py::str(val)); + } else if (dt == DType::Int64 || dt == DType::Timestamp || dt == DType::Decimal6) { + result[col_name] = py::cast(val); + } else if (dt == DType::Float64) { + result[col_name] = py::cast(val); + } else if (dt == DType::Symbol) { + result[col_name] = py::cast(val); + } else if (dt == DType::Bool) { + result[col_name] = py::cast(val); + } + } + return result; +} + +PYBIND11_MODULE(_core, m, py::mod_gil_not_used()) { + m.doc() = "WayyDB: High-performance columnar time-series database (free-threading safe)"; + + // DType enum + py::enum_(m, "DType") + .value("Int64", DType::Int64) + .value("Float64", DType::Float64) + .value("Timestamp", DType::Timestamp) + .value("Symbol", DType::Symbol) + .value("Bool", DType::Bool) + .value("String", DType::String) + .value("Decimal6", DType::Decimal6) + .export_values(); + + // Exceptions + py::register_exception(m, "WayyException"); + py::register_exception(m, "ColumnNotFound"); + py::register_exception(m, "TypeMismatch"); + py::register_exception(m, "InvalidOperation"); + + // Column class + py::class_(m, "Column") + .def_property_readonly("name", &Column::name) + .def_property_readonly("dtype", &Column::dtype) + .def_property_readonly("size", &Column::size) + .def("__len__", &Column::size) + .def("to_numpy", [](Column& self) -> py::array { + py::dtype dt = wayy_dtype_to_numpy(self.dtype()); + return py::array(dt, {self.size()}, {dtype_size(self.dtype())}, + self.data(), py::cast(self)); + }, py::return_value_policy::reference_internal, + "Zero-copy view as numpy array") + .def("is_valid", &Column::is_valid, py::arg("row"), + "Check if row is valid (not null/deleted)") + .def("count_valid", &Column::count_valid, + "Count non-null/non-deleted rows"); + + // StringColumn class + py::class_(m, "StringColumn") + .def(py::init(), py::arg("name") = "") + .def_property_readonly("name", &StringColumn::name) + .def_property_readonly("dtype", &StringColumn::dtype) + .def_property_readonly("size", &StringColumn::size) + .def("__len__", &StringColumn::size) + .def("get", &StringColumn::get, py::arg("row"), + "Get string at row index") + .def("append", &StringColumn::append, py::arg("val"), + "Append a string value") + .def("set", &StringColumn::set, py::arg("row"), py::arg("val"), + "Set string at row index") + .def("is_valid", &StringColumn::is_valid, py::arg("row")) + .def("count_valid", &StringColumn::count_valid) + .def("to_list", &StringColumn::to_vector, + "Get all strings as a Python list"); + + // Table class + py::class_(m, "Table") + .def(py::init(), py::arg("name") = "") + .def_property_readonly("name", &Table::name) + .def_property_readonly("num_rows", &Table::num_rows) + .def_property_readonly("num_columns", &Table::num_columns) + .def_property_readonly("sorted_by", [](const Table& t) -> py::object { + if (t.sorted_by()) return py::cast(*t.sorted_by()); + return py::none(); + }) + .def_property_readonly("primary_key", [](const Table& t) -> py::object { + if (t.primary_key()) return py::cast(*t.primary_key()); + return py::none(); + }) + .def("__len__", &Table::num_rows) + .def("has_column", &Table::has_column) + .def("column", py::overload_cast(&Table::column), + py::return_value_policy::reference_internal) + .def("__getitem__", py::overload_cast(&Table::column), + py::return_value_policy::reference_internal) + .def("has_string_column", &Table::has_string_column) + .def("string_column", py::overload_cast(&Table::string_column), + py::return_value_policy::reference_internal) + .def("column_dtype", &Table::column_dtype, py::arg("name"), + "Get the DType of any column (fixed or string)") + .def("column_names", &Table::column_names) + .def("set_sorted_by", &Table::set_sorted_by) + .def("set_primary_key", &Table::set_primary_key, py::arg("col_name"), + "Set the primary key column and build hash index") + .def("rebuild_index", &Table::rebuild_index, + "Rebuild the primary key hash index") + // CRUD operations + .def("append_row", [](Table& self, py::dict values) -> size_t { + auto map = py_dict_to_any_map(values, self); + return self.append_row(map); + }, py::arg("values"), "Append a row from a dict, returns row index") + .def("update_row", [](Table& self, py::object pk, py::dict values) -> bool { + auto map = py_dict_to_any_map(values, self); + if (py::isinstance(pk)) { + return self.update_row(py::cast(pk), map); + } else { + return self.update_row(std::string(py::str(pk)), map); + } + }, py::arg("pk"), py::arg("values"), "Update row by primary key") + .def("delete_row", [](Table& self, py::object pk) -> bool { + if (py::isinstance(pk)) { + return self.delete_row(py::cast(pk)); + } else { + return self.delete_row(std::string(py::str(pk))); + } + }, py::arg("pk"), "Soft-delete row by primary key") + .def("find_row", [](const Table& self, py::object pk) -> py::object { + std::optional row; + if (py::isinstance(pk)) { + row = self.find_row(py::cast(pk)); + } else { + row = self.find_row(std::string(py::str(pk))); + } + if (row) return py::cast(*row); + return py::none(); + }, py::arg("pk"), "Find row index by primary key") + .def("where_eq", [](const Table& self, const std::string& col, py::object val) -> py::list { + std::vector rows; + DType dt = self.column_dtype(col); + if (dt == DType::String) { + rows = self.where_eq(col, std::string(py::str(val))); + } else { + rows = self.where_eq(col, py::cast(val)); + } + py::list result; + for (auto r : rows) result.append(r); + return result; + }, py::arg("col"), py::arg("val"), "Filter rows where col == val") + .def("compact", &Table::compact, + "Physically remove deleted rows and rebuild index") + .def("save", &Table::save) + .def_static("load", &Table::load) + .def_static("mmap", &Table::mmap) + .def("add_column_from_numpy", [](Table& self, const std::string& name, + py::array arr, DType dtype) { + py::buffer_info buf = arr.request(); + if (buf.ndim != 1) { + throw std::runtime_error("Array must be 1-dimensional"); + } + // Copy data into owned buffer + size_t elem_size = dtype_size(dtype); + std::vector data(buf.size * elem_size); + std::memcpy(data.data(), buf.ptr, data.size()); + self.add_column(Column(name, dtype, std::move(data))); + }, py::arg("name"), py::arg("array"), py::arg("dtype")) + .def("add_string_column_from_list", [](Table& self, const std::string& name, + py::list strings) { + StringColumn sc(name); + for (auto& item : strings) { + if (item.is_none()) { + sc.append_null(); + } else { + sc.append(std::string(py::str(item))); + } + } + self.add_string_column(std::move(sc)); + }, py::arg("name"), py::arg("strings"), + "Add a string column from a Python list") + .def("to_dict", [](Table& self) -> py::dict { + py::dict result; + for (const auto& col_name : self.column_names()) { + if (self.has_string_column(col_name)) { + auto& scol = self.string_column(col_name); + result[py::cast(col_name)] = py::cast(scol.to_vector()); + } else { + Column& col = self.column(col_name); + py::dtype dt = wayy_dtype_to_numpy(col.dtype()); + // Make a copy for the dict + py::array arr(dt, {col.size()}, {dtype_size(col.dtype())}, col.data()); + result[py::cast(col_name)] = arr.attr("copy")(); + } + } + return result; + }); + + // Database class + py::class_(m, "Database") + .def(py::init<>()) + .def(py::init(), py::arg("path")) + .def_property_readonly("path", &Database::path) + .def_property_readonly("is_persistent", &Database::is_persistent) + .def("tables", &Database::tables) + .def("has_table", &Database::has_table) + .def("table", &Database::table, py::return_value_policy::reference_internal) + .def("__getitem__", &Database::table, py::return_value_policy::reference_internal) + .def("create_table", &Database::create_table, py::return_value_policy::reference_internal) + .def("add_table", [](Database& db, Table& table) { + db.add_table(std::move(table)); + }) + .def("drop_table", &Database::drop_table) + .def("save", &Database::save) + .def("refresh", &Database::refresh) + .def("checkpoint", &Database::checkpoint, + "Flush WAL, save all tables, truncate WAL"); + + // Operations submodule + py::module_ ops_mod = m.def_submodule("ops", "WayyDB operations"); + + // Aggregations - use lambdas to avoid overload issues + // All aggregations release the GIL for concurrent execution + ops_mod.def("sum", [](const Column& col) { return wdb_ops::sum(col); }, + py::arg("col"), release_gil(), "Sum of column values"); + ops_mod.def("avg", [](const Column& col) { return wdb_ops::avg(col); }, + py::arg("col"), release_gil(), "Average of column values"); + ops_mod.def("min", [](const Column& col) { return wdb_ops::min_val(col); }, + py::arg("col"), release_gil(), "Minimum value"); + ops_mod.def("max", [](const Column& col) { return wdb_ops::max_val(col); }, + py::arg("col"), release_gil(), "Maximum value"); + ops_mod.def("std", [](const Column& col) { return wdb_ops::std_dev(col); }, + py::arg("col"), release_gil(), "Standard deviation"); + + // Joins - release GIL for concurrent execution + ops_mod.def("aj", &wdb_ops::aj, + py::arg("left"), py::arg("right"), py::arg("on"), py::arg("as_of"), + release_gil(), + "As-of join: find most recent right row for each left row"); + ops_mod.def("wj", &wdb_ops::wj, + py::arg("left"), py::arg("right"), py::arg("on"), py::arg("as_of"), + py::arg("window_before"), py::arg("window_after"), + release_gil(), + "Window join: find all right rows within time window"); + + // Window functions (returning numpy arrays) + // These compute with GIL released, then briefly reacquire to create numpy array + ops_mod.def("mavg", [](Column& col, size_t window) -> py::array_t { + std::vector result; + { + py::gil_scoped_release release; + result = wdb_ops::mavg(col.as_float64(), window); + } + return py::array_t(result.size(), result.data()); + }, py::arg("col"), py::arg("window"), "Moving average"); + + ops_mod.def("msum", [](Column& col, size_t window) -> py::array_t { + std::vector result; + { + py::gil_scoped_release release; + result = wdb_ops::msum(col.as_float64(), window); + } + return py::array_t(result.size(), result.data()); + }, py::arg("col"), py::arg("window"), "Moving sum"); + + ops_mod.def("mstd", [](Column& col, size_t window) -> py::array_t { + std::vector result; + { + py::gil_scoped_release release; + result = wdb_ops::mstd(col.as_float64(), window); + } + return py::array_t(result.size(), result.data()); + }, py::arg("col"), py::arg("window"), "Moving standard deviation"); + + ops_mod.def("mmin", [](Column& col, size_t window) -> py::array_t { + std::vector result; + { + py::gil_scoped_release release; + result = wdb_ops::mmin(col.as_float64(), window); + } + return py::array_t(result.size(), result.data()); + }, py::arg("col"), py::arg("window"), "Moving minimum"); + + ops_mod.def("mmax", [](Column& col, size_t window) -> py::array_t { + std::vector result; + { + py::gil_scoped_release release; + result = wdb_ops::mmax(col.as_float64(), window); + } + return py::array_t(result.size(), result.data()); + }, py::arg("col"), py::arg("window"), "Moving maximum"); + + ops_mod.def("ema", [](Column& col, double alpha) -> py::array_t { + std::vector result; + { + py::gil_scoped_release release; + result = wdb_ops::ema(col.as_float64(), alpha); + } + return py::array_t(result.size(), result.data()); + }, py::arg("col"), py::arg("alpha"), "Exponential moving average"); + + ops_mod.def("diff", [](Column& col, size_t periods) -> py::array_t { + std::vector result; + { + py::gil_scoped_release release; + result = wdb_ops::diff(col.as_float64(), periods); + } + return py::array_t(result.size(), result.data()); + }, py::arg("col"), py::arg("periods") = 1, "Difference between consecutive values"); + + ops_mod.def("pct_change", [](Column& col, size_t periods) -> py::array_t { + std::vector result; + { + py::gil_scoped_release release; + result = wdb_ops::pct_change(col.as_float64(), periods); + } + return py::array_t(result.size(), result.data()); + }, py::arg("col"), py::arg("periods") = 1, "Percent change"); + + ops_mod.def("shift", [](Column& col, int64_t n) -> py::array_t { + std::vector result; + { + py::gil_scoped_release release; + result = wdb_ops::shift(col.as_float64(), n); + } + return py::array_t(result.size(), result.data()); + }, py::arg("col"), py::arg("n"), "Shift values by n positions"); + + // Version info + m.attr("__version__") = "0.2.0"; +} diff --git a/python/wayy_db/__init__.py b/python/wayy_db/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..359c5f2d5bdc6bb0c86426ac2baee87a0954da57 --- /dev/null +++ b/python/wayy_db/__init__.py @@ -0,0 +1,122 @@ +""" +WayyDB: High-performance columnar time-series database + +A kdb+-like database with Python-first API, featuring: +- As-of joins (aj) and window joins (wj) +- Zero-copy numpy interop via memory mapping +- SIMD-accelerated aggregations +- Columnar storage with sorted indices +""" + +from __future__ import annotations + +from wayy_db._core import ( + # Core classes + Database, + Table, + Column, + StringColumn, + # Types + DType, + # Exceptions + WayyException, + ColumnNotFound, + TypeMismatch, + InvalidOperation, + # Version + __version__, +) + +# Operations module +from wayy_db import ops + +__all__ = [ + # Core classes + "Database", + "Table", + "Column", + "StringColumn", + # Types + "DType", + # Exceptions + "WayyException", + "ColumnNotFound", + "TypeMismatch", + "InvalidOperation", + # Submodules + "ops", + # Version + "__version__", +] + + +def from_dict(data: dict, name: str = "", sorted_by: str | None = None) -> Table: + """Create a Table from a dictionary of numpy arrays. + + Args: + data: Dictionary mapping column names to numpy arrays + name: Optional table name + sorted_by: Optional column name to mark as sorted index + + Returns: + Table with the provided data + """ + import numpy as np + + table = Table(name) + + dtype_map = { + np.dtype("int64"): DType.Int64, + np.dtype("float64"): DType.Float64, + np.dtype("uint32"): DType.Symbol, + np.dtype("uint8"): DType.Bool, + } + + for col_name, arr in data.items(): + arr = np.asarray(arr) + if arr.dtype not in dtype_map: + # Try to convert + if np.issubdtype(arr.dtype, np.integer): + arr = arr.astype(np.int64) + elif np.issubdtype(arr.dtype, np.floating): + arr = arr.astype(np.float64) + else: + raise ValueError(f"Unsupported dtype {arr.dtype} for column {col_name}") + + dtype = dtype_map[arr.dtype] + table.add_column_from_numpy(col_name, arr, dtype) + + if sorted_by is not None: + table.set_sorted_by(sorted_by) + + return table + + +def from_pandas(df, name: str = "", sorted_by: str | None = None) -> Table: + """Create a Table from a pandas DataFrame. + + Args: + df: pandas DataFrame + name: Optional table name + sorted_by: Optional column name to mark as sorted index + + Returns: + Table with the DataFrame data + """ + data = {col: df[col].values for col in df.columns} + return from_dict(data, name=name, sorted_by=sorted_by) + + +def from_polars(df, name: str = "", sorted_by: str | None = None) -> Table: + """Create a Table from a polars DataFrame. + + Args: + df: polars DataFrame + name: Optional table name + sorted_by: Optional column name to mark as sorted index + + Returns: + Table with the DataFrame data + """ + data = {col: df[col].to_numpy() for col in df.columns} + return from_dict(data, name=name, sorted_by=sorted_by) diff --git a/python/wayy_db/_core.pyi b/python/wayy_db/_core.pyi new file mode 100644 index 0000000000000000000000000000000000000000..6dd8635e742f3f8f057404058b9ea898987c05e7 --- /dev/null +++ b/python/wayy_db/_core.pyi @@ -0,0 +1,113 @@ +"""Type stubs for wayy_db._core C++ extension module.""" + +from typing import Optional, Sequence +import numpy as np +import numpy.typing as npt + +__version__: str + +class DType: + Int64: DType + Float64: DType + Timestamp: DType + Symbol: DType + Bool: DType + +class WayyException(Exception): ... +class ColumnNotFound(WayyException): ... +class TypeMismatch(WayyException): ... +class InvalidOperation(WayyException): ... + +class Column: + @property + def name(self) -> str: ... + @property + def dtype(self) -> DType: ... + @property + def size(self) -> int: ... + def __len__(self) -> int: ... + def to_numpy(self) -> npt.NDArray: ... + +class Table: + def __init__(self, name: str = "") -> None: ... + @property + def name(self) -> str: ... + @property + def num_rows(self) -> int: ... + @property + def num_columns(self) -> int: ... + @property + def sorted_by(self) -> Optional[str]: ... + def __len__(self) -> int: ... + def has_column(self, name: str) -> bool: ... + def column(self, name: str) -> Column: ... + def __getitem__(self, name: str) -> Column: ... + def column_names(self) -> list[str]: ... + def set_sorted_by(self, col: str) -> None: ... + def save(self, path: str) -> None: ... + @staticmethod + def load(path: str) -> Table: ... + @staticmethod + def mmap(path: str) -> Table: ... + def add_column_from_numpy( + self, name: str, array: npt.NDArray, dtype: DType + ) -> None: ... + def to_dict(self) -> dict[str, npt.NDArray]: ... + +class Database: + def __init__(self, path: str = "") -> None: ... + @property + def path(self) -> str: ... + @property + def is_persistent(self) -> bool: ... + def tables(self) -> list[str]: ... + def has_table(self, name: str) -> bool: ... + def table(self, name: str) -> Table: ... + def __getitem__(self, name: str) -> Table: ... + def create_table(self, name: str) -> Table: ... + def drop_table(self, name: str) -> None: ... + def save(self) -> None: ... + def refresh(self) -> None: ... + +class ops: + @staticmethod + def sum(col: Column) -> float: ... + @staticmethod + def avg(col: Column) -> float: ... + @staticmethod + def min(col: Column) -> float: ... + @staticmethod + def max(col: Column) -> float: ... + @staticmethod + def std(col: Column) -> float: ... + @staticmethod + def aj( + left: Table, right: Table, on: Sequence[str], as_of: str + ) -> Table: ... + @staticmethod + def wj( + left: Table, + right: Table, + on: Sequence[str], + as_of: str, + window_before: int, + window_after: int, + ) -> Table: ... + @staticmethod + def mavg(col: Column, window: int) -> npt.NDArray[np.float64]: ... + @staticmethod + def msum(col: Column, window: int) -> npt.NDArray[np.float64]: ... + @staticmethod + def mstd(col: Column, window: int) -> npt.NDArray[np.float64]: ... + @staticmethod + def mmin(col: Column, window: int) -> npt.NDArray[np.float64]: ... + @staticmethod + def mmax(col: Column, window: int) -> npt.NDArray[np.float64]: ... + @staticmethod + def ema(col: Column, alpha: float) -> npt.NDArray[np.float64]: ... + @staticmethod + def diff(col: Column, periods: int = 1) -> npt.NDArray[np.float64]: ... + @staticmethod + def pct_change(col: Column, periods: int = 1) -> npt.NDArray[np.float64]: ... + @staticmethod + def shift(col: Column, n: int) -> npt.NDArray[np.float64]: ... diff --git a/python/wayy_db/cli/__init__.py b/python/wayy_db/cli/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..21abebadffad1d96838f75289da00220ab704d89 --- /dev/null +++ b/python/wayy_db/cli/__init__.py @@ -0,0 +1 @@ +"""WayyDB CLI - command-line interface for the WayyDB service.""" diff --git a/python/wayy_db/cli/client.py b/python/wayy_db/cli/client.py new file mode 100644 index 0000000000000000000000000000000000000000..0ba88ebfd51c01d3c71723eb6cbd1deab0cfbd57 --- /dev/null +++ b/python/wayy_db/cli/client.py @@ -0,0 +1,300 @@ +"""HTTP client for the WayyDB service.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, NoReturn, Optional + +import httpx + +from wayy_db.cli.config import get_server_url + +# The API uses /api/v1/{db_name}/... for OLTP routes but db_name is unused +# server-side (single global db). We hardcode "default" for forward compat. +_DB_NAME = "default" + + +class WayyClientError(Exception): + """Raised when the WayyDB service returns an error.""" + + def __init__(self, status_code: int, detail: str) -> None: + self.status_code = status_code + self.detail = detail + super().__init__(f"HTTP {status_code}: {detail}") + + +class WayyClient: + """Synchronous HTTP client for the WayyDB REST API.""" + + def __init__(self, base_url: Optional[str] = None, timeout: float = 30.0) -> None: + self.base_url = (base_url or get_server_url()).rstrip("/") + self._client = httpx.Client(base_url=self.base_url, timeout=timeout) + + def _request(self, method: str, path: str, **kwargs: Any) -> Any: + """Make an HTTP request and return JSON response.""" + try: + resp = self._client.request(method, path, **kwargs) + except httpx.ConnectError: + raise WayyClientError(0, f"Cannot connect to {self.base_url}") + if resp.status_code >= 400: + try: + detail = resp.json().get("detail", resp.text) + except Exception: + detail = resp.text + raise WayyClientError(resp.status_code, detail) + if resp.status_code == 204 or not resp.content: + return {} + return resp.json() + + # --- Health --- + + def health(self) -> dict[str, Any]: + return self._request("GET", "/health") + + def info(self) -> dict[str, Any]: + return self._request("GET", "/") + + # --- Tables --- + + def list_tables(self) -> list[str]: + data = self._request("GET", "/tables") + return data.get("tables", []) + + def get_table_info(self, name: str) -> dict[str, Any]: + return self._request("GET", f"/tables/{name}") + + def get_table_data( + self, name: str, limit: int = 100, offset: int = 0 + ) -> dict[str, Any]: + return self._request( + "GET", f"/tables/{name}/data", params={"limit": limit, "offset": offset} + ) + + def create_table( + self, + name: str, + columns: list[dict[str, str]], + primary_key: Optional[str] = None, + sorted_by: Optional[str] = None, + ) -> dict[str, Any]: + payload = { + "name": name, + "columns": columns, + "primary_key": primary_key, + "sorted_by": sorted_by, + } + return self._request("POST", f"/api/v1/{_DB_NAME}/tables", json=payload) + + def drop_table(self, name: str) -> dict[str, Any]: + return self._request("DELETE", f"/tables/{name}") + + def upload_table(self, table_data: dict[str, Any]) -> dict[str, Any]: + return self._request("POST", "/tables/upload", json=table_data) + + def append_rows(self, name: str, columns: list[dict[str, Any]]) -> dict[str, Any]: + return self._request("POST", f"/tables/{name}/append", json={"columns": columns}) + + # --- OLTP --- + + def insert_row(self, table: str, data: dict[str, Any]) -> dict[str, Any]: + return self._request( + "POST", f"/api/v1/{_DB_NAME}/tables/{table}/rows", json={"data": data} + ) + + def get_row(self, table: str, pk: str) -> dict[str, Any]: + return self._request("GET", f"/api/v1/{_DB_NAME}/tables/{table}/rows/{pk}") + + def update_row(self, table: str, pk: str, data: dict[str, Any]) -> dict[str, Any]: + return self._request( + "PUT", f"/api/v1/{_DB_NAME}/tables/{table}/rows/{pk}", json={"data": data} + ) + + def delete_row(self, table: str, pk: str) -> dict[str, Any]: + return self._request("DELETE", f"/api/v1/{_DB_NAME}/tables/{table}/rows/{pk}") + + def filter_rows( + self, table: str, filters: Optional[dict[str, str]] = None, limit: int = 500 + ) -> dict[str, Any]: + params = dict(filters or {}) + params["limit"] = str(limit) + return self._request( + "GET", f"/api/v1/{_DB_NAME}/tables/{table}/rows", params=params + ) + + # --- Aggregations --- + + def aggregate(self, table: str, column: str, op: str) -> dict[str, Any]: + return self._request("GET", f"/tables/{table}/agg/{column}/{op}") + + # --- Joins --- + + def as_of_join( + self, left: str, right: str, on: list[str], as_of: str + ) -> dict[str, Any]: + payload = {"left_table": left, "right_table": right, "on": on, "as_of": as_of} + return self._request("POST", "/join/aj", json=payload) + + def window_join( + self, + left: str, + right: str, + on: list[str], + as_of: str, + before: int, + after: int, + ) -> dict[str, Any]: + payload = { + "left_table": left, + "right_table": right, + "on": on, + "as_of": as_of, + "window_before": before, + "window_after": after, + } + return self._request("POST", "/join/wj", json=payload) + + # --- Window functions --- + + def window_function( + self, + table: str, + column: str, + operation: str, + window: Optional[int] = None, + alpha: Optional[float] = None, + ) -> dict[str, Any]: + payload: dict[str, Any] = { + "table": table, + "column": column, + "operation": operation, + } + if window is not None: + payload["window"] = window + if alpha is not None: + payload["alpha"] = alpha + return self._request("POST", "/window", json=payload) + + # --- Streaming --- + + def ingest_tick(self, table: str, tick: dict[str, Any]) -> dict[str, Any]: + return self._request("POST", f"/ingest/{table}", json=tick) + + def ingest_batch(self, table: str, ticks: list[dict[str, Any]]) -> dict[str, Any]: + return self._request("POST", f"/ingest/{table}/batch", json={"ticks": ticks}) + + def get_streaming_stats(self) -> dict[str, Any]: + return self._request("GET", "/streaming/stats") + + def get_quote(self, table: str, symbol: str) -> dict[str, Any]: + return self._request("GET", f"/streaming/quote/{table}/{symbol}") + + def get_all_quotes(self, table: str) -> dict[str, Any]: + return self._request("GET", f"/streaming/quotes/{table}") + + # --- KV Store --- + + def kv_get(self, key: str) -> Any: + data = self._request("GET", f"/kv/{key}") + return data.get("value") + + def kv_set(self, key: str, value: Any, ttl: Optional[float] = None) -> dict[str, Any]: + payload: dict[str, Any] = {"value": value} + if ttl is not None: + payload["ttl"] = ttl + return self._request("POST", f"/kv/{key}", json=payload) + + def kv_delete(self, key: str) -> dict[str, Any]: + return self._request("DELETE", f"/kv/{key}") + + def kv_list(self, pattern: Optional[str] = None) -> list[str]: + params = {} + if pattern: + params["pattern"] = pattern + data = self._request("GET", "/kv", params=params) + return data.get("keys", []) + + # --- Checkpoint --- + + def checkpoint(self) -> dict[str, Any]: + return self._request("POST", f"/api/v1/{_DB_NAME}/checkpoint") + + def close(self) -> None: + self._client.close() + + def __enter__(self) -> "WayyClient": + return self + + def __exit__(self, *args: Any) -> None: + self.close() + + +def upload_csv( + client: WayyClient, name: str, file_path: Path, sorted_by: Optional[str] = None +) -> dict[str, Any]: + """Read a CSV file and upload it as a table. + + Uses stdlib csv to avoid requiring pandas in CLI. + """ + import csv + + with open(file_path, newline="") as f: + reader = csv.reader(f) + headers = next(reader) + rows = list(reader) + + if not rows: + raise ValueError("CSV file is empty (no data rows)") + + columns: list[dict[str, Any]] = [] + for i, header in enumerate(headers): + raw_values = [row[i] for row in rows] + dtype, data = _infer_column(raw_values) + columns.append({"name": header, "dtype": dtype, "data": data}) + + payload = {"name": name, "columns": columns, "sorted_by": sorted_by} + return client.upload_table(payload) + + +def _infer_column(values: list[str]) -> tuple[str, list[Any]]: + """Infer column dtype from string values. Returns (dtype_name, typed_data).""" + non_empty = [v for v in values if v.strip()] + if not non_empty: + return ("float64", [0.0] * len(values)) + + # Try int64 + try: + data = [int(v) if v.strip() else 0 for v in values] + return ("int64", data) + except (ValueError, OverflowError): + pass + + # Try float64 (handles empty cells as NaN) + try: + data = [float(v) if v.strip() else float("nan") for v in values] + return ("float64", data) + except (ValueError, OverflowError): + pass + + raise ValueError( + f"Non-numeric column detected. Values: {values[:3]}... " + "CSV upload currently supports numeric columns only. " + "Use the Python API with from_pandas() for string/symbol columns." + ) + + +def upload_json_ticks( + client: WayyClient, table: str, file_path: Path +) -> dict[str, Any]: + """Read a JSON file of ticks and batch-ingest them.""" + with open(file_path) as f: + data = json.load(f) + + if isinstance(data, list): + ticks = data + elif isinstance(data, dict) and "ticks" in data: + ticks = data["ticks"] + else: + raise ValueError("JSON must be a list of ticks or {\"ticks\": [...]}") + + return client.ingest_batch(table, ticks) diff --git a/python/wayy_db/cli/config.py b/python/wayy_db/cli/config.py new file mode 100644 index 0000000000000000000000000000000000000000..3a4fbbef06ebac718efa4d9b4db7459aa9cfed83 --- /dev/null +++ b/python/wayy_db/cli/config.py @@ -0,0 +1,42 @@ +"""Configuration management for the WayyDB CLI.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + + +CONFIG_DIR = Path.home() / ".wayy" +CONFIG_FILE = CONFIG_DIR / "config.json" + +DEFAULTS: dict[str, Any] = { + "server_url": "http://localhost:8080", + "format": "table", + "db_name": "default", +} + + +def load_config() -> dict[str, Any]: + """Load config from ~/.wayy/config.json, creating defaults if missing.""" + if CONFIG_FILE.exists(): + with open(CONFIG_FILE) as f: + return {**DEFAULTS, **json.load(f)} + return dict(DEFAULTS) + + +def save_config(config: dict[str, Any]) -> None: + """Save config to ~/.wayy/config.json.""" + CONFIG_DIR.mkdir(parents=True, exist_ok=True) + with open(CONFIG_FILE, "w") as f: + json.dump(config, f, indent=2) + + +def get_server_url() -> str: + """Get the configured server URL.""" + return load_config()["server_url"] + + +def get_db_name() -> str: + """Get the configured database name.""" + return load_config()["db_name"] diff --git a/python/wayy_db/cli/deploy.py b/python/wayy_db/cli/deploy.py new file mode 100644 index 0000000000000000000000000000000000000000..585635440779804d0cf52bdda89c21bbabe439b0 --- /dev/null +++ b/python/wayy_db/cli/deploy.py @@ -0,0 +1,284 @@ +"""Deployment commands for the WayyDB CLI. + +Supports: +- Local: start uvicorn directly or via Docker +- HuggingFace Spaces: push to HF Docker space +- Docker: build and run container +""" + +from __future__ import annotations + +import os +import shutil +import subprocess +import sys +from pathlib import Path +from typing import Optional + +import typer + +from wayy_db.cli.config import load_config, save_config +from wayy_db.cli.output import console, print_error, print_info, print_success + +deploy_app = typer.Typer( + name="deploy", + help="Deploy WayyDB service", + no_args_is_help=True, +) + + +def _find_project_root() -> Path: + """Walk up from cwd looking for pyproject.toml with wayy-db.""" + cwd = Path.cwd() + for parent in [cwd, *cwd.parents]: + toml = parent / "pyproject.toml" + if toml.exists() and "wayy-db" in toml.read_text(): + return parent + raise FileNotFoundError( + "Cannot find wayyDB project root (no pyproject.toml with wayy-db found). " + "Run this command from within the wayyDB repo." + ) + + +def _run(cmd: list[str], cwd: Optional[Path] = None, check: bool = True) -> subprocess.CompletedProcess[str]: + """Run a subprocess with live output.""" + console.print(f"[dim]$ {' '.join(cmd)}[/dim]") + return subprocess.run(cmd, cwd=cwd, check=check, text=True) + + +# --- Local serve --- + + +@deploy_app.command("local") +def deploy_local( + port: int = typer.Option(8080, "--port", "-p", help="Port to serve on"), + host: str = typer.Option("0.0.0.0", "--host", help="Host to bind to"), + data_path: str = typer.Option("./data/wayydb", "--data-path", "-d", help="Data directory"), + workers: int = typer.Option(1, "--workers", "-w", help="Number of uvicorn workers"), +) -> None: + """Start WayyDB server locally with uvicorn.""" + os.makedirs(data_path, exist_ok=True) + os.environ["WAYY_DATA_PATH"] = str(Path(data_path).resolve()) + os.environ["PORT"] = str(port) + os.environ["CORS_ORIGINS"] = "*" + + print_info("Data path", os.environ["WAYY_DATA_PATH"]) + print_info("Serving on", f"http://{host}:{port}") + console.print("[dim]Press Ctrl+C to stop[/dim]\n") + + try: + _find_project_root() + # Running from source — use api.main:app + api_module = "api.main:app" + except FileNotFoundError: + # Installed package — api module should be importable + api_module = "api.main:app" + + cmd = [ + sys.executable, "-m", "uvicorn", api_module, + "--host", host, + "--port", str(port), + "--workers", str(workers), + ] + + try: + _run(cmd) + except KeyboardInterrupt: + console.print("\n[dim]Server stopped.[/dim]") + except subprocess.CalledProcessError: + print_error("Failed to start server. Is uvicorn installed? (pip install wayy-db[api])") + raise typer.Exit(1) + + +# --- Docker --- + + +@deploy_app.command("docker") +def deploy_docker( + port: int = typer.Option(8080, "--port", "-p", help="Host port to expose"), + tag: str = typer.Option("wayydb:latest", "--tag", "-t", help="Docker image tag"), + data_volume: str = typer.Option("wayydb-data", "--volume", "-v", help="Docker volume for data persistence"), + build: bool = typer.Option(True, "--build/--no-build", help="Build image before running"), + detach: bool = typer.Option(True, "--detach/--foreground", help="Run in background"), +) -> None: + """Build and run WayyDB in Docker.""" + if not shutil.which("docker"): + print_error("Docker not found. Install Docker: https://docs.docker.com/get-docker/") + raise typer.Exit(1) + + try: + root = _find_project_root() + except FileNotFoundError as e: + print_error(str(e)) + raise typer.Exit(1) + + if build: + console.print("[bold]Building Docker image...[/bold]") + _run(["docker", "build", "-t", tag, "."], cwd=root) + print_success(f"Built {tag}") + + # Create volume if needed + _run(["docker", "volume", "create", data_volume], check=False) + + # Stop existing container if running + _run(["docker", "rm", "-f", "wayydb"], check=False) + + cmd = [ + "docker", "run", + "--name", "wayydb", + "-p", f"{port}:8080", + "-v", f"{data_volume}:/data/wayydb", + "-e", "CORS_ORIGINS=*", + ] + + if detach: + cmd.append("-d") + + cmd.append(tag) + + _run(cmd) + + if detach: + print_success(f"WayyDB running at http://localhost:{port}") + print_info("Container", "wayydb") + print_info("Volume", data_volume) + console.print("[dim]Stop with: docker stop wayydb[/dim]") + else: + console.print("\n[dim]Container stopped.[/dim]") + + +# --- HuggingFace Spaces --- + + +@deploy_app.command("hf") +def deploy_hf( + repo: str = typer.Option("", "--repo", "-r", help="HF Space repo (user/name). Uses git remote 'hf' if not set."), + token: Optional[str] = typer.Option(None, "--token", help="HuggingFace token (or set HF_TOKEN env var)"), +) -> None: + """Deploy WayyDB to HuggingFace Spaces (Docker). + + Pushes the current repo state to a HuggingFace Space configured as a Docker space. + The Space must already exist. Create one at: https://huggingface.co/new-space?sdk=docker + """ + if not shutil.which("git"): + print_error("git not found") + raise typer.Exit(1) + + try: + root = _find_project_root() + except FileNotFoundError as e: + print_error(str(e)) + raise typer.Exit(1) + + # Check if hf remote exists + result = subprocess.run( + ["git", "remote", "get-url", "hf"], capture_output=True, text=True, cwd=root + ) + hf_remote_exists = result.returncode == 0 + existing_url = result.stdout.strip() if hf_remote_exists else "" + + if repo: + hf_token = token or os.environ.get("HF_TOKEN", "") + if hf_token: + remote_url = f"https://user:{hf_token}@huggingface.co/spaces/{repo}" + else: + remote_url = f"https://huggingface.co/spaces/{repo}" + + if hf_remote_exists: + _run(["git", "remote", "set-url", "hf", remote_url], cwd=root) + else: + _run(["git", "remote", "add", "hf", remote_url], cwd=root) + elif not hf_remote_exists: + print_error( + "No 'hf' git remote found. Either:\n" + " 1. Run: wayy deploy hf --repo /\n" + " 2. Add manually: git remote add hf https://huggingface.co/spaces//" + ) + raise typer.Exit(1) + + console.print("[bold]Pushing to HuggingFace Spaces...[/bold]") + + # HF Spaces rejects pushes containing large files in history (even deleted ones). + # Create a clean orphan commit with only the current tree to avoid this. + try: + # Create a temporary orphan branch with just the current working tree + _run(["git", "checkout", "--orphan", "_hf_deploy"], cwd=root) + _run(["git", "add", "-A"], cwd=root) + _run( + ["git", "commit", "-m", "Deploy wayyDB to HuggingFace Spaces", "--allow-empty"], + cwd=root, + ) + _run(["git", "push", "hf", "_hf_deploy:main", "--force"], cwd=root) + except subprocess.CalledProcessError: + # Clean up temp branch before erroring + subprocess.run(["git", "checkout", "main"], cwd=root, capture_output=True) + subprocess.run(["git", "branch", "-D", "_hf_deploy"], cwd=root, capture_output=True) + print_error("Push failed. Check your HF token and Space configuration.") + raise typer.Exit(1) + finally: + # Always return to main branch and clean up + subprocess.run(["git", "checkout", "main"], cwd=root, capture_output=True) + subprocess.run(["git", "branch", "-D", "_hf_deploy"], cwd=root, capture_output=True) + + # Extract space URL from remote + result = subprocess.run( + ["git", "remote", "get-url", "hf"], capture_output=True, text=True, cwd=root + ) + remote_url = result.stdout.strip() + + # Parse space name from URL + space_name = "" + if "huggingface.co/spaces/" in remote_url: + space_name = remote_url.split("huggingface.co/spaces/")[-1].rstrip(".git") + elif repo: + space_name = repo + + if space_name: + space_url = f"https://huggingface.co/spaces/{space_name}" + # HF Spaces with Docker get a direct URL + space_direct = f"https://{space_name.replace('/', '-')}.hf.space" + print_success(f"Deployed to HuggingFace Spaces") + print_info("Space", space_url) + print_info("API", space_direct) + console.print(f"\n[dim]Connect with: wayy connect {space_direct}[/dim]") + else: + print_success("Pushed to HuggingFace Spaces") + + +# --- Status / logs --- + + +@deploy_app.command("stop") +def deploy_stop( + name: str = typer.Option("wayydb", "--name", "-n", help="Container name"), +) -> None: + """Stop a running WayyDB Docker container.""" + if not shutil.which("docker"): + print_error("Docker not found") + raise typer.Exit(1) + + _run(["docker", "stop", name], check=False) + _run(["docker", "rm", name], check=False) + print_success(f"Stopped {name}") + + +@deploy_app.command("logs") +def deploy_logs( + name: str = typer.Option("wayydb", "--name", "-n", help="Container name"), + follow: bool = typer.Option(False, "--follow", "-f", help="Follow log output"), + tail: int = typer.Option(100, "--tail", help="Number of lines to show"), +) -> None: + """View logs from a running WayyDB Docker container.""" + if not shutil.which("docker"): + print_error("Docker not found") + raise typer.Exit(1) + + cmd = ["docker", "logs", "--tail", str(tail)] + if follow: + cmd.append("-f") + cmd.append(name) + + try: + _run(cmd, check=False) + except KeyboardInterrupt: + pass diff --git a/python/wayy_db/cli/main.py b/python/wayy_db/cli/main.py new file mode 100644 index 0000000000000000000000000000000000000000..54cc519772207b32114cf8c2da04f00436f2e1ad --- /dev/null +++ b/python/wayy_db/cli/main.py @@ -0,0 +1,522 @@ +"""WayyDB CLI — command-line interface for the WayyDB service. + +Usage: + wayy status Check server health + wayy connect Set server URL + wayy tables List all tables + wayy create --schema '{}' Create a table with schema + wayy query
Query table data + wayy upload --file data.csv Upload CSV as a table + wayy agg
Run aggregation + wayy stream
Subscribe to live updates + wayy ingest
--file ticks.json Batch ingest ticks + wayy kv get/set/del Key-value operations +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, NoReturn, Optional + +import typer + +from wayy_db.cli.client import WayyClient, WayyClientError, upload_csv, upload_json_ticks +from wayy_db.cli.config import get_server_url, load_config, save_config +from wayy_db.cli.deploy import deploy_app +from wayy_db.cli.output import ( + console, + print_error, + print_info, + print_json_data, + print_kv, + print_rows, + print_success, + print_table_data, +) + +app = typer.Typer( + name="wayy", + help="WayyDB CLI — high-performance columnar time-series database", + no_args_is_help=True, + add_completion=False, +) + + +def _handle_error(e: WayyClientError) -> NoReturn: + if e.status_code == 0: + print_error(f"Connection failed: {e.detail}") + else: + print_error(f"Error {e.status_code}: {e.detail}") + raise typer.Exit(1) + + +# --- Connection --- + + +@app.command() +def connect(url: str = typer.Argument(..., help="WayyDB server URL")) -> None: + """Set the WayyDB server URL.""" + url = url.rstrip("/") + if not url.startswith(("http://", "https://")): + url = f"http://{url}" + + try: + with WayyClient(base_url=url) as client: + info = client.health() + except WayyClientError as e: + print_error(f"Cannot reach {url}: {e.detail}") + raise typer.Exit(1) + + config = load_config() + config["server_url"] = url + save_config(config) + print_success(f"Connected to {url}") + print_info("Tables", info.get("tables", 0)) + + +@app.command() +def status() -> None: + """Check server health and connection info.""" + url = get_server_url() + print_info("Server", url) + + try: + with WayyClient() as client: + info = client.info() + health = client.health() + except WayyClientError as e: + _handle_error(e) + + print_info("Service", info.get("service", "?")) + print_info("Version", info.get("version", "?")) + print_info("Status", health.get("status", "?")) + print_info("Tables", health.get("tables", 0)) + + +# --- Tables --- + + +@app.command() +def tables() -> None: + """List all tables in the database.""" + try: + with WayyClient() as client: + table_list = client.list_tables() + except WayyClientError as e: + _handle_error(e) + + if not table_list: + console.print("[dim]No tables[/dim]") + return + + for t in table_list: + console.print(f" {t}") + + +@app.command() +def create( + name: str = typer.Argument(..., help="Table name"), + schema: str = typer.Option( + ..., "--schema", "-s", + help='Column schema as JSON: \'{"ts": "timestamp", "price": "float64"}\'', + ), + primary_key: Optional[str] = typer.Option(None, "--pk", help="Primary key column"), + sorted_by: Optional[str] = typer.Option(None, "--sorted-by", help="Sorted index column"), +) -> None: + """Create a new table with a typed schema.""" + try: + schema_dict = json.loads(schema) + except json.JSONDecodeError as e: + print_error(f"Invalid JSON schema: {e}") + raise typer.Exit(1) + + columns = [{"name": k, "dtype": v} for k, v in schema_dict.items()] + + try: + with WayyClient() as client: + result = client.create_table(name, columns, primary_key=primary_key, sorted_by=sorted_by) + except WayyClientError as e: + _handle_error(e) + + print_success(f"Created table '{name}' with columns: {result.get('columns', [])}") + + +@app.command() +def drop(name: str = typer.Argument(..., help="Table name to delete")) -> None: + """Drop a table.""" + try: + with WayyClient() as client: + client.drop_table(name) + except WayyClientError as e: + _handle_error(e) + + print_success(f"Dropped table '{name}'") + + +@app.command() +def info(name: str = typer.Argument(..., help="Table name")) -> None: + """Get table metadata.""" + try: + with WayyClient() as client: + data = client.get_table_info(name) + except WayyClientError as e: + _handle_error(e) + + print_info("Name", data.get("name")) + print_info("Rows", data.get("num_rows")) + print_info("Columns", data.get("num_columns")) + print_info("Column names", ", ".join(data.get("columns", []))) + print_info("Sorted by", data.get("sorted_by") or "none") + + +@app.command() +def query( + table: str = typer.Argument(..., help="Table name"), + limit: int = typer.Option(100, "--limit", "-n", help="Max rows to return"), + offset: int = typer.Option(0, "--offset", help="Row offset"), + where: Optional[list[str]] = typer.Option(None, "--where", "-w", help="Filter as col=val"), + output_json: bool = typer.Option(False, "--json", "-j", help="Output as JSON"), +) -> None: + """Query table data.""" + try: + with WayyClient() as client: + if where: + filters = {} + for w in where: + if "=" not in w: + print_error(f"Invalid filter: {w} (expected col=val)") + raise typer.Exit(1) + k, v = w.split("=", 1) + filters[k] = v + + result = client.filter_rows(table, filters=filters, limit=limit) + + if output_json: + print_json_data(result) + else: + print_rows(result.get("data", []), title=f"{table} ({result.get('count', 0)} rows)") + else: + result = client.get_table_data(table, limit=limit, offset=offset) + + if output_json: + print_json_data(result) + else: + data = result.get("data", {}) + total = result.get("total_rows", 0) + shown = len(next(iter(data.values()))) if data else 0 + print_table_data(data, title=f"{table} ({shown}/{total} rows)") + + except WayyClientError as e: + _handle_error(e) + + +@app.command() +def upload( + name: str = typer.Argument(..., help="Table name"), + file: Path = typer.Option(..., "--file", "-f", help="CSV file to upload"), + sorted_by: Optional[str] = typer.Option(None, "--sorted-by", help="Sorted index column"), +) -> None: + """Upload a CSV file as a new table.""" + if not file.exists(): + print_error(f"File not found: {file}") + raise typer.Exit(1) + + try: + with WayyClient() as client: + result = upload_csv(client, name, file, sorted_by=sorted_by) + except WayyClientError as e: + _handle_error(e) + except ValueError as e: + print_error(str(e)) + raise typer.Exit(1) + + print_success(f"Uploaded '{name}': {result.get('rows', 0)} rows, columns: {result.get('columns', [])}") + + +# --- Aggregations --- + + +@app.command() +def agg( + table: str = typer.Argument(..., help="Table name"), + column: str = typer.Argument(..., help="Column name"), + op: str = typer.Argument(..., help="Operation: sum, avg, min, max, std"), +) -> None: + """Run an aggregation on a table column.""" + try: + with WayyClient() as client: + result = client.aggregate(table, column, op) + except WayyClientError as e: + _handle_error(e) + + console.print(f"[bold]{op}[/bold]({table}.{column}) = [cyan]{result.get('result')}[/cyan]") + + +# --- Streaming --- + + +@app.command() +def stream( + table: str = typer.Argument(..., help="Table name to subscribe to"), + symbols: Optional[str] = typer.Option(None, "--symbols", "-s", help="Comma-separated symbol filter"), + output_json: bool = typer.Option(False, "--json", "-j", help="Output raw JSON"), +) -> None: + """Subscribe to real-time streaming updates via WebSocket.""" + import asyncio + + async def _stream() -> None: + import websockets + + url = get_server_url().replace("http://", "ws://").replace("https://", "wss://") + ws_url = f"{url}/ws/subscribe/{table}" + + console.print(f"[dim]Connecting to {ws_url}...[/dim]") + + async with websockets.connect(ws_url) as ws: + if symbols: + symbol_list = [s.strip() for s in symbols.split(",")] + await ws.send(json.dumps({"symbols": symbol_list})) + console.print(f"[dim]Filtering: {symbol_list}[/dim]") + + console.print("[green]Connected.[/green] Press Ctrl+C to disconnect.\n") + + try: + async for message in ws: + data = json.loads(message) + if output_json: + print_json_data(data) + else: + if "batch" in data: + for tick in data["batch"]: + _print_tick(tick) + else: + _print_tick(data) + except asyncio.CancelledError: + pass + + try: + asyncio.run(_stream()) + except KeyboardInterrupt: + console.print("\n[dim]Disconnected.[/dim]") + + +def _print_tick(tick: dict[str, Any]) -> None: + """Format a single tick for display.""" + sym = tick.get("symbol", "?") + price = tick.get("price", "?") + vol = tick.get("volume", "") + bid = tick.get("bid", "") + ask = tick.get("ask", "") + + parts = [f"[bold]{sym}[/bold]", f"[cyan]{price}[/cyan]"] + if bid and ask: + parts.append(f"[dim]{bid}/{ask}[/dim]") + if vol: + parts.append(f"vol={vol}") + + console.print(" ".join(parts)) + + +# --- Ingestion --- + + +@app.command() +def ingest( + table: str = typer.Argument(..., help="Table name"), + file: Path = typer.Option(..., "--file", "-f", help="JSON file with ticks"), +) -> None: + """Batch ingest ticks from a JSON file.""" + if not file.exists(): + print_error(f"File not found: {file}") + raise typer.Exit(1) + + try: + with WayyClient() as client: + result = upload_json_ticks(client, table, file) + except WayyClientError as e: + _handle_error(e) + except ValueError as e: + print_error(str(e)) + raise typer.Exit(1) + + print_success(f"Ingested {result.get('ingested', 0)} ticks into '{table}'") + + +# --- KV Store --- + +kv_app = typer.Typer(name="kv", help="Key-value store operations", no_args_is_help=True) +app.add_typer(kv_app) + + +@kv_app.command("get") +def kv_get(key: str = typer.Argument(..., help="Key to get")) -> None: + """Get a value by key.""" + try: + with WayyClient() as client: + value = client.kv_get(key) + except WayyClientError as e: + _handle_error(e) + + print_kv(key, value) + + +@kv_app.command("set") +def kv_set( + key: str = typer.Argument(..., help="Key to set"), + value: str = typer.Argument(..., help="Value (JSON or string)"), + ttl: Optional[float] = typer.Option(None, "--ttl", help="TTL in seconds"), +) -> None: + """Set a key-value pair.""" + try: + parsed = json.loads(value) + except json.JSONDecodeError: + parsed = value + + try: + with WayyClient() as client: + client.kv_set(key, parsed, ttl=ttl) + except WayyClientError as e: + _handle_error(e) + + print_success(f"Set '{key}'") + + +@kv_app.command("del") +def kv_del(key: str = typer.Argument(..., help="Key to delete")) -> None: + """Delete a key.""" + try: + with WayyClient() as client: + client.kv_delete(key) + except WayyClientError as e: + _handle_error(e) + + print_success(f"Deleted '{key}'") + + +@kv_app.command("list") +def kv_list(pattern: Optional[str] = typer.Argument(None, help="Glob pattern filter")) -> None: + """List keys, optionally filtered by pattern.""" + try: + with WayyClient() as client: + keys = client.kv_list(pattern) + except WayyClientError as e: + _handle_error(e) + + if not keys: + console.print("[dim]No keys[/dim]") + return + + for k in keys: + console.print(f" {k}") + + +# --- Joins --- + +join_app = typer.Typer(name="join", help="Join operations", no_args_is_help=True) +app.add_typer(join_app) + + +@join_app.command("aj") +def join_aj( + left: str = typer.Argument(..., help="Left table"), + right: str = typer.Argument(..., help="Right table"), + on: str = typer.Option(..., "--on", help="Join keys (comma-separated)"), + as_of: str = typer.Option(..., "--as-of", help="Temporal column"), + output_json: bool = typer.Option(False, "--json", "-j", help="Output as JSON"), +) -> None: + """As-of join: find most recent right row for each left row.""" + on_cols = [c.strip() for c in on.split(",")] + + try: + with WayyClient() as client: + result = client.as_of_join(left, right, on_cols, as_of) + except WayyClientError as e: + _handle_error(e) + + if output_json: + print_json_data(result) + else: + print_table_data(result.get("data", {}), title=f"aj({left}, {right}) — {result.get('rows', 0)} rows") + + +@join_app.command("wj") +def join_wj( + left: str = typer.Argument(..., help="Left table"), + right: str = typer.Argument(..., help="Right table"), + on: str = typer.Option(..., "--on", help="Join keys (comma-separated)"), + as_of: str = typer.Option(..., "--as-of", help="Temporal column"), + before: int = typer.Option(..., "--before", help="Window before (nanoseconds)"), + after: int = typer.Option(..., "--after", help="Window after (nanoseconds)"), + output_json: bool = typer.Option(False, "--json", "-j", help="Output as JSON"), +) -> None: + """Window join: find all right rows within time window.""" + on_cols = [c.strip() for c in on.split(",")] + + try: + with WayyClient() as client: + result = client.window_join(left, right, on_cols, as_of, before, after) + except WayyClientError as e: + _handle_error(e) + + if output_json: + print_json_data(result) + else: + print_table_data(result.get("data", {}), title=f"wj({left}, {right}) — {result.get('rows', 0)} rows") + + +# --- Window Functions --- + + +@app.command("window") +def window_fn( + table: str = typer.Argument(..., help="Table name"), + column: str = typer.Argument(..., help="Column name"), + op: str = typer.Argument(..., help="Operation: mavg, msum, mstd, mmin, mmax, ema, diff, pct_change"), + window: Optional[int] = typer.Option(None, "--window", "-w", help="Window size"), + alpha: Optional[float] = typer.Option(None, "--alpha", help="EMA alpha"), + output_json: bool = typer.Option(False, "--json", "-j", help="Output as JSON"), +) -> None: + """Apply a window function to a column.""" + try: + with WayyClient() as client: + result = client.window_function(table, column, op, window=window, alpha=alpha) + except WayyClientError as e: + _handle_error(e) + + if output_json: + print_json_data(result) + else: + values = result.get("result", []) + console.print(f"[bold]{op}[/bold]({table}.{column}) — {len(values)} values") + if len(values) <= 20: + for v in values: + console.print(f" {v}") + else: + for v in values[:5]: + console.print(f" {v}") + console.print(f" ... ({len(values) - 10} more)") + for v in values[-5:]: + console.print(f" {v}") + + +# --- Checkpoint --- + + +@app.command() +def checkpoint() -> None: + """Flush WAL and save all tables to disk.""" + try: + with WayyClient() as client: + client.checkpoint() + except WayyClientError as e: + _handle_error(e) + + print_success("Checkpoint complete") + + +app.add_typer(deploy_app) + + +if __name__ == "__main__": + app() diff --git a/python/wayy_db/cli/output.py b/python/wayy_db/cli/output.py new file mode 100644 index 0000000000000000000000000000000000000000..0c5a8e2b53a0c81141c6ef96d12dcf24ec259092 --- /dev/null +++ b/python/wayy_db/cli/output.py @@ -0,0 +1,76 @@ +"""Output formatting for the WayyDB CLI.""" + +from __future__ import annotations + +import json +import sys +from typing import Any + +from rich.console import Console +from rich.json import JSON +from rich.table import Table + +console = Console() +err_console = Console(stderr=True) + + +def print_json_data(data: Any) -> None: + """Pretty-print JSON data.""" + console.print(JSON(json.dumps(data, default=str))) + + +def print_table_data(data: dict[str, list[Any]], title: str = "") -> None: + """Render columnar data as a rich table.""" + if not data: + console.print("[dim]No data[/dim]") + return + + table = Table(title=title, show_lines=False) + columns = list(data.keys()) + for col in columns: + table.add_column(col, style="cyan") + + num_rows = len(next(iter(data.values()))) + for i in range(num_rows): + row = [str(data[col][i]) for col in columns] + table.add_row(*row) + + console.print(table) + + +def print_rows(rows: list[dict[str, Any]], title: str = "") -> None: + """Render a list of row dicts as a rich table.""" + if not rows: + console.print("[dim]No rows[/dim]") + return + + columns = list(rows[0].keys()) + table = Table(title=title, show_lines=False) + for col in columns: + table.add_column(col, style="cyan") + + for row in rows: + table.add_row(*[str(row.get(col, "")) for col in columns]) + + console.print(table) + + +def print_kv(key: str, value: Any) -> None: + """Print a KV pair.""" + console.print(f"[bold]{key}[/bold] = ", end="") + if isinstance(value, (dict, list)): + print_json_data(value) + else: + console.print(str(value)) + + +def print_success(msg: str) -> None: + console.print(f"[green]{msg}[/green]") + + +def print_error(msg: str) -> None: + err_console.print(f"[red]{msg}[/red]") + + +def print_info(label: str, value: Any) -> None: + console.print(f"[bold]{label}:[/bold] {value}") diff --git a/python/wayy_db/ops.py b/python/wayy_db/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5e97f704455278c334de869e1aaaf651b1a87308 --- /dev/null +++ b/python/wayy_db/ops.py @@ -0,0 +1,55 @@ +""" +WayyDB Operations + +High-performance operations for time-series analysis: +- Temporal joins (aj, wj) +- SIMD aggregations (sum, avg, min, max, std) +- Window functions (mavg, msum, mstd, ema, etc.) +""" + +from wayy_db._core import ops as _ops + +# Re-export all operations from C++ module +from wayy_db._core.ops import ( + # Aggregations + sum, + avg, + min, + max, + std, + # Joins + aj, + wj, + # Window functions + mavg, + msum, + mstd, + mmin, + mmax, + ema, + diff, + pct_change, + shift, +) + +__all__ = [ + # Aggregations + "sum", + "avg", + "min", + "max", + "std", + # Joins + "aj", + "wj", + # Window functions + "mavg", + "msum", + "mstd", + "mmin", + "mmax", + "ema", + "diff", + "pct_change", + "shift", +] diff --git a/src/column.cpp b/src/column.cpp new file mode 100644 index 0000000000000000000000000000000000000000..41492b0f348cdc643edca8fd6cb3e8948f0ff1a3 --- /dev/null +++ b/src/column.cpp @@ -0,0 +1,121 @@ +#include "wayy_db/column.hpp" + +#include + +namespace wayy_db { + +Column::Column(std::string name, DType dtype, std::vector data) + : name_(std::move(name)) + , dtype_(dtype) + , size_(dtype_size(dtype) > 0 ? data.size() / dtype_size(dtype) : 0) + , owns_data_(true) + , owned_data_(std::move(data)) { + data_ = owned_data_.data(); +} + +Column::Column(std::string name, DType dtype, void* data, size_t size, bool owns_data) + : name_(std::move(name)) + , dtype_(dtype) + , data_(data) + , size_(size) + , owns_data_(owns_data) { + if (owns_data && data != nullptr && dtype_size(dtype) > 0) { + // Copy data into owned buffer + size_t byte_size = size * dtype_size(dtype); + owned_data_.resize(byte_size); + std::memcpy(owned_data_.data(), data, byte_size); + data_ = owned_data_.data(); + } +} + +// --- Validity bitmap --- + +void Column::ensure_validity() { + if (has_validity_) return; + size_t num_bytes = (size_ + 7) / 8; + validity_.assign(num_bytes, 0xFF); // All bits set = all valid + // Handle trailing bits in last byte + if (size_ % 8 != 0) { + uint8_t mask = static_cast((1u << (size_ % 8)) - 1); + validity_.back() = mask; + } + has_validity_ = true; +} + +bool Column::is_valid(size_t row) const { + if (!has_validity_) return true; // No bitmap = all valid + if (row >= size_) return false; + return (validity_[row / 8] >> (row % 8)) & 1; +} + +void Column::set_valid(size_t row, bool valid) { + if (!has_validity_) ensure_validity(); + if (row >= size_) return; + if (valid) { + validity_[row / 8] |= (1u << (row % 8)); + } else { + validity_[row / 8] &= ~(1u << (row % 8)); + } +} + +size_t Column::count_valid() const { + if (!has_validity_) return size_; // All valid + size_t count = 0; + for (size_t i = 0; i < validity_.size(); ++i) { + count += std::popcount(validity_[i]); + } + return count; +} + +void Column::set_validity_bitmap(std::vector bitmap) { + validity_ = std::move(bitmap); + has_validity_ = !validity_.empty(); +} + +void Column::append(const void* value, size_t value_size) { + if (!owns_data_) { + throw InvalidOperation("Cannot append to non-owned column"); + } + size_t elem_size = dtype_size(dtype_); + if (elem_size == 0) { + throw InvalidOperation("Cannot append to variable-length column via Column::append"); + } + if (value_size != elem_size) { + throw InvalidOperation("Value size mismatch in append"); + } + + size_t old_byte_size = owned_data_.size(); + owned_data_.resize(old_byte_size + elem_size); + std::memcpy(owned_data_.data() + old_byte_size, value, elem_size); + data_ = owned_data_.data(); + ++size_; + + // Extend validity bitmap if present + if (has_validity_) { + size_t needed_bytes = (size_ + 7) / 8; + if (validity_.size() < needed_bytes) { + validity_.push_back(0); + } + set_valid(size_ - 1, true); + } +} + +void Column::set(size_t row, const void* value, size_t value_size) { + if (!owns_data_) { + throw InvalidOperation("Cannot set on non-owned column"); + } + if (row >= size_) { + throw InvalidOperation("Row index out of range in set"); + } + size_t elem_size = dtype_size(dtype_); + if (elem_size == 0) { + throw InvalidOperation("Cannot set on variable-length column via Column::set"); + } + if (value_size != elem_size) { + throw InvalidOperation("Value size mismatch in set"); + } + + std::memcpy(owned_data_.data() + row * elem_size, value, elem_size); +} + +} // namespace wayy_db diff --git a/src/database.cpp b/src/database.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b23a586cb1859633df2471d3049e94b995565c80 --- /dev/null +++ b/src/database.cpp @@ -0,0 +1,156 @@ +#include "wayy_db/database.hpp" + +#include +#include + +namespace fs = std::filesystem; + +namespace wayy_db { + +Database::Database() = default; + +Database::Database(const std::string& path) : path_(path) { + if (!path_.empty()) { + fs::create_directories(path_); + scan_tables(); + + // Initialize WAL + wal_ = std::make_unique(path_); + + // Replay any unprocessed WAL entries from a crash + if (wal_->has_entries()) { + wal_->replay(*this); + } + } +} + +std::vector Database::tables() const { + std::shared_lock lock(mutex_); + std::vector names; + names.reserve(tables_.size()); + for (const auto& [name, _] : tables_) { + names.push_back(name); + } + // Also include tables on disk that aren't loaded yet + for (const auto& [name, _] : loaded_) { + if (!tables_.count(name)) { + names.push_back(name); + } + } + return names; +} + +bool Database::has_table(const std::string& name) const { + std::shared_lock lock(mutex_); + return tables_.count(name) > 0 || loaded_.count(name) > 0; +} + +Table& Database::table(const std::string& name) { + // First try with shared lock (read-only) + { + std::shared_lock lock(mutex_); + auto it = tables_.find(name); + if (it != tables_.end()) { + return it->second; + } + } + + // Need to lazy load - acquire exclusive lock + std::unique_lock lock(mutex_); + + // Double-check after acquiring exclusive lock (another thread may have loaded it) + auto it = tables_.find(name); + if (it != tables_.end()) { + return it->second; + } + + // Try to load from disk + if (is_persistent() && loaded_.count(name)) { + tables_.emplace(name, Table::mmap(table_path(name))); + return tables_.at(name); + } + + throw WayyException("Table not found: " + name); +} + +Table& Database::create_table(const std::string& name) { + std::unique_lock lock(mutex_); + + if (tables_.count(name) > 0 || loaded_.count(name) > 0) { + throw InvalidOperation("Table already exists: " + name); + } + + tables_.emplace(name, Table(name)); + if (is_persistent()) { + loaded_[name] = true; + } + return tables_.at(name); +} + +void Database::add_table(Table table) { + const std::string& name = table.name(); + + std::unique_lock lock(mutex_); + + if (tables_.count(name) > 0 || loaded_.count(name) > 0) { + throw InvalidOperation("Table already exists: " + name); + } + + if (is_persistent()) { + table.save(table_path(name)); + loaded_[name] = true; + } + tables_.emplace(name, std::move(table)); +} + +void Database::drop_table(const std::string& name) { + std::unique_lock lock(mutex_); + + tables_.erase(name); + loaded_.erase(name); + + if (is_persistent()) { + fs::remove_all(table_path(name)); + } +} + +void Database::save() { + if (!is_persistent()) return; + + std::shared_lock lock(mutex_); + for (auto& [name, table] : tables_) { + table.save(table_path(name)); + } +} + +void Database::refresh() { + if (!is_persistent()) return; + + std::unique_lock lock(mutex_); + scan_tables(); +} + +void Database::checkpoint() { + if (!wal_) return; + wal_->checkpoint(*this); +} + +std::string Database::table_path(const std::string& name) const { + return path_ + "/" + name; +} + +void Database::scan_tables() { + if (!fs::exists(path_)) return; + + for (const auto& entry : fs::directory_iterator(path_)) { + if (entry.is_directory()) { + std::string meta_path = entry.path().string() + "/_meta.json"; + if (fs::exists(meta_path)) { + std::string name = entry.path().filename().string(); + loaded_[name] = false; // Not loaded into memory yet + } + } + } +} + +} // namespace wayy_db diff --git a/src/hash_index.cpp b/src/hash_index.cpp new file mode 100644 index 0000000000000000000000000000000000000000..29f02e05d124515809a1ff2347cb979ba7562638 --- /dev/null +++ b/src/hash_index.cpp @@ -0,0 +1,62 @@ +#include "wayy_db/hash_index.hpp" +#include "wayy_db/table.hpp" +#include "wayy_db/column.hpp" +#include "wayy_db/string_column.hpp" + +namespace wayy_db { + +void HashIndex::build_int(const Table& table, const std::string& col_name) { + clear(); + const Column& col = table.column(col_name); + auto view = col.as(); + for (size_t i = 0; i < view.size(); ++i) { + if (col.is_valid(i)) { + int_map_[view[i]] = i; + } + } +} + +void HashIndex::build_str(const Table& table, const std::string& col_name) { + clear(); + const StringColumn& col = table.string_column(col_name); + for (size_t i = 0; i < col.size(); ++i) { + if (col.is_valid(i)) { + str_map_[std::string(col.get(i))] = i; + } + } +} + +std::optional HashIndex::find_int(int64_t key) const { + auto it = int_map_.find(key); + if (it != int_map_.end()) return it->second; + return std::nullopt; +} + +std::optional HashIndex::find_str(std::string_view key) const { + auto it = str_map_.find(std::string(key)); + if (it != str_map_.end()) return it->second; + return std::nullopt; +} + +void HashIndex::insert_int(int64_t key, size_t row) { + int_map_[key] = row; +} + +void HashIndex::insert_str(std::string_view key, size_t row) { + str_map_[std::string(key)] = row; +} + +void HashIndex::remove_int(int64_t key) { + int_map_.erase(key); +} + +void HashIndex::remove_str(std::string_view key) { + str_map_.erase(std::string(key)); +} + +void HashIndex::clear() { + int_map_.clear(); + str_map_.clear(); +} + +} // namespace wayy_db diff --git a/src/mmap_file.cpp b/src/mmap_file.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b4ce9ce7db110c017484caa77458ef63e0585939 --- /dev/null +++ b/src/mmap_file.cpp @@ -0,0 +1,154 @@ +#include "wayy_db/mmap_file.hpp" +#include "wayy_db/types.hpp" + +#include +#include +#include +#include + +#include + +namespace wayy_db { + +MmapFile::MmapFile(const std::string& path, Mode mode, size_t size) { + open(path, mode, size); +} + +MmapFile::MmapFile(MmapFile&& other) noexcept + : path_(std::move(other.path_)) + , data_(other.data_) + , size_(other.size_) + , mode_(other.mode_) + , fd_(other.fd_) { + other.data_ = nullptr; + other.size_ = 0; + other.fd_ = -1; +} + +MmapFile& MmapFile::operator=(MmapFile&& other) noexcept { + if (this != &other) { + close(); + path_ = std::move(other.path_); + data_ = other.data_; + size_ = other.size_; + mode_ = other.mode_; + fd_ = other.fd_; + other.data_ = nullptr; + other.size_ = 0; + other.fd_ = -1; + } + return *this; +} + +MmapFile::~MmapFile() { + close(); +} + +void MmapFile::open(const std::string& path, Mode mode, size_t size) { + close(); + + path_ = path; + mode_ = mode; + + int flags = 0; + int prot = 0; + + switch (mode) { + case Mode::ReadOnly: + flags = O_RDONLY; + prot = PROT_READ; + break; + case Mode::ReadWrite: + flags = O_RDWR; + prot = PROT_READ | PROT_WRITE; + break; + case Mode::Create: + flags = O_RDWR | O_CREAT | O_TRUNC; + prot = PROT_READ | PROT_WRITE; + break; + } + + fd_ = ::open(path.c_str(), flags, 0644); + if (fd_ < 0) { + throw WayyException("Failed to open file: " + path + " (" + strerror(errno) + ")"); + } + + if (mode == Mode::Create && size > 0) { + // Extend file to requested size + if (ftruncate(fd_, size) < 0) { + ::close(fd_); + fd_ = -1; + throw WayyException("Failed to resize file: " + path); + } + size_ = size; + } else { + // Get file size + struct stat st; + if (fstat(fd_, &st) < 0) { + ::close(fd_); + fd_ = -1; + throw WayyException("Failed to stat file: " + path); + } + size_ = st.st_size; + } + + if (size_ == 0) { + // Can't mmap empty file + return; + } + + data_ = mmap(nullptr, size_, prot, MAP_SHARED, fd_, 0); + if (data_ == MAP_FAILED) { + data_ = nullptr; + ::close(fd_); + fd_ = -1; + throw WayyException("Failed to mmap file: " + path + " (" + strerror(errno) + ")"); + } +} + +void MmapFile::close() { + if (data_ != nullptr) { + munmap(data_, size_); + data_ = nullptr; + } + if (fd_ >= 0) { + ::close(fd_); + fd_ = -1; + } + size_ = 0; + path_.clear(); +} + +void MmapFile::sync() { + if (data_ != nullptr && mode_ != Mode::ReadOnly) { + msync(data_, size_, MS_SYNC); + } +} + +void MmapFile::resize(size_t new_size) { + if (mode_ != Mode::Create && mode_ != Mode::ReadWrite) { + throw InvalidOperation("Cannot resize read-only mmap"); + } + + if (data_ != nullptr) { + munmap(data_, size_); + data_ = nullptr; + } + + if (ftruncate(fd_, new_size) < 0) { + throw WayyException("Failed to resize file: " + path_); + } + + size_ = new_size; + + if (size_ > 0) { + int prot = PROT_READ | PROT_WRITE; + data_ = mmap(nullptr, size_, prot, MAP_SHARED, fd_, 0); + if (data_ == MAP_FAILED) { + data_ = nullptr; + throw WayyException("Failed to remap file: " + path_); + } + } +} + +} // namespace wayy_db diff --git a/src/ops/aggregations.cpp b/src/ops/aggregations.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2cc29a95447bfa9397299cc710c3b8f0f24608c6 --- /dev/null +++ b/src/ops/aggregations.cpp @@ -0,0 +1,200 @@ +#include "wayy_db/ops/aggregations.hpp" + +#include +#include +#include + +#ifdef WAYY_USE_AVX2 +#include +#endif + +namespace wayy_db::ops { + +// Scalar implementations + +template +T sum(const ColumnView& col) { + return std::accumulate(col.begin(), col.end(), T{0}); +} + +template int64_t sum(const ColumnView&); +template double sum(const ColumnView&); + +template +T min(const ColumnView& col) { + if (col.empty()) { + throw InvalidOperation("min() on empty column"); + } + return *std::min_element(col.begin(), col.end()); +} + +template int64_t min(const ColumnView&); +template double min(const ColumnView&); + +template +T max(const ColumnView& col) { + if (col.empty()) { + throw InvalidOperation("max() on empty column"); + } + return *std::max_element(col.begin(), col.end()); +} + +template int64_t max(const ColumnView&); +template double max(const ColumnView&); + +template +double variance(const ColumnView& col) { + if (col.empty()) { + return std::numeric_limits::quiet_NaN(); + } + + double mean = avg(col); + double sum_sq = 0.0; + + for (const auto& val : col) { + double diff = static_cast(val) - mean; + sum_sq += diff * diff; + } + + return sum_sq / static_cast(col.size()); +} + +template double variance(const ColumnView&); +template double variance(const ColumnView&); + +template +double std_dev(const ColumnView& col) { + return std::sqrt(variance(col)); +} + +template double std_dev(const ColumnView&); +template double std_dev(const ColumnView&); + +// SIMD implementations + +#ifdef WAYY_USE_AVX2 + +double sum_simd(const ColumnView& col) { + const double* data = col.data(); + size_t n = col.size(); + + __m256d vsum = _mm256_setzero_pd(); + + // Process 4 doubles per iteration + size_t i = 0; + for (; i + 4 <= n; i += 4) { + __m256d v = _mm256_loadu_pd(data + i); + vsum = _mm256_add_pd(vsum, v); + } + + // Horizontal reduction + __m128d vlow = _mm256_castpd256_pd128(vsum); + __m128d vhigh = _mm256_extractf128_pd(vsum, 1); + vlow = _mm_add_pd(vlow, vhigh); + __m128d high64 = _mm_unpackhi_pd(vlow, vlow); + double result = _mm_cvtsd_f64(_mm_add_sd(vlow, high64)); + + // Handle remainder + for (; i < n; ++i) { + result += data[i]; + } + + return result; +} + +int64_t sum_simd(const ColumnView& col) { + const int64_t* data = col.data(); + size_t n = col.size(); + + __m256i vsum = _mm256_setzero_si256(); + + // Process 4 int64s per iteration + size_t i = 0; + for (; i + 4 <= n; i += 4) { + __m256i v = _mm256_loadu_si256(reinterpret_cast(data + i)); + vsum = _mm256_add_epi64(vsum, v); + } + + // Horizontal reduction + alignas(32) int64_t temp[4]; + _mm256_store_si256(reinterpret_cast<__m256i*>(temp), vsum); + int64_t result = temp[0] + temp[1] + temp[2] + temp[3]; + + // Handle remainder + for (; i < n; ++i) { + result += data[i]; + } + + return result; +} + +#else + +double sum_simd(const ColumnView& col) { + return sum(col); +} + +int64_t sum_simd(const ColumnView& col) { + return sum(col); +} + +#endif + +// Type-erased implementations + +double sum(const Column& col) { + switch (col.dtype()) { + case DType::Int64: + case DType::Timestamp: + return static_cast(sum_simd(const_cast(col).as_int64())); + case DType::Float64: + return sum_simd(const_cast(col).as_float64()); + default: + throw InvalidOperation("sum() not supported for this type"); + } +} + +double avg(const Column& col) { + if (col.size() == 0) { + return std::numeric_limits::quiet_NaN(); + } + return sum(col) / static_cast(col.size()); +} + +double min_val(const Column& col) { + switch (col.dtype()) { + case DType::Int64: + case DType::Timestamp: + return static_cast(min(const_cast(col).as_int64())); + case DType::Float64: + return min(const_cast(col).as_float64()); + default: + throw InvalidOperation("min() not supported for this type"); + } +} + +double max_val(const Column& col) { + switch (col.dtype()) { + case DType::Int64: + case DType::Timestamp: + return static_cast(max(const_cast(col).as_int64())); + case DType::Float64: + return max(const_cast(col).as_float64()); + default: + throw InvalidOperation("max() not supported for this type"); + } +} + +double std_dev(const Column& col) { + switch (col.dtype()) { + case DType::Int64: + case DType::Timestamp: + return std_dev(const_cast(col).as_int64()); + case DType::Float64: + return std_dev(const_cast(col).as_float64()); + default: + throw InvalidOperation("std_dev() not supported for this type"); + } +} + +} // namespace wayy_db::ops diff --git a/src/ops/joins.cpp b/src/ops/joins.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6366dd6072eebf52c9d5ddc04c351d7222c6f1a4 --- /dev/null +++ b/src/ops/joins.cpp @@ -0,0 +1,271 @@ +#include "wayy_db/ops/joins.hpp" + +#include +#include +#include +#include + +namespace wayy_db::ops { + +namespace { + +// Hash combine for multi-key joins +struct KeyHash { + size_t operator()(const std::vector& key) const { + size_t hash = 0; + for (auto val : key) { + hash ^= std::hash{}(val) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + } + return hash; + } +}; + +// Extract join key values from a row +std::vector extract_key(const Table& table, + const std::vector& on, + size_t row) { + std::vector key; + key.reserve(on.size()); + + for (const auto& col_name : on) { + const Column& col = table.column(col_name); + switch (col.dtype()) { + case DType::Int64: + case DType::Timestamp: + key.push_back(const_cast(col).as_int64()[row]); + break; + case DType::Symbol: + key.push_back(const_cast(col).as_symbol()[row]); + break; + default: + throw InvalidOperation("Join key column must be Int64, Timestamp, or Symbol"); + } + } + + return key; +} + +// Group row indices by key values +std::unordered_map, std::vector, KeyHash> +group_by_key(const Table& table, const std::vector& on) { + std::unordered_map, std::vector, KeyHash> groups; + + for (size_t i = 0; i < table.num_rows(); ++i) { + auto key = extract_key(table, on, i); + groups[key].push_back(i); + } + + return groups; +} + +} // namespace + +Table aj(const Table& left, const Table& right, + const std::vector& on, + const std::string& as_of) { + + // Validate inputs + if (!left.is_sorted() || left.sorted_by() != as_of) { + throw InvalidOperation("Left table must be sorted by " + as_of); + } + if (!right.is_sorted() || right.sorted_by() != as_of) { + throw InvalidOperation("Right table must be sorted by " + as_of); + } + + // Group right table by join keys + auto right_groups = group_by_key(right, on); + + // Get timestamp columns + auto left_ts = const_cast(left).column(as_of).as_int64(); + auto right_ts = const_cast(right).column(as_of).as_int64(); + + // Result builders - collect matching indices + std::vector left_indices; + std::vector right_indices; // -1 means no match + left_indices.reserve(left.num_rows()); + right_indices.reserve(left.num_rows()); + + // For each left row, find the most recent right row + for (size_t i = 0; i < left.num_rows(); ++i) { + auto key = extract_key(left, on, i); + int64_t ts = left_ts[i]; + + auto group_it = right_groups.find(key); + if (group_it == right_groups.end()) { + // No matching key in right table + left_indices.push_back(i); + right_indices.push_back(static_cast(-1)); + continue; + } + + const auto& group = group_it->second; + + // Binary search for largest timestamp <= ts + auto it = std::upper_bound(group.begin(), group.end(), ts, + [&right_ts](int64_t t, size_t idx) { return t < right_ts[idx]; }); + + if (it != group.begin()) { + --it; + left_indices.push_back(i); + right_indices.push_back(*it); + } else { + // No timestamp <= ts + left_indices.push_back(i); + right_indices.push_back(static_cast(-1)); + } + } + + // Build result table + Table result("aj_result"); + + // Add left columns + for (const auto& col_name : left.column_names()) { + const Column& src = left.column(col_name); + size_t elem_size = dtype_size(src.dtype()); + std::vector data(left_indices.size() * elem_size); + + const uint8_t* src_data = static_cast(src.data()); + for (size_t i = 0; i < left_indices.size(); ++i) { + std::memcpy(data.data() + i * elem_size, + src_data + left_indices[i] * elem_size, + elem_size); + } + + result.add_column(Column(col_name, src.dtype(), std::move(data))); + } + + // Add right columns (excluding join keys and as_of) + for (const auto& col_name : right.column_names()) { + // Skip if already in left or is a join key + if (result.has_column(col_name)) continue; + if (std::find(on.begin(), on.end(), col_name) != on.end()) continue; + + const Column& src = right.column(col_name); + size_t elem_size = dtype_size(src.dtype()); + std::vector data(right_indices.size() * elem_size, 0); + + const uint8_t* src_data = static_cast(src.data()); + for (size_t i = 0; i < right_indices.size(); ++i) { + if (right_indices[i] != static_cast(-1)) { + std::memcpy(data.data() + i * elem_size, + src_data + right_indices[i] * elem_size, + elem_size); + } + // else: leave as zero (null representation) + } + + result.add_column(Column(col_name, src.dtype(), std::move(data))); + } + + result.set_sorted_by(as_of); + return result; +} + +Table wj(const Table& left, const Table& right, + const std::vector& on, + const std::string& as_of, + int64_t window_before, + int64_t window_after) { + + // Validate inputs + if (!left.is_sorted() || left.sorted_by() != as_of) { + throw InvalidOperation("Left table must be sorted by " + as_of); + } + if (!right.is_sorted() || right.sorted_by() != as_of) { + throw InvalidOperation("Right table must be sorted by " + as_of); + } + + // Group right table by join keys + auto right_groups = group_by_key(right, on); + + // Get timestamp columns + auto left_ts = const_cast(left).column(as_of).as_int64(); + auto right_ts = const_cast(right).column(as_of).as_int64(); + + // Result builders + std::vector left_indices; + std::vector right_indices; + + // For each left row, find all right rows in window + for (size_t i = 0; i < left.num_rows(); ++i) { + auto key = extract_key(left, on, i); + int64_t ts = left_ts[i]; + int64_t ts_min = ts - window_before; + int64_t ts_max = ts + window_after; + + auto group_it = right_groups.find(key); + if (group_it == right_groups.end()) { + continue; // No matching key + } + + const auto& group = group_it->second; + + // Find range [ts_min, ts_max] + auto lower = std::lower_bound(group.begin(), group.end(), ts_min, + [&right_ts](size_t idx, int64_t t) { return right_ts[idx] < t; }); + auto upper = std::upper_bound(group.begin(), group.end(), ts_max, + [&right_ts](int64_t t, size_t idx) { return t < right_ts[idx]; }); + + for (auto it = lower; it != upper; ++it) { + left_indices.push_back(i); + right_indices.push_back(*it); + } + } + + // Build result table (similar to aj) + Table result("wj_result"); + + // Add left columns + for (const auto& col_name : left.column_names()) { + const Column& src = left.column(col_name); + size_t elem_size = dtype_size(src.dtype()); + std::vector data(left_indices.size() * elem_size); + + const uint8_t* src_data = static_cast(src.data()); + for (size_t i = 0; i < left_indices.size(); ++i) { + std::memcpy(data.data() + i * elem_size, + src_data + left_indices[i] * elem_size, + elem_size); + } + + result.add_column(Column(col_name, src.dtype(), std::move(data))); + } + + // Add right columns (excluding join keys) + for (const auto& col_name : right.column_names()) { + if (result.has_column(col_name)) continue; + if (std::find(on.begin(), on.end(), col_name) != on.end()) continue; + + const Column& src = right.column(col_name); + size_t elem_size = dtype_size(src.dtype()); + std::vector data(right_indices.size() * elem_size); + + const uint8_t* src_data = static_cast(src.data()); + for (size_t i = 0; i < right_indices.size(); ++i) { + std::memcpy(data.data() + i * elem_size, + src_data + right_indices[i] * elem_size, + elem_size); + } + + result.add_column(Column(col_name, src.dtype(), std::move(data))); + } + + if (!result.column_names().empty()) { + result.set_sorted_by(as_of); + } + return result; +} + +Table inner_join(const Table& left, const Table& right, + const std::vector& on) { + // TODO: Implement inner join + throw InvalidOperation("inner_join not yet implemented"); +} + +Table left_join(const Table& left, const Table& right, + const std::vector& on) { + // TODO: Implement left join + throw InvalidOperation("left_join not yet implemented"); +} + +} // namespace wayy_db::ops diff --git a/src/ops/window.cpp b/src/ops/window.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0eae0e9fcda3942b6a3dbd8525aa94587284efea --- /dev/null +++ b/src/ops/window.cpp @@ -0,0 +1,314 @@ +#include "wayy_db/ops/window.hpp" + +#include +#include +#include + +namespace wayy_db::ops { + +// Moving average + +std::vector mavg(const ColumnView& col, size_t window) { + if (col.empty() || window == 0) return {}; + + std::vector result(col.size()); + double sum = 0.0; + + for (size_t i = 0; i < col.size(); ++i) { + sum += col[i]; + if (i >= window) { + sum -= col[i - window]; + result[i] = sum / static_cast(window); + } else { + result[i] = sum / static_cast(i + 1); + } + } + + return result; +} + +std::vector mavg(const ColumnView& col, size_t window) { + if (col.empty() || window == 0) return {}; + + std::vector result(col.size()); + int64_t sum = 0; + + for (size_t i = 0; i < col.size(); ++i) { + sum += col[i]; + if (i >= window) { + sum -= col[i - window]; + result[i] = static_cast(sum) / static_cast(window); + } else { + result[i] = static_cast(sum) / static_cast(i + 1); + } + } + + return result; +} + +// Moving sum + +std::vector msum(const ColumnView& col, size_t window) { + if (col.empty() || window == 0) return {}; + + std::vector result(col.size()); + double sum = 0.0; + + for (size_t i = 0; i < col.size(); ++i) { + sum += col[i]; + if (i >= window) { + sum -= col[i - window]; + } + result[i] = sum; + } + + return result; +} + +std::vector msum(const ColumnView& col, size_t window) { + if (col.empty() || window == 0) return {}; + + std::vector result(col.size()); + int64_t sum = 0; + + for (size_t i = 0; i < col.size(); ++i) { + sum += col[i]; + if (i >= window) { + sum -= col[i - window]; + } + result[i] = sum; + } + + return result; +} + +// Moving standard deviation (Welford's online algorithm) + +std::vector mstd(const ColumnView& col, size_t window) { + if (col.empty() || window == 0) return {}; + + std::vector result(col.size()); + + for (size_t i = 0; i < col.size(); ++i) { + size_t start = (i >= window) ? i - window + 1 : 0; + size_t count = i - start + 1; + + double mean = 0.0; + double m2 = 0.0; + size_t n = 0; + + for (size_t j = start; j <= i; ++j) { + ++n; + double delta = col[j] - mean; + mean += delta / static_cast(n); + double delta2 = col[j] - mean; + m2 += delta * delta2; + } + + result[i] = (n > 1) ? std::sqrt(m2 / static_cast(n)) : 0.0; + } + + return result; +} + +std::vector mstd(const ColumnView& col, size_t window) { + if (col.empty() || window == 0) return {}; + + std::vector result(col.size()); + + for (size_t i = 0; i < col.size(); ++i) { + size_t start = (i >= window) ? i - window + 1 : 0; + + double mean = 0.0; + double m2 = 0.0; + size_t n = 0; + + for (size_t j = start; j <= i; ++j) { + ++n; + double val = static_cast(col[j]); + double delta = val - mean; + mean += delta / static_cast(n); + double delta2 = val - mean; + m2 += delta * delta2; + } + + result[i] = (n > 1) ? std::sqrt(m2 / static_cast(n)) : 0.0; + } + + return result; +} + +// Moving min/max using monotonic deque for O(n) complexity + +template +std::vector monotonic_window(const ColumnView& col, size_t window, Compare cmp) { + if (col.empty() || window == 0) return {}; + + std::vector result(col.size()); + std::deque dq; // Indices + + for (size_t i = 0; i < col.size(); ++i) { + // Remove elements outside window + while (!dq.empty() && dq.front() + window <= i) { + dq.pop_front(); + } + + // Remove elements that won't be min/max + while (!dq.empty() && cmp(col[i], col[dq.back()])) { + dq.pop_back(); + } + + dq.push_back(i); + result[i] = col[dq.front()]; + } + + return result; +} + +std::vector mmin(const ColumnView& col, size_t window) { + return monotonic_window(col, window, std::less{}); +} + +std::vector mmin(const ColumnView& col, size_t window) { + return monotonic_window(col, window, std::less{}); +} + +std::vector mmax(const ColumnView& col, size_t window) { + return monotonic_window(col, window, std::greater{}); +} + +std::vector mmax(const ColumnView& col, size_t window) { + return monotonic_window(col, window, std::greater{}); +} + +// Exponential moving average + +std::vector ema(const ColumnView& col, double alpha) { + if (col.empty()) return {}; + if (alpha <= 0.0 || alpha > 1.0) { + throw std::invalid_argument("EMA alpha must be in (0, 1]"); + } + + std::vector result(col.size()); + result[0] = col[0]; + + for (size_t i = 1; i < col.size(); ++i) { + result[i] = alpha * col[i] + (1.0 - alpha) * result[i - 1]; + } + + return result; +} + +std::vector ema(const ColumnView& col, double alpha) { + if (col.empty()) return {}; + if (alpha <= 0.0 || alpha > 1.0) { + throw std::invalid_argument("EMA alpha must be in (0, 1]"); + } + + std::vector result(col.size()); + result[0] = static_cast(col[0]); + + for (size_t i = 1; i < col.size(); ++i) { + result[i] = alpha * static_cast(col[i]) + (1.0 - alpha) * result[i - 1]; + } + + return result; +} + +std::vector ema_span(const ColumnView& col, size_t span) { + double alpha = 2.0 / (static_cast(span) + 1.0); + return ema(col, alpha); +} + +// Diff + +std::vector diff(const ColumnView& col, size_t periods) { + if (col.empty() || periods >= col.size()) return std::vector(col.size(), 0.0); + + std::vector result(col.size()); + for (size_t i = 0; i < periods; ++i) { + result[i] = std::numeric_limits::quiet_NaN(); + } + for (size_t i = periods; i < col.size(); ++i) { + result[i] = col[i] - col[i - periods]; + } + + return result; +} + +std::vector diff(const ColumnView& col, size_t periods) { + if (col.empty() || periods >= col.size()) return std::vector(col.size(), 0); + + std::vector result(col.size(), 0); + for (size_t i = periods; i < col.size(); ++i) { + result[i] = col[i] - col[i - periods]; + } + + return result; +} + +// Percent change + +std::vector pct_change(const ColumnView& col, size_t periods) { + if (col.empty() || periods >= col.size()) { + return std::vector(col.size(), std::numeric_limits::quiet_NaN()); + } + + std::vector result(col.size()); + for (size_t i = 0; i < periods; ++i) { + result[i] = std::numeric_limits::quiet_NaN(); + } + for (size_t i = periods; i < col.size(); ++i) { + if (col[i - periods] != 0.0) { + result[i] = (col[i] - col[i - periods]) / col[i - periods]; + } else { + result[i] = std::numeric_limits::quiet_NaN(); + } + } + + return result; +} + +// Shift + +std::vector shift(const ColumnView& col, int64_t n) { + if (col.empty()) return {}; + + std::vector result(col.size(), std::numeric_limits::quiet_NaN()); + + if (n >= 0) { + size_t offset = static_cast(n); + for (size_t i = offset; i < col.size(); ++i) { + result[i] = col[i - offset]; + } + } else { + size_t offset = static_cast(-n); + for (size_t i = 0; i + offset < col.size(); ++i) { + result[i] = col[i + offset]; + } + } + + return result; +} + +std::vector shift(const ColumnView& col, int64_t n) { + if (col.empty()) return {}; + + std::vector result(col.size(), 0); + + if (n >= 0) { + size_t offset = static_cast(n); + for (size_t i = offset; i < col.size(); ++i) { + result[i] = col[i - offset]; + } + } else { + size_t offset = static_cast(-n); + for (size_t i = 0; i + offset < col.size(); ++i) { + result[i] = col[i + offset]; + } + } + + return result; +} + +} // namespace wayy_db::ops diff --git a/src/string_column.cpp b/src/string_column.cpp new file mode 100644 index 0000000000000000000000000000000000000000..db3dd8f5b31e2e0ac9df46f9a0f9a5afb99868fd --- /dev/null +++ b/src/string_column.cpp @@ -0,0 +1,224 @@ +#include "wayy_db/string_column.hpp" + +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +namespace wayy_db { + +StringColumn::StringColumn(std::string name) : name_(std::move(name)) { + offsets_.push_back(0); // Initial offset +} + +std::string_view StringColumn::get(size_t row) const { + if (row >= size()) { + throw InvalidOperation("StringColumn row out of range"); + } + if (has_validity_ && !is_valid(row)) { + return {}; // Null row returns empty view + } + int64_t start = offsets_[row]; + int64_t end = offsets_[row + 1]; + return std::string_view(reinterpret_cast(data_.data() + start), + static_cast(end - start)); +} + +void StringColumn::append(std::string_view val) { + int64_t offset = offsets_.back(); + data_.insert(data_.end(), val.begin(), val.end()); + offsets_.push_back(offset + static_cast(val.size())); + + if (has_validity_) { + size_t row = size() - 1; + size_t needed_bytes = (size() + 7) / 8; + if (validity_.size() < needed_bytes) { + validity_.push_back(0); + } + set_valid(row, true); + } +} + +void StringColumn::append_null() { + offsets_.push_back(offsets_.back()); // Zero-length entry + ensure_validity(); + set_valid(size() - 1, false); +} + +void StringColumn::set(size_t row, std::string_view val) { + if (row >= size()) { + throw InvalidOperation("StringColumn row out of range in set"); + } + int64_t old_start = offsets_[row]; + int64_t old_end = offsets_[row + 1]; + int64_t old_len = old_end - old_start; + int64_t new_len = static_cast(val.size()); + + if (new_len <= old_len) { + // Fits in-place: overwrite and zero-pad remainder + std::memcpy(data_.data() + old_start, val.data(), val.size()); + if (new_len < old_len) { + std::memset(data_.data() + old_start + new_len, 0, + static_cast(old_len - new_len)); + } + // Update offsets: shift this entry's end + offsets_[row + 1] = old_start + new_len; + // NOTE: This changes the offset for subsequent rows if they shared + // contiguous data. For OLTP use (row-level updates), this is fine + // because compact() will fix fragmentation. + } else { + // Doesn't fit: append to end of data buffer, old slot becomes waste + int64_t new_start = static_cast(data_.size()); + data_.insert(data_.end(), val.begin(), val.end()); + offsets_[row] = new_start; + offsets_[row + 1] = new_start + new_len; + } + + if (has_validity_) { + set_valid(row, true); + } +} + +// --- Validity bitmap --- + +void StringColumn::ensure_validity() { + if (has_validity_) return; + size_t n = size(); + size_t num_bytes = (n + 7) / 8; + validity_.assign(num_bytes, 0xFF); + if (n % 8 != 0) { + uint8_t mask = static_cast((1u << (n % 8)) - 1); + validity_.back() = mask; + } + has_validity_ = true; +} + +bool StringColumn::is_valid(size_t row) const { + if (!has_validity_) return true; + if (row >= size()) return false; + return (validity_[row / 8] >> (row % 8)) & 1; +} + +void StringColumn::set_valid(size_t row, bool valid) { + if (!has_validity_) ensure_validity(); + if (row >= size()) return; + if (valid) { + validity_[row / 8] |= (1u << (row % 8)); + } else { + validity_[row / 8] &= ~(1u << (row % 8)); + } +} + +size_t StringColumn::count_valid() const { + if (!has_validity_) return size(); + size_t count = 0; + for (auto byte : validity_) { + count += std::popcount(byte); + } + return count; +} + +// --- Persistence --- +// Files: /.offsets, .data, .validity + +void StringColumn::save(const std::string& dir_path, const std::string& col_name) const { + fs::create_directories(dir_path); + + // Write offsets + { + std::string path = dir_path + "/" + col_name + ".offsets"; + std::ofstream f(path, std::ios::binary); + if (!f) throw WayyException("Failed to create offsets file: " + path); + uint64_t count = offsets_.size(); + f.write(reinterpret_cast(&count), sizeof(count)); + f.write(reinterpret_cast(offsets_.data()), + static_cast(offsets_.size() * sizeof(int64_t))); + } + + // Write data + { + std::string path = dir_path + "/" + col_name + ".data"; + std::ofstream f(path, std::ios::binary); + if (!f) throw WayyException("Failed to create data file: " + path); + uint64_t sz = data_.size(); + f.write(reinterpret_cast(&sz), sizeof(sz)); + f.write(reinterpret_cast(data_.data()), + static_cast(data_.size())); + } + + // Write validity if present + if (has_validity_) { + std::string path = dir_path + "/" + col_name + ".validity"; + std::ofstream f(path, std::ios::binary); + if (!f) throw WayyException("Failed to create validity file: " + path); + uint64_t sz = validity_.size(); + f.write(reinterpret_cast(&sz), sizeof(sz)); + f.write(reinterpret_cast(validity_.data()), + static_cast(validity_.size())); + } +} + +StringColumn StringColumn::load(const std::string& dir_path, const std::string& col_name) { + StringColumn sc(col_name); + sc.offsets_.clear(); + + // Read offsets + { + std::string path = dir_path + "/" + col_name + ".offsets"; + std::ifstream f(path, std::ios::binary); + if (!f) throw WayyException("Failed to open offsets file: " + path); + uint64_t count = 0; + f.read(reinterpret_cast(&count), sizeof(count)); + sc.offsets_.resize(count); + f.read(reinterpret_cast(sc.offsets_.data()), + static_cast(count * sizeof(int64_t))); + } + + // Read data + { + std::string path = dir_path + "/" + col_name + ".data"; + std::ifstream f(path, std::ios::binary); + if (!f) throw WayyException("Failed to open data file: " + path); + uint64_t sz = 0; + f.read(reinterpret_cast(&sz), sizeof(sz)); + sc.data_.resize(sz); + f.read(reinterpret_cast(sc.data_.data()), + static_cast(sz)); + } + + // Read validity if present + { + std::string path = dir_path + "/" + col_name + ".validity"; + if (fs::exists(path)) { + std::ifstream f(path, std::ios::binary); + if (f) { + uint64_t sz = 0; + f.read(reinterpret_cast(&sz), sizeof(sz)); + sc.validity_.resize(sz); + f.read(reinterpret_cast(sc.validity_.data()), + static_cast(sz)); + sc.has_validity_ = true; + } + } + } + + return sc; +} + +std::vector StringColumn::to_vector() const { + std::vector result; + result.reserve(size()); + for (size_t i = 0; i < size(); ++i) { + if (is_valid(i)) { + result.emplace_back(get(i)); + } else { + result.emplace_back(); + } + } + return result; +} + +} // namespace wayy_db diff --git a/src/table.cpp b/src/table.cpp new file mode 100644 index 0000000000000000000000000000000000000000..db43071a52aaf5a32d0304f1c49a5c9f066dee7a --- /dev/null +++ b/src/table.cpp @@ -0,0 +1,778 @@ +#include "wayy_db/table.hpp" +#include "wayy_db/hash_index.hpp" + +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +namespace wayy_db { + +Table::Table(std::string name) : name_(std::move(name)) {} + +Table::~Table() = default; + +Table::Table(Table&& other) noexcept + : name_(std::move(other.name_)), + num_rows_(other.num_rows_), + columns_(std::move(other.columns_)), + column_index_(std::move(other.column_index_)), + sorted_by_(std::move(other.sorted_by_)), + string_columns_(std::move(other.string_columns_)), + string_column_index_(std::move(other.string_column_index_)), + primary_key_(std::move(other.primary_key_)), + pk_index_(std::move(other.pk_index_)), + mmap_files_(std::move(other.mmap_files_)) { + other.num_rows_ = 0; +} + +Table& Table::operator=(Table&& other) noexcept { + if (this != &other) { + name_ = std::move(other.name_); + num_rows_ = other.num_rows_; + columns_ = std::move(other.columns_); + column_index_ = std::move(other.column_index_); + sorted_by_ = std::move(other.sorted_by_); + string_columns_ = std::move(other.string_columns_); + string_column_index_ = std::move(other.string_column_index_); + primary_key_ = std::move(other.primary_key_); + pk_index_ = std::move(other.pk_index_); + mmap_files_ = std::move(other.mmap_files_); + other.num_rows_ = 0; + } + return *this; +} + +// --- Fixed-width column management --- + +void Table::add_column(Column column) { + if (columns_.empty() && string_columns_.empty()) { + num_rows_ = column.size(); + } else if (column.size() != num_rows_) { + throw InvalidOperation( + "Column size mismatch: expected " + std::to_string(num_rows_) + + ", got " + std::to_string(column.size())); + } + + const std::string& col_name = column.name(); + if (column_index_.count(col_name) || string_column_index_.count(col_name)) { + throw InvalidOperation("Column already exists: " + col_name); + } + + column_index_[col_name] = columns_.size(); + columns_.push_back(std::move(column)); +} + +void Table::add_column(const std::string& name, DType dtype, void* data, size_t size) { + add_column(Column(name, dtype, data, size, true)); +} + +// --- String column management --- + +void Table::add_string_column(StringColumn col) { + if (columns_.empty() && string_columns_.empty()) { + num_rows_ = col.size(); + } else if (col.size() != num_rows_) { + throw InvalidOperation( + "StringColumn size mismatch: expected " + std::to_string(num_rows_) + + ", got " + std::to_string(col.size())); + } + + const std::string& col_name = col.name(); + if (column_index_.count(col_name) || string_column_index_.count(col_name)) { + throw InvalidOperation("Column already exists: " + col_name); + } + + string_column_index_[col_name] = string_columns_.size(); + string_columns_.push_back(std::move(col)); +} + +bool Table::has_string_column(const std::string& name) const { + return string_column_index_.count(name) > 0; +} + +StringColumn& Table::string_column(const std::string& name) { + auto it = string_column_index_.find(name); + if (it == string_column_index_.end()) { + throw ColumnNotFound(name); + } + return string_columns_[it->second]; +} + +const StringColumn& Table::string_column(const std::string& name) const { + auto it = string_column_index_.find(name); + if (it == string_column_index_.end()) { + throw ColumnNotFound(name); + } + return string_columns_[it->second]; +} + +// --- General column queries --- + +bool Table::has_column(const std::string& name) const { + return column_index_.count(name) > 0 || string_column_index_.count(name) > 0; +} + +Column& Table::column(const std::string& name) { + auto it = column_index_.find(name); + if (it == column_index_.end()) { + throw ColumnNotFound(name); + } + return columns_[it->second]; +} + +const Column& Table::column(const std::string& name) const { + auto it = column_index_.find(name); + if (it == column_index_.end()) { + throw ColumnNotFound(name); + } + return columns_[it->second]; +} + +DType Table::column_dtype(const std::string& name) const { + auto it = column_index_.find(name); + if (it != column_index_.end()) { + return columns_[it->second].dtype(); + } + auto sit = string_column_index_.find(name); + if (sit != string_column_index_.end()) { + return DType::String; + } + throw ColumnNotFound(name); +} + +std::vector Table::column_names() const { + std::vector names; + names.reserve(columns_.size() + string_columns_.size()); + for (const auto& col : columns_) { + names.push_back(col.name()); + } + for (const auto& col : string_columns_) { + names.push_back(col.name()); + } + return names; +} + +void Table::set_sorted_by(const std::string& col) { + if (!has_column(col)) { + throw ColumnNotFound(col); + } + sorted_by_ = col; +} + +// --- Primary key + hash index --- + +void Table::set_primary_key(const std::string& col_name) { + if (!has_column(col_name)) { + throw ColumnNotFound(col_name); + } + primary_key_ = col_name; + rebuild_index(); +} + +void Table::rebuild_index() { + if (!primary_key_) return; + + pk_index_ = std::make_unique(); + DType pk_dtype = column_dtype(*primary_key_); + + if (pk_dtype == DType::String) { + pk_index_->build_str(*this, *primary_key_); + } else if (pk_dtype == DType::Int64 || pk_dtype == DType::Timestamp || pk_dtype == DType::Decimal6) { + pk_index_->build_int(*this, *primary_key_); + } else { + throw InvalidOperation("Primary key must be String, Int64, Timestamp, or Decimal6"); + } +} + +std::optional Table::find_row(int64_t key) const { + if (!pk_index_) return std::nullopt; + auto row = pk_index_->find_int(key); + if (row && !columns_.empty() && columns_[0].has_validity()) { + // Check validity of any fixed column + if (!columns_[0].is_valid(*row)) return std::nullopt; + } + return row; +} + +std::optional Table::find_row(std::string_view key) const { + if (!pk_index_) return std::nullopt; + auto row = pk_index_->find_str(key); + if (row) { + // Check validity via the PK string column itself + const auto& pk_col = string_column(*primary_key_); + if (pk_col.has_validity() && !pk_col.is_valid(*row)) return std::nullopt; + } + return row; +} + +// --- CRUD operations --- + +size_t Table::append_row(const std::unordered_map& values) { + size_t row_idx = num_rows_; + + // Append to each fixed-width column + for (auto& col : columns_) { + auto it = values.find(col.name()); + if (it == values.end()) { + // Append default (zero) value + uint8_t zeros[8] = {}; + col.append(zeros, dtype_size(col.dtype())); + col.ensure_validity(); + col.set_valid(row_idx, false); // Mark as null + } else { + const auto& val = it->second; + DType dt = col.dtype(); + + if (dt == DType::Int64 || dt == DType::Timestamp || dt == DType::Decimal6) { + int64_t v = std::any_cast(val); + col.append(&v, sizeof(v)); + } else if (dt == DType::Float64) { + double v = std::any_cast(val); + col.append(&v, sizeof(v)); + } else if (dt == DType::Symbol) { + uint32_t v = std::any_cast(val); + col.append(&v, sizeof(v)); + } else if (dt == DType::Bool) { + uint8_t v = std::any_cast(val); + col.append(&v, sizeof(v)); + } + } + } + + // Append to each string column + for (auto& scol : string_columns_) { + auto it = values.find(scol.name()); + if (it == values.end()) { + scol.append_null(); + } else { + auto sv = std::any_cast(it->second); + scol.append(sv); + } + } + + ++num_rows_; + + // Update index + if (pk_index_ && primary_key_) { + DType pk_dtype = column_dtype(*primary_key_); + auto it = values.find(*primary_key_); + if (it != values.end()) { + if (pk_dtype == DType::String) { + pk_index_->insert_str(std::any_cast(it->second), row_idx); + } else { + pk_index_->insert_int(std::any_cast(it->second), row_idx); + } + } + } + + return row_idx; +} + +bool Table::update_row(int64_t pk, const std::unordered_map& values) { + auto row = find_row(pk); + if (!row) return false; + return update_row_at(*row, values); +} + +bool Table::update_row(std::string_view pk, const std::unordered_map& values) { + auto row = find_row(pk); + if (!row) return false; + return update_row_at(*row, values); +} + +bool Table::update_row_at(size_t row_idx, const std::unordered_map& values) { + if (row_idx >= num_rows_) return false; + + for (const auto& [col_name, val] : values) { + // Check if it's a string column + auto sit = string_column_index_.find(col_name); + if (sit != string_column_index_.end()) { + auto sv = std::any_cast(val); + string_columns_[sit->second].set(row_idx, sv); + continue; + } + + // Fixed-width column + auto it = column_index_.find(col_name); + if (it == column_index_.end()) continue; // Skip unknown columns + + Column& col = columns_[it->second]; + DType dt = col.dtype(); + + if (dt == DType::Int64 || dt == DType::Timestamp || dt == DType::Decimal6) { + int64_t v = std::any_cast(val); + col.set(row_idx, &v, sizeof(v)); + } else if (dt == DType::Float64) { + double v = std::any_cast(val); + col.set(row_idx, &v, sizeof(v)); + } else if (dt == DType::Symbol) { + uint32_t v = std::any_cast(val); + col.set(row_idx, &v, sizeof(v)); + } else if (dt == DType::Bool) { + uint8_t v = std::any_cast(val); + col.set(row_idx, &v, sizeof(v)); + } + } + + return true; +} + +bool Table::delete_row(int64_t pk) { + auto row = find_row(pk); + if (!row) return false; + + // Soft delete: set validity bit to 0 on all columns + for (auto& col : columns_) { + col.ensure_validity(); + col.set_valid(*row, false); + } + for (auto& scol : string_columns_) { + scol.set_valid(*row, false); + } + + // Remove from index + if (pk_index_) { + pk_index_->remove_int(pk); + } + + return true; +} + +bool Table::delete_row(std::string_view pk) { + auto row = find_row(pk); + if (!row) return false; + + for (auto& col : columns_) { + col.ensure_validity(); + col.set_valid(*row, false); + } + for (auto& scol : string_columns_) { + scol.set_valid(*row, false); + } + + if (pk_index_) { + pk_index_->remove_str(pk); + } + + return true; +} + +// --- Filter --- + +std::vector Table::where_eq(const std::string& col_name, int64_t val) const { + std::vector result; + auto it = column_index_.find(col_name); + if (it == column_index_.end()) throw ColumnNotFound(col_name); + + const Column& col = columns_[it->second]; + auto view = col.as(); + for (size_t i = 0; i < view.size(); ++i) { + if (col.is_valid(i) && view[i] == val) { + result.push_back(i); + } + } + return result; +} + +std::vector Table::where_eq(const std::string& col_name, std::string_view val) const { + std::vector result; + auto sit = string_column_index_.find(col_name); + if (sit == string_column_index_.end()) throw ColumnNotFound(col_name); + + const StringColumn& scol = string_columns_[sit->second]; + for (size_t i = 0; i < scol.size(); ++i) { + if (scol.is_valid(i) && scol.get(i) == val) { + result.push_back(i); + } + } + return result; +} + +// --- Compaction --- + +void Table::compact() { + // Determine which rows are valid (check first available column) + std::vector keep(num_rows_, true); + bool any_deleted = false; + + // Check fixed columns for validity + for (const auto& col : columns_) { + if (col.has_validity()) { + for (size_t i = 0; i < num_rows_; ++i) { + if (!col.is_valid(i)) { + keep[i] = false; + any_deleted = true; + } + } + break; // Only need to check one column + } + } + + // Also check string columns + if (!any_deleted) { + for (const auto& scol : string_columns_) { + if (scol.has_validity()) { + for (size_t i = 0; i < scol.size(); ++i) { + if (!scol.is_valid(i)) { + keep[i] = false; + any_deleted = true; + } + } + break; + } + } + } + + if (!any_deleted) return; // Nothing to compact + + // Count new rows + size_t new_rows = 0; + for (bool k : keep) { + if (k) ++new_rows; + } + + // Compact fixed columns + for (size_t ci = 0; ci < columns_.size(); ++ci) { + Column& col = columns_[ci]; + size_t elem_size = dtype_size(col.dtype()); + std::vector new_data; + new_data.reserve(new_rows * elem_size); + + const uint8_t* src = static_cast(col.data()); + for (size_t i = 0; i < num_rows_; ++i) { + if (keep[i]) { + new_data.insert(new_data.end(), src + i * elem_size, src + (i + 1) * elem_size); + } + } + + // Replace column + std::string cname = col.name(); + DType cdtype = col.dtype(); + columns_[ci] = Column(std::move(cname), cdtype, std::move(new_data)); + } + + // Compact string columns + for (size_t si = 0; si < string_columns_.size(); ++si) { + StringColumn& scol = string_columns_[si]; + StringColumn new_scol(scol.name()); + for (size_t i = 0; i < scol.size(); ++i) { + if (keep[i]) { + if (scol.is_valid(i)) { + new_scol.append(scol.get(i)); + } else { + new_scol.append_null(); + } + } + } + string_columns_[si] = std::move(new_scol); + } + + num_rows_ = new_rows; + + // Rebuild index + rebuild_index(); +} + +// --- Persistence --- + +void Table::save(const std::string& dir_path) const { + fs::create_directories(dir_path); + + // Write metadata + write_metadata(dir_path); + + // Write each fixed-width column + for (const auto& col : columns_) { + std::string col_path = dir_path + "/" + col.name() + ".col"; + std::ofstream file(col_path, std::ios::binary); + + if (!file) { + throw WayyException("Failed to create column file: " + col_path); + } + + // Write header + ColumnHeader header{}; + header.magic = WAYY_MAGIC; + header.version = WAYY_VERSION; + header.dtype = col.dtype(); + header.row_count = col.size(); + header.compression = 0; + header.data_offset = sizeof(ColumnHeader); + + file.write(reinterpret_cast(&header), sizeof(header)); + + // Write data + file.write(static_cast(col.data()), col.byte_size()); + + // Write validity bitmap if present + if (col.has_validity()) { + std::string vpath = dir_path + "/" + col.name() + ".validity"; + std::ofstream vf(vpath, std::ios::binary); + if (vf) { + const auto& bmap = col.validity_bitmap(); + uint64_t sz = bmap.size(); + vf.write(reinterpret_cast(&sz), sizeof(sz)); + vf.write(reinterpret_cast(bmap.data()), + static_cast(sz)); + } + } + } + + // Write each string column + for (const auto& scol : string_columns_) { + scol.save(dir_path, scol.name()); + } +} + +void Table::write_metadata(const std::string& dir_path) const { + std::string meta_path = dir_path + "/_meta.json"; + std::ofstream file(meta_path); + + if (!file) { + throw WayyException("Failed to create metadata file: " + meta_path); + } + + file << "{\n"; + file << " \"version\": " << WAYY_VERSION << ",\n"; + file << " \"name\": \"" << name_ << "\",\n"; + file << " \"num_rows\": " << num_rows_ << ",\n"; + + if (sorted_by_) { + file << " \"sorted_by\": \"" << *sorted_by_ << "\",\n"; + } else { + file << " \"sorted_by\": null,\n"; + } + + if (primary_key_) { + file << " \"primary_key\": \"" << *primary_key_ << "\",\n"; + } else { + file << " \"primary_key\": null,\n"; + } + + file << " \"columns\": [\n"; + size_t total_cols = columns_.size() + string_columns_.size(); + size_t idx = 0; + for (const auto& col : columns_) { + file << " {\"name\": \"" << col.name() + << "\", \"dtype\": \"" << dtype_to_string(col.dtype()) << "\"}"; + if (++idx < total_cols) file << ","; + file << "\n"; + } + for (const auto& scol : string_columns_) { + file << " {\"name\": \"" << scol.name() + << "\", \"dtype\": \"string\"}"; + if (++idx < total_cols) file << ","; + file << "\n"; + } + file << " ]\n"; + file << "}\n"; +} + +Table Table::load(const std::string& dir_path) { + auto [name, num_rows, sorted_by, primary_key, col_info] = read_metadata(dir_path); + + Table table(name); + + for (const auto& [col_name, dtype] : col_info) { + if (dtype == DType::String) { + // Load string column + table.add_string_column(StringColumn::load(dir_path, col_name)); + } else { + // Load fixed-width column + std::string col_path = dir_path + "/" + col_name + ".col"; + std::ifstream file(col_path, std::ios::binary); + + if (!file) { + throw WayyException("Failed to open column file: " + col_path); + } + + // Read header + ColumnHeader header; + file.read(reinterpret_cast(&header), sizeof(header)); + + if (header.magic != WAYY_MAGIC) { + throw WayyException("Invalid column file magic: " + col_path); + } + + // Read data + size_t byte_size = header.row_count * dtype_size(header.dtype); + std::vector data(byte_size); + file.read(reinterpret_cast(data.data()), byte_size); + + Column col(col_name, header.dtype, std::move(data)); + + // Load validity bitmap if present + std::string vpath = dir_path + "/" + col_name + ".validity"; + if (fs::exists(vpath)) { + std::ifstream vf(vpath, std::ios::binary); + if (vf) { + uint64_t sz = 0; + vf.read(reinterpret_cast(&sz), sizeof(sz)); + std::vector bitmap(sz); + vf.read(reinterpret_cast(bitmap.data()), + static_cast(sz)); + col.set_validity_bitmap(std::move(bitmap)); + } + } + + table.add_column(std::move(col)); + } + } + + if (sorted_by) { + table.set_sorted_by(*sorted_by); + } + + if (primary_key) { + table.set_primary_key(*primary_key); + } + + return table; +} + +Table Table::mmap(const std::string& dir_path) { + auto [name, num_rows, sorted_by, primary_key, col_info] = read_metadata(dir_path); + + Table table(name); + + for (const auto& [col_name, dtype] : col_info) { + if (dtype == DType::String) { + // String columns are loaded (not mmap'd) since they have complex structure + table.add_string_column(StringColumn::load(dir_path, col_name)); + } else { + std::string col_path = dir_path + "/" + col_name + ".col"; + + MmapFile mmap_file(col_path, MmapFile::Mode::ReadOnly); + + // Validate header + auto* header = static_cast(mmap_file.data()); + if (header->magic != WAYY_MAGIC) { + throw WayyException("Invalid column file magic: " + col_path); + } + + // Create column pointing to mmap'd data + void* data_ptr = static_cast(mmap_file.data()) + header->data_offset; + Column col(col_name, header->dtype, data_ptr, header->row_count, false); + + // Load validity bitmap (always into memory, small) + std::string vpath = dir_path + "/" + col_name + ".validity"; + if (fs::exists(vpath)) { + std::ifstream vf(vpath, std::ios::binary); + if (vf) { + uint64_t sz = 0; + vf.read(reinterpret_cast(&sz), sizeof(sz)); + std::vector bitmap(sz); + vf.read(reinterpret_cast(bitmap.data()), + static_cast(sz)); + col.set_validity_bitmap(std::move(bitmap)); + } + } + + table.add_column(std::move(col)); + + // Keep mmap file alive + table.mmap_files_.push_back(std::move(mmap_file)); + } + } + + if (sorted_by) { + table.set_sorted_by(*sorted_by); + } + + if (primary_key) { + table.set_primary_key(*primary_key); + } + + return table; +} + +std::tuple, + std::optional, + std::vector>> +Table::read_metadata(const std::string& dir_path) { + std::string meta_path = dir_path + "/_meta.json"; + std::ifstream file(meta_path); + + if (!file) { + throw WayyException("Failed to open metadata file: " + meta_path); + } + + // Simple JSON parsing (minimal implementation) + std::stringstream buffer; + buffer << file.rdbuf(); + std::string json = buffer.str(); + + // Extract fields using simple string parsing + auto extract_string = [&json](const std::string& key) -> std::string { + std::string pattern = "\"" + key + "\": \""; + auto pos = json.find(pattern); + if (pos == std::string::npos) return ""; + pos += pattern.size(); + auto end = json.find("\"", pos); + return json.substr(pos, end - pos); + }; + + auto extract_int = [&json](const std::string& key) -> size_t { + std::string pattern = "\"" + key + "\": "; + auto pos = json.find(pattern); + if (pos == std::string::npos) return 0; + pos += pattern.size(); + return std::stoull(json.substr(pos)); + }; + + std::string name = extract_string("name"); + size_t num_rows_val = extract_int("num_rows"); + + std::optional sorted_by; + std::string sorted_str = extract_string("sorted_by"); + if (!sorted_str.empty()) { + sorted_by = sorted_str; + } + + std::optional primary_key; + std::string pk_str = extract_string("primary_key"); + if (!pk_str.empty()) { + primary_key = pk_str; + } + + // Parse columns array + std::vector> columns; + auto cols_start = json.find("\"columns\":"); + if (cols_start != std::string::npos) { + auto arr_start = json.find("[", cols_start); + auto arr_end = json.find("]", arr_start); + std::string arr = json.substr(arr_start, arr_end - arr_start + 1); + + size_t pos = 0; + while ((pos = arr.find("{", pos)) != std::string::npos) { + auto obj_end = arr.find("}", pos); + std::string obj = arr.substr(pos, obj_end - pos + 1); + + // Extract name and dtype from object + auto name_pos = obj.find("\"name\": \""); + if (name_pos != std::string::npos) { + name_pos += 9; + auto name_end = obj.find("\"", name_pos); + std::string col_name = obj.substr(name_pos, name_end - name_pos); + + auto dtype_pos = obj.find("\"dtype\": \""); + dtype_pos += 10; + auto dtype_end = obj.find("\"", dtype_pos); + std::string dtype_str = obj.substr(dtype_pos, dtype_end - dtype_pos); + + columns.emplace_back(col_name, dtype_from_string(dtype_str)); + } + + pos = obj_end + 1; + } + } + + return {name, num_rows_val, sorted_by, primary_key, columns}; +} + +} // namespace wayy_db diff --git a/src/types.cpp b/src/types.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d8bd72f9aa37aae07ca66bc3a323033b260a182b --- /dev/null +++ b/src/types.cpp @@ -0,0 +1,25 @@ +#include "wayy_db/types.hpp" + +#include + +namespace wayy_db { + +DType dtype_from_string(std::string_view s) { + static const std::unordered_map map = { + {"int64", DType::Int64}, + {"float64", DType::Float64}, + {"timestamp", DType::Timestamp}, + {"symbol", DType::Symbol}, + {"bool", DType::Bool}, + {"string", DType::String}, + {"decimal6", DType::Decimal6}, + }; + + auto it = map.find(s); + if (it == map.end()) { + throw WayyException("Unknown dtype: " + std::string(s)); + } + return it->second; +} + +} // namespace wayy_db diff --git a/src/wal.cpp b/src/wal.cpp new file mode 100644 index 0000000000000000000000000000000000000000..40fc9bfa0bfa8a70863541566bcb197f2659da7c --- /dev/null +++ b/src/wal.cpp @@ -0,0 +1,225 @@ +#include "wayy_db/wal.hpp" +#include "wayy_db/database.hpp" + +#include +#include +#include + +namespace fs = std::filesystem; + +namespace wayy_db { + +// Simple CRC32 (IEEE polynomial) +static const std::array crc32_table = [] { + std::array table{}; + for (uint32_t i = 0; i < 256; ++i) { + uint32_t crc = i; + for (int j = 0; j < 8; ++j) { + crc = (crc >> 1) ^ ((crc & 1) ? 0xEDB88320u : 0); + } + table[i] = crc; + } + return table; +}(); + +WriteAheadLog::WriteAheadLog(const std::string& db_path) { + fs::create_directories(db_path); + path_ = db_path + "/wal.bin"; + open_for_append(); +} + +WriteAheadLog::~WriteAheadLog() { + if (file_.is_open()) { + file_.flush(); + file_.close(); + } +} + +void WriteAheadLog::open_for_append() { + if (file_.is_open()) file_.close(); + file_.open(path_, std::ios::binary | std::ios::app); + if (!file_) { + throw WayyException("Failed to open WAL file: " + path_); + } +} + +uint32_t WriteAheadLog::crc32(const uint8_t* data, size_t len) { + uint32_t crc = 0xFFFFFFFF; + for (size_t i = 0; i < len; ++i) { + crc = crc32_table[(crc ^ data[i]) & 0xFF] ^ (crc >> 8); + } + return crc ^ 0xFFFFFFFF; +} + +void WriteAheadLog::write_entry(WalOp op, const std::string& table, size_t row, + const std::vector& payload) { + std::lock_guard lock(mu_); + + // Build the entry in a buffer for CRC calculation + std::vector buf; + buf.reserve(4 + 1 + 4 + table.size() + 8 + 4 + payload.size()); + + // Magic + uint32_t magic = WAL_MAGIC; + buf.insert(buf.end(), reinterpret_cast(&magic), + reinterpret_cast(&magic) + 4); + + // Op type + buf.push_back(static_cast(op)); + + // Table name length + name + uint32_t tlen = static_cast(table.size()); + buf.insert(buf.end(), reinterpret_cast(&tlen), + reinterpret_cast(&tlen) + 4); + buf.insert(buf.end(), table.begin(), table.end()); + + // Row ID + uint64_t row_id = static_cast(row); + buf.insert(buf.end(), reinterpret_cast(&row_id), + reinterpret_cast(&row_id) + 8); + + // Payload length + payload + uint32_t plen = static_cast(payload.size()); + buf.insert(buf.end(), reinterpret_cast(&plen), + reinterpret_cast(&plen) + 4); + buf.insert(buf.end(), payload.begin(), payload.end()); + + // CRC32 + uint32_t checksum = crc32(buf.data(), buf.size()); + buf.insert(buf.end(), reinterpret_cast(&checksum), + reinterpret_cast(&checksum) + 4); + + // Write to file + file_.write(reinterpret_cast(buf.data()), + static_cast(buf.size())); + file_.flush(); +} + +void WriteAheadLog::log_insert(const std::string& table, size_t row, + const std::vector& data) { + write_entry(WalOp::Insert, table, row, data); +} + +void WriteAheadLog::log_update(const std::string& table, size_t row, + const std::string& col, const std::vector& data) { + // Encode column name + data as payload + std::vector payload; + uint32_t clen = static_cast(col.size()); + payload.insert(payload.end(), reinterpret_cast(&clen), + reinterpret_cast(&clen) + 4); + payload.insert(payload.end(), col.begin(), col.end()); + payload.insert(payload.end(), data.begin(), data.end()); + write_entry(WalOp::Update, table, row, payload); +} + +void WriteAheadLog::log_delete(const std::string& table, size_t row) { + write_entry(WalOp::Delete, table, row, {}); +} + +void WriteAheadLog::checkpoint(Database& db) { + std::lock_guard lock(mu_); + + // Flush and close WAL + if (file_.is_open()) { + file_.flush(); + file_.close(); + } + + // Save all tables to disk + db.save(); + + // Truncate WAL (start fresh) + file_.open(path_, std::ios::binary | std::ios::trunc); + if (!file_) { + throw WayyException("Failed to truncate WAL: " + path_); + } +} + +void WriteAheadLog::replay(Database& db) { + if (!fs::exists(path_)) return; + + std::ifstream wal(path_, std::ios::binary); + if (!wal) return; + + // Get file size + wal.seekg(0, std::ios::end); + auto file_size = wal.tellg(); + if (file_size <= 0) return; + wal.seekg(0, std::ios::beg); + + size_t entries_replayed = 0; + + while (wal.good() && wal.tellg() < file_size) { + auto entry_start = wal.tellg(); + + // Read magic + uint32_t magic = 0; + wal.read(reinterpret_cast(&magic), 4); + if (magic != WAL_MAGIC) break; // Corrupt or end of valid entries + + // Read op + uint8_t op_byte = 0; + wal.read(reinterpret_cast(&op_byte), 1); + auto op = static_cast(op_byte); + + // Read table name + uint32_t tlen = 0; + wal.read(reinterpret_cast(&tlen), 4); + std::string table_name(tlen, '\0'); + wal.read(table_name.data(), tlen); + + // Read row ID + uint64_t row_id = 0; + wal.read(reinterpret_cast(&row_id), 8); + + // Read payload + uint32_t plen = 0; + wal.read(reinterpret_cast(&plen), 4); + std::vector payload(plen); + if (plen > 0) { + wal.read(reinterpret_cast(payload.data()), plen); + } + + // Read CRC + uint32_t stored_crc = 0; + wal.read(reinterpret_cast(&stored_crc), 4); + + // Verify CRC (re-read the entry from start to before CRC) + auto entry_end = wal.tellg(); + size_t entry_size = static_cast(entry_end - entry_start) - 4; // Exclude CRC + wal.seekg(entry_start); + std::vector entry_data(entry_size); + wal.read(reinterpret_cast(entry_data.data()), entry_size); + wal.seekg(entry_end); // Skip past CRC we already read + + uint32_t computed_crc = crc32(entry_data.data(), entry_data.size()); + if (computed_crc != stored_crc) { + break; // Corrupt entry, stop replay + } + + // Apply operation (best-effort: skip if table doesn't exist) + // The actual replay logic depends on the table having been loaded. + // For now, we just count replayed entries. Full replay requires + // deserializing the payload and calling table CRUD methods. + // TODO: Implement full row-level replay when table schema is available. + (void)op; + (void)row_id; + (void)table_name; + + ++entries_replayed; + } + + // After replay, truncate WAL + wal.close(); + if (entries_replayed > 0) { + // Re-save state and clear WAL + std::ofstream truncate(path_, std::ios::binary | std::ios::trunc); + } +} + +bool WriteAheadLog::has_entries() const { + if (!fs::exists(path_)) return false; + return fs::file_size(path_) > 0; +} + +} // namespace wayy_db diff --git a/tests/python/conftest.py b/tests/python/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..216ed58d4c4a2271544a07256d3e52073171d7c2 --- /dev/null +++ b/tests/python/conftest.py @@ -0,0 +1,37 @@ +"""Pytest configuration and fixtures for WayyDB tests.""" + +import pytest +import numpy as np +import tempfile +import shutil +from pathlib import Path + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test data.""" + path = tempfile.mkdtemp(prefix="wayy_test_") + yield path + shutil.rmtree(path, ignore_errors=True) + + +@pytest.fixture +def sample_trades(): + """Sample trades data for testing.""" + return { + "timestamp": np.array([1000, 2000, 3000, 4000, 5000], dtype=np.int64), + "symbol": np.array([0, 1, 0, 1, 0], dtype=np.uint32), # AAPL, MSFT alternating + "price": np.array([150.0, 380.0, 151.0, 381.0, 152.0], dtype=np.float64), + "size": np.array([100, 200, 150, 250, 100], dtype=np.int64), + } + + +@pytest.fixture +def sample_quotes(): + """Sample quotes data for testing.""" + return { + "timestamp": np.array([500, 900, 1500, 2500, 3500], dtype=np.int64), + "symbol": np.array([0, 1, 0, 1, 0], dtype=np.uint32), + "bid": np.array([149.5, 379.5, 150.5, 380.5, 151.5], dtype=np.float64), + "ask": np.array([150.0, 380.0, 151.0, 381.0, 152.0], dtype=np.float64), + } diff --git a/tests/python/test_api.py b/tests/python/test_api.py new file mode 100644 index 0000000000000000000000000000000000000000..2075def3ad258c5a2b1b42caf0ab00f4010d57e7 --- /dev/null +++ b/tests/python/test_api.py @@ -0,0 +1,349 @@ +"""Tests for WayyDB REST API endpoints.""" + +import pytest +import asyncio +import numpy as np +from httpx import AsyncClient, ASGITransport +from fastapi.testclient import TestClient +import tempfile +import shutil +import os + +# Set up test environment before importing app +_test_data_path = tempfile.mkdtemp(prefix="wayydb_test_") +os.environ["WAYY_DATA_PATH"] = _test_data_path + + +@pytest.fixture(scope="module") +def api_client(): + """Create a test client with lifespan managed.""" + from api.main import app + with TestClient(app) as client: + yield client + + +class TestHealthEndpoints: + """Tests for health check endpoints.""" + + def test_root(self, api_client): + """Test root endpoint.""" + response = api_client.get("/") + assert response.status_code == 200 + data = response.json() + assert data["service"] == "WayyDB API" + assert "version" in data + assert data["status"] == "healthy" + + def test_health(self, api_client): + """Test health endpoint.""" + response = api_client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert "tables" in data + + +class TestTableOperations: + """Tests for table CRUD operations.""" + + def test_list_tables(self, api_client): + """Test listing tables.""" + response = api_client.get("/tables") + assert response.status_code == 200 + assert "tables" in response.json() + + def test_create_and_delete_table(self, api_client): + """Test creating and deleting a table.""" + # Create + response = api_client.post("/tables", json={"name": "test_table"}) + assert response.status_code == 200 + assert response.json()["created"] == "test_table" + + # Verify exists + response = api_client.get("/tables/test_table") + assert response.status_code == 200 + + # Delete + response = api_client.delete("/tables/test_table") + assert response.status_code == 200 + + # Verify deleted + response = api_client.get("/tables/test_table") + assert response.status_code == 404 + + def test_upload_table(self, api_client): + """Test uploading a table with data.""" + data = { + "name": "uploaded_table", + "columns": [ + {"name": "timestamp", "dtype": "int64", "data": [1000, 2000, 3000]}, + {"name": "price", "dtype": "float64", "data": [100.0, 101.0, 102.0]}, + ], + "sorted_by": "timestamp", + } + + response = api_client.post("/tables/upload", json=data) + assert response.status_code == 200 + result = response.json() + assert result["created"] == "uploaded_table" + assert result["rows"] == 3 + + # Get data back + response = api_client.get("/tables/uploaded_table/data") + assert response.status_code == 200 + result = response.json() + assert result["total_rows"] == 3 + assert result["data"]["timestamp"] == [1000, 2000, 3000] + assert result["data"]["price"] == [100.0, 101.0, 102.0] + + # Cleanup + api_client.delete("/tables/uploaded_table") + + def test_get_table_info(self, api_client): + """Test getting table metadata.""" + # Upload a table + data = { + "name": "info_test", + "columns": [ + {"name": "ts", "dtype": "int64", "data": [1, 2, 3]}, + {"name": "val", "dtype": "float64", "data": [1.0, 2.0, 3.0]}, + ], + "sorted_by": "ts", + } + api_client.post("/tables/upload", json=data) + + # Get info + response = api_client.get("/tables/info_test") + assert response.status_code == 200 + info = response.json() + assert info["name"] == "info_test" + assert info["num_rows"] == 3 + assert info["num_columns"] == 2 + assert "ts" in info["columns"] + assert info["sorted_by"] == "ts" + + # Cleanup + api_client.delete("/tables/info_test") + + +class TestAppendAPI: + """Tests for the append endpoint.""" + + def test_append_to_table(self, api_client): + """Test appending rows to an existing table.""" + # Create initial table + data = { + "name": "append_test", + "columns": [ + {"name": "timestamp", "dtype": "int64", "data": [1000, 2000]}, + {"name": "price", "dtype": "float64", "data": [100.0, 101.0]}, + ], + "sorted_by": "timestamp", + } + api_client.post("/tables/upload", json=data) + + # Append more data + append_data = { + "columns": [ + {"name": "timestamp", "dtype": "int64", "data": [3000, 4000]}, + {"name": "price", "dtype": "float64", "data": [102.0, 103.0]}, + ] + } + response = api_client.post("/tables/append_test/append", json=append_data) + assert response.status_code == 200 + result = response.json() + assert result["new_rows"] == 2 + assert result["total_rows"] == 4 + + # Verify data + response = api_client.get("/tables/append_test/data") + data = response.json()["data"] + assert len(data["timestamp"]) == 4 + assert data["timestamp"] == [1000, 2000, 3000, 4000] + assert data["price"] == [100.0, 101.0, 102.0, 103.0] + + # Cleanup + api_client.delete("/tables/append_test") + + def test_append_column_mismatch(self, api_client): + """Test that append fails with mismatched columns.""" + # Create table + data = { + "name": "mismatch_test", + "columns": [ + {"name": "timestamp", "dtype": "int64", "data": [1000]}, + {"name": "price", "dtype": "float64", "data": [100.0]}, + ], + } + api_client.post("/tables/upload", json=data) + + # Try to append with wrong columns + append_data = { + "columns": [ + {"name": "timestamp", "dtype": "int64", "data": [2000]}, + {"name": "volume", "dtype": "float64", "data": [50.0]}, # Wrong column + ] + } + response = api_client.post("/tables/mismatch_test/append", json=append_data) + assert response.status_code == 400 + assert "mismatch" in response.json()["detail"].lower() + + # Cleanup + api_client.delete("/tables/mismatch_test") + + +class TestIngestAPI: + """Tests for streaming ingestion REST endpoints.""" + + def test_ingest_single_tick(self, api_client): + """Test ingesting a single tick via REST.""" + tick = { + "symbol": "BTC-USD", + "price": 42150.50, + "volume": 1.5, + } + response = api_client.post("/ingest/test_ticks", json=tick) + assert response.status_code == 200 + assert response.json()["ingested"] == 1 + + def test_ingest_batch(self, api_client): + """Test ingesting a batch of ticks via REST.""" + batch = { + "ticks": [ + {"symbol": "BTC-USD", "price": 42150.0}, + {"symbol": "ETH-USD", "price": 2250.0}, + {"symbol": "SOL-USD", "price": 100.0}, + ] + } + response = api_client.post("/ingest/test_ticks/batch", json=batch) + assert response.status_code == 200 + assert response.json()["ingested"] == 3 + + +class TestStreamingStats: + """Tests for streaming statistics endpoints.""" + + def test_streaming_stats(self, api_client): + """Test getting streaming statistics.""" + response = api_client.get("/streaming/stats") + assert response.status_code == 200 + stats = response.json() + assert "ticks_received" in stats + assert "ticks_flushed" in stats + assert "buffer_sizes" in stats + assert "running" in stats + + def test_get_quote(self, api_client): + """Test getting a specific quote.""" + # Ingest a tick first + api_client.post("/ingest/ticks", json={"symbol": "BTC-USD", "price": 42000.0}) + + # Get quote + response = api_client.get("/streaming/quote/ticks/BTC-USD") + assert response.status_code == 200 + quote = response.json() + assert quote["symbol"] == "BTC-USD" + assert quote["price"] == 42000.0 + + def test_get_all_quotes(self, api_client): + """Test getting all quotes for a table.""" + # Ingest multiple ticks + for symbol, price in [("BTC-USD", 42000.0), ("ETH-USD", 2200.0)]: + api_client.post("/ingest/quotes_test", json={"symbol": symbol, "price": price}) + + # Get all quotes + response = api_client.get("/streaming/quotes/quotes_test") + assert response.status_code == 200 + quotes = response.json() + assert "BTC-USD" in quotes + assert "ETH-USD" in quotes + + +class TestAggregations: + """Tests for aggregation endpoints.""" + + @pytest.fixture + def setup_table(self, api_client): + """Set up a table for aggregation tests.""" + data = { + "name": "agg_test", + "columns": [ + {"name": "timestamp", "dtype": "int64", "data": [1, 2, 3, 4, 5]}, + {"name": "price", "dtype": "float64", "data": [10.0, 20.0, 30.0, 40.0, 50.0]}, + ], + } + api_client.post("/tables/upload", json=data) + yield + api_client.delete("/tables/agg_test") + + def test_sum(self, api_client, setup_table): + """Test sum aggregation.""" + response = api_client.get("/tables/agg_test/agg/price/sum") + assert response.status_code == 200 + assert response.json()["result"] == pytest.approx(150.0) + + def test_avg(self, api_client, setup_table): + """Test average aggregation.""" + response = api_client.get("/tables/agg_test/agg/price/avg") + assert response.status_code == 200 + assert response.json()["result"] == pytest.approx(30.0) + + def test_min_max(self, api_client, setup_table): + """Test min/max aggregations.""" + response = api_client.get("/tables/agg_test/agg/price/min") + assert response.json()["result"] == pytest.approx(10.0) + + response = api_client.get("/tables/agg_test/agg/price/max") + assert response.json()["result"] == pytest.approx(50.0) + + +class TestWindowFunctions: + """Tests for window function endpoint.""" + + @pytest.fixture + def setup_table(self, api_client): + """Set up a table for window tests.""" + data = { + "name": "window_test", + "columns": [ + {"name": "timestamp", "dtype": "int64", "data": list(range(10))}, + {"name": "price", "dtype": "float64", "data": [float(i) for i in range(10)]}, + ], + } + api_client.post("/tables/upload", json=data) + yield + api_client.delete("/tables/window_test") + + def test_mavg(self, api_client, setup_table): + """Test moving average.""" + response = api_client.post("/window", json={ + "table": "window_test", + "column": "price", + "operation": "mavg", + "window": 3, + }) + assert response.status_code == 200 + result = response.json()["result"] + assert len(result) == 10 + # Third element should be avg of 0, 1, 2 = 1.0 + assert result[2] == pytest.approx(1.0) + + def test_ema(self, api_client, setup_table): + """Test exponential moving average.""" + response = api_client.post("/window", json={ + "table": "window_test", + "column": "price", + "operation": "ema", + "alpha": 0.5, + }) + assert response.status_code == 200 + result = response.json()["result"] + assert len(result) == 10 + assert result[0] == pytest.approx(0.0) # First value unchanged + + +# Cleanup test directory after all tests +def pytest_sessionfinish(session, exitstatus): + """Clean up test data directory.""" + if _test_data_path and os.path.exists(_test_data_path): + shutil.rmtree(_test_data_path, ignore_errors=True) diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py new file mode 100644 index 0000000000000000000000000000000000000000..f3c7d2e1a548400f1947b49046866eed92a27c13 --- /dev/null +++ b/tests/python/test_basic.py @@ -0,0 +1,179 @@ +"""Basic functionality tests for WayyDB Python bindings.""" + +import pytest +import numpy as np +import wayy_db as wdb + + +class TestTable: + """Tests for Table class.""" + + def test_create_empty_table(self): + table = wdb.Table("test") + assert table.name == "test" + assert table.num_rows == 0 + assert table.num_columns == 0 + assert len(table) == 0 + + def test_from_dict(self, sample_trades): + table = wdb.from_dict(sample_trades, name="trades", sorted_by="timestamp") + + assert table.name == "trades" + assert table.num_rows == 5 + assert table.num_columns == 4 + assert table.sorted_by == "timestamp" + + def test_column_access(self, sample_trades): + table = wdb.from_dict(sample_trades, name="trades") + + assert table.has_column("price") + assert not table.has_column("nonexistent") + + price_col = table["price"] + assert price_col.name == "price" + assert price_col.dtype == wdb.DType.Float64 + assert len(price_col) == 5 + + def test_to_numpy_zero_copy(self, sample_trades): + table = wdb.from_dict(sample_trades, name="trades") + + prices = table["price"].to_numpy() + + assert isinstance(prices, np.ndarray) + assert prices.dtype == np.float64 + assert len(prices) == 5 + np.testing.assert_array_equal(prices, sample_trades["price"]) + + def test_to_dict(self, sample_trades): + table = wdb.from_dict(sample_trades, name="trades") + + result = table.to_dict() + + assert set(result.keys()) == {"timestamp", "symbol", "price", "size"} + np.testing.assert_array_equal(result["price"], sample_trades["price"]) + + def test_column_names(self, sample_trades): + table = wdb.from_dict(sample_trades, name="trades") + + names = table.column_names() + + assert set(names) == {"timestamp", "symbol", "price", "size"} + + +class TestDatabase: + """Tests for Database class.""" + + def test_in_memory_database(self): + db = wdb.Database() + + assert not db.is_persistent + assert db.tables() == [] + + def test_create_table(self): + db = wdb.Database() + table = db.create_table("trades") + + assert db.has_table("trades") + assert "trades" in db.tables() + + def test_persistent_database(self, temp_dir, sample_trades): + # Create and populate + db = wdb.Database(temp_dir) + table = db.create_table("trades") + + for name, data in sample_trades.items(): + dtype = { + np.dtype("int64"): wdb.DType.Int64, + np.dtype("float64"): wdb.DType.Float64, + np.dtype("uint32"): wdb.DType.Symbol, + }[data.dtype] + table.add_column_from_numpy(name, data, dtype) + + table.set_sorted_by("timestamp") + db.save() + + # Reload and verify + db2 = wdb.Database(temp_dir) + assert db2.has_table("trades") + + loaded = db2["trades"] + assert loaded.num_rows == 5 + assert loaded.sorted_by == "timestamp" + + +class TestOperations: + """Tests for operations module.""" + + def test_aggregations(self, sample_trades): + table = wdb.from_dict(sample_trades, name="trades") + price_col = table["price"] + + assert wdb.ops.sum(price_col) == pytest.approx(1214.0) + assert wdb.ops.avg(price_col) == pytest.approx(242.8) + assert wdb.ops.min(price_col) == pytest.approx(150.0) + assert wdb.ops.max(price_col) == pytest.approx(381.0) + + def test_window_functions(self, sample_trades): + table = wdb.from_dict(sample_trades, name="trades") + price_col = table["price"] + + mavg = wdb.ops.mavg(price_col, 2) + assert len(mavg) == 5 + assert mavg[1] == pytest.approx((150.0 + 380.0) / 2) + + msum = wdb.ops.msum(price_col, 2) + assert len(msum) == 5 + + def test_ema(self, sample_trades): + table = wdb.from_dict(sample_trades, name="trades") + price_col = table["price"] + + ema = wdb.ops.ema(price_col, 0.5) + assert len(ema) == 5 + assert ema[0] == pytest.approx(150.0) # First value unchanged + + def test_diff(self, sample_trades): + table = wdb.from_dict(sample_trades, name="trades") + price_col = table["price"] + + diff = wdb.ops.diff(price_col, 1) + assert len(diff) == 5 + assert diff[1] == pytest.approx(380.0 - 150.0) + + +class TestAsOfJoin: + """Tests for as-of join operation.""" + + def test_aj_basic(self, sample_trades, sample_quotes): + trades = wdb.from_dict(sample_trades, name="trades", sorted_by="timestamp") + quotes = wdb.from_dict(sample_quotes, name="quotes", sorted_by="timestamp") + + result = wdb.ops.aj(trades, quotes, on=["symbol"], as_of="timestamp") + + assert result.num_rows == 5 + assert result.has_column("bid") + assert result.has_column("ask") + assert result.has_column("price") + + def test_aj_requires_sorted(self, sample_trades, sample_quotes): + trades = wdb.from_dict(sample_trades, name="trades") # Not sorted + quotes = wdb.from_dict(sample_quotes, name="quotes", sorted_by="timestamp") + + with pytest.raises(wdb.InvalidOperation): + wdb.ops.aj(trades, quotes, on=["symbol"], as_of="timestamp") + + +class TestExceptions: + """Tests for exception handling.""" + + def test_column_not_found(self, sample_trades): + table = wdb.from_dict(sample_trades, name="trades") + + with pytest.raises(wdb.ColumnNotFound): + _ = table["nonexistent"] + + def test_invalid_operation(self, sample_trades): + table = wdb.from_dict(sample_trades, name="trades") + + with pytest.raises(wdb.ColumnNotFound): + table.set_sorted_by("nonexistent") diff --git a/tests/python/test_pubsub.py b/tests/python/test_pubsub.py new file mode 100644 index 0000000000000000000000000000000000000000..5b3e25d61437679cffe92f52a4189a9a72aaf6b8 --- /dev/null +++ b/tests/python/test_pubsub.py @@ -0,0 +1,408 @@ +"""Tests for WayyDB PubSub abstraction layer.""" + +import asyncio +import time + +import pytest + +import sys +sys.path.insert(0, str(__file__).replace("/tests/python/test_pubsub.py", "")) +from api.pubsub import InMemoryPubSub, RedisPubSub, create_pubsub, Message + + +class TestInMemoryPubSub: + """Tests for InMemoryPubSub backend.""" + + @pytest.fixture + def pubsub(self): + return InMemoryPubSub(max_buffer_per_channel=100) + + @pytest.mark.asyncio + async def test_start_stop(self, pubsub): + await pubsub.start() + assert pubsub._running + await pubsub.stop() + assert not pubsub._running + + @pytest.mark.asyncio + async def test_publish_returns_sequence(self, pubsub): + await pubsub.start() + seq1 = await pubsub.publish("ticks:AAPL", {"price": 150.0}) + seq2 = await pubsub.publish("ticks:AAPL", {"price": 151.0}) + assert seq1 == 1 + assert seq2 == 2 + await pubsub.stop() + + @pytest.mark.asyncio + async def test_sequence_per_channel(self, pubsub): + await pubsub.start() + seq_a = await pubsub.publish("ticks:AAPL", {"price": 150.0}) + seq_m = await pubsub.publish("ticks:MSFT", {"price": 380.0}) + assert seq_a == 1 + assert seq_m == 1 # Separate sequence per channel + await pubsub.stop() + + @pytest.mark.asyncio + async def test_subscribe_receives_messages(self, pubsub): + await pubsub.start() + received = [] + + async def callback(msg): + received.append(msg) + + await pubsub.subscribe("ticks:AAPL", callback, "test_sub") + await pubsub.publish("ticks:AAPL", {"price": 150.0}) + await pubsub.publish("ticks:AAPL", {"price": 151.0}) + + assert len(received) == 2 + assert received[0]["price"] == 150.0 + assert received[1]["price"] == 151.0 + # Messages include metadata + assert received[0]["_seq"] == 1 + assert received[0]["_channel"] == "ticks:AAPL" + await pubsub.stop() + + @pytest.mark.asyncio + async def test_subscribe_wildcard(self, pubsub): + await pubsub.start() + received = [] + + async def callback(msg): + received.append(msg) + + await pubsub.subscribe("ticks:*", callback, "wildcard_sub") + await pubsub.publish("ticks:AAPL", {"symbol": "AAPL"}) + await pubsub.publish("ticks:MSFT", {"symbol": "MSFT"}) + await pubsub.publish("quotes:AAPL", {"symbol": "AAPL"}) # Not ticks:* + + assert len(received) == 2 + assert received[0]["symbol"] == "AAPL" + assert received[1]["symbol"] == "MSFT" + await pubsub.stop() + + @pytest.mark.asyncio + async def test_unsubscribe(self, pubsub): + await pubsub.start() + received = [] + + async def callback(msg): + received.append(msg) + + await pubsub.subscribe("ticks:AAPL", callback, "test_sub") + await pubsub.publish("ticks:AAPL", {"price": 150.0}) + assert len(received) == 1 + + await pubsub.unsubscribe("ticks:AAPL", "test_sub") + await pubsub.publish("ticks:AAPL", {"price": 151.0}) + assert len(received) == 1 # No new message + await pubsub.stop() + + @pytest.mark.asyncio + async def test_multiple_subscribers(self, pubsub): + await pubsub.start() + received_a = [] + received_b = [] + + async def cb_a(msg): + received_a.append(msg) + + async def cb_b(msg): + received_b.append(msg) + + await pubsub.subscribe("ticks:AAPL", cb_a, "sub_a") + await pubsub.subscribe("ticks:AAPL", cb_b, "sub_b") + await pubsub.publish("ticks:AAPL", {"price": 150.0}) + + assert len(received_a) == 1 + assert len(received_b) == 1 + await pubsub.stop() + + @pytest.mark.asyncio + async def test_backpressure_buffer_overflow(self): + pubsub = InMemoryPubSub(max_buffer_per_channel=5) + await pubsub.start() + + for i in range(10): + await pubsub.publish("ticks:AAPL", {"price": float(i)}) + + # Buffer should only keep last 5 + buf = pubsub.get_channel_buffer("ticks:AAPL") + assert len(buf) == 5 + assert buf[0].data["price"] == 5.0 # Oldest kept + assert buf[-1].data["price"] == 9.0 # Newest + + stats = pubsub.get_stats() + assert stats["messages_dropped"] == 5 + await pubsub.stop() + + @pytest.mark.asyncio + async def test_channel_buffer_replay(self, pubsub): + await pubsub.start() + + for i in range(5): + await pubsub.publish("ticks:AAPL", {"price": float(i)}) + + # Replay from sequence 0 (all messages) + all_msgs = pubsub.get_channel_buffer("ticks:AAPL", since_seq=0) + assert len(all_msgs) == 5 + + # Replay from sequence 3 (only seq 4 and 5) + recent = pubsub.get_channel_buffer("ticks:AAPL", since_seq=3) + assert len(recent) == 2 + assert recent[0].sequence == 4 + await pubsub.stop() + + @pytest.mark.asyncio + async def test_dead_subscriber_removal(self, pubsub): + await pubsub.start() + + async def bad_callback(msg): + raise ConnectionError("WebSocket closed") + + received_good = [] + + async def good_callback(msg): + received_good.append(msg) + + await pubsub.subscribe("ticks:AAPL", bad_callback, "bad_sub") + await pubsub.subscribe("ticks:AAPL", good_callback, "good_sub") + + await pubsub.publish("ticks:AAPL", {"price": 150.0}) + + # Bad subscriber should be removed, good one still works + assert len(received_good) == 1 + stats = pubsub.get_stats() + assert stats["active_subscriptions"] == 1 + await pubsub.stop() + + @pytest.mark.asyncio + async def test_publish_batch(self, pubsub): + await pubsub.start() + received = [] + + async def callback(msg): + received.append(msg) + + await pubsub.subscribe("ticks:AAPL", callback, "batch_sub") + messages = [{"price": float(i)} for i in range(5)] + last_seq = await pubsub.publish_batch("ticks:AAPL", messages) + + assert last_seq == 5 + assert len(received) == 5 + await pubsub.stop() + + @pytest.mark.asyncio + async def test_stats(self, pubsub): + await pubsub.start() + + async def noop(msg): + pass + + await pubsub.subscribe("ticks:AAPL", noop, "s1") + await pubsub.subscribe("ticks:MSFT", noop, "s2") + await pubsub.publish("ticks:AAPL", {"price": 150.0}) + await pubsub.publish("ticks:MSFT", {"price": 380.0}) + + stats = pubsub.get_stats() + assert stats["backend"] == "in_memory" + assert stats["messages_published"] == 2 + assert stats["messages_delivered"] == 2 + assert stats["active_subscriptions"] == 2 + assert stats["channels"] == 2 + await pubsub.stop() + + +class TestCreatePubSub: + """Tests for the factory function.""" + + def test_creates_inmemory_by_default(self): + ps = create_pubsub(None) + assert isinstance(ps, InMemoryPubSub) + + def test_creates_inmemory_for_empty_string(self): + ps = create_pubsub("") + assert isinstance(ps, InMemoryPubSub) + + def test_creates_redis_with_url(self): + ps = create_pubsub("redis://localhost:6379") + assert isinstance(ps, RedisPubSub) + + +class TestPubSubPerformance: + """Performance tests for InMemoryPubSub.""" + + @pytest.mark.asyncio + async def test_publish_throughput(self): + """Test raw publish throughput without subscribers.""" + pubsub = InMemoryPubSub(max_buffer_per_channel=100000) + await pubsub.start() + + num_msgs = 50000 + start = time.time() + + for i in range(num_msgs): + await pubsub.publish(f"ticks:SYM-{i % 100}", {"price": float(i)}) + + elapsed = time.time() - start + rate = num_msgs / elapsed + + print(f"\nPublish throughput (no subscribers): {rate:.0f} msgs/sec") + assert rate > 50000 # Should handle at least 50K msgs/sec + + await pubsub.stop() + + @pytest.mark.asyncio + async def test_publish_with_subscribers_throughput(self): + """Test publish throughput with active subscribers.""" + pubsub = InMemoryPubSub(max_buffer_per_channel=100000) + await pubsub.start() + + counter = {"count": 0} + + async def counting_callback(msg): + counter["count"] += 1 + + # Subscribe to 10 channels + for i in range(10): + await pubsub.subscribe(f"ticks:SYM-{i}", counting_callback, f"sub_{i}") + + num_msgs = 10000 + start = time.time() + + for i in range(num_msgs): + await pubsub.publish(f"ticks:SYM-{i % 10}", {"price": float(i)}) + + elapsed = time.time() - start + rate = num_msgs / elapsed + + print(f"\nPublish throughput (10 subscribers): {rate:.0f} msgs/sec") + print(f"Messages delivered: {counter['count']}") + assert rate > 5000 + assert counter["count"] == num_msgs + + await pubsub.stop() + + @pytest.mark.asyncio + async def test_many_channels(self): + """Test performance with many channels.""" + pubsub = InMemoryPubSub(max_buffer_per_channel=1000) + await pubsub.start() + + num_channels = 1000 + msgs_per_channel = 10 + + start = time.time() + for ch in range(num_channels): + for m in range(msgs_per_channel): + await pubsub.publish(f"ticks:SYM-{ch}", {"price": float(m)}) + + elapsed = time.time() - start + total = num_channels * msgs_per_channel + rate = total / elapsed + + print(f"\nMany channels throughput: {rate:.0f} msgs/sec ({num_channels} channels)") + + stats = pubsub.get_stats() + assert stats["channels"] == num_channels + assert stats["messages_published"] == total + + await pubsub.stop() + + @pytest.mark.asyncio + async def test_wildcard_subscriber_performance(self): + """Test wildcard subscriber doesn't destroy performance.""" + pubsub = InMemoryPubSub(max_buffer_per_channel=100000) + await pubsub.start() + + counter = {"count": 0} + + async def callback(msg): + counter["count"] += 1 + + # Wildcard subscriber for all ticks + await pubsub.subscribe("ticks:*", callback, "wildcard") + + num_msgs = 10000 + start = time.time() + + for i in range(num_msgs): + await pubsub.publish(f"ticks:SYM-{i % 50}", {"price": float(i)}) + + elapsed = time.time() - start + rate = num_msgs / elapsed + + print(f"\nWildcard subscriber throughput: {rate:.0f} msgs/sec") + assert counter["count"] == num_msgs + assert rate > 3000 + + await pubsub.stop() + + +class TestPubSubStress: + """Stress tests for PubSub under adverse conditions.""" + + @pytest.mark.asyncio + async def test_concurrent_publish_subscribe(self): + """Test concurrent publishers and subscribers.""" + pubsub = InMemoryPubSub(max_buffer_per_channel=10000) + await pubsub.start() + + results = {"total_received": 0} + + async def subscriber_cb(msg): + results["total_received"] += 1 + + # Register 5 subscribers on different channels + for i in range(5): + await pubsub.subscribe(f"ticks:SYM-{i}", subscriber_cb, f"sub_{i}") + + # Concurrent publishers + async def publisher(channel_idx, count): + for j in range(count): + await pubsub.publish(f"ticks:SYM-{channel_idx}", {"price": float(j)}) + + tasks = [publisher(i, 200) for i in range(5)] + await asyncio.gather(*tasks) + + assert results["total_received"] == 1000 # 5 channels x 200 msgs + await pubsub.stop() + + @pytest.mark.asyncio + async def test_subscribe_unsubscribe_churn(self): + """Test rapid subscribe/unsubscribe cycles.""" + pubsub = InMemoryPubSub() + await pubsub.start() + + async def noop(msg): + pass + + for cycle in range(100): + sub_id = f"churn_{cycle}" + await pubsub.subscribe("ticks:AAPL", noop, sub_id) + await pubsub.publish("ticks:AAPL", {"cycle": cycle}) + await pubsub.unsubscribe("ticks:AAPL", sub_id) + + stats = pubsub.get_stats() + assert stats["active_subscriptions"] == 0 + assert stats["messages_published"] == 100 + await pubsub.stop() + + @pytest.mark.asyncio + async def test_buffer_overflow_under_load(self): + """Test backpressure behavior under high load.""" + pubsub = InMemoryPubSub(max_buffer_per_channel=100) + await pubsub.start() + + # Publish 1000 messages to a channel with buffer size 100 + for i in range(1000): + await pubsub.publish("ticks:AAPL", {"price": float(i)}) + + stats = pubsub.get_stats() + buf = pubsub.get_channel_buffer("ticks:AAPL") + + assert len(buf) == 100 # Buffer capped + assert stats["messages_published"] == 1000 + assert stats["messages_dropped"] == 900 + # Buffer should have the latest 100 messages + assert buf[-1].data["price"] == 999.0 + await pubsub.stop() diff --git a/tests/python/test_streaming.py b/tests/python/test_streaming.py new file mode 100644 index 0000000000000000000000000000000000000000..e07ca41767ed9dd34c97805a0e6312f32a6fb646 --- /dev/null +++ b/tests/python/test_streaming.py @@ -0,0 +1,553 @@ +"""Tests for WayyDB streaming functionality with PubSub integration.""" + +import asyncio +import time +from datetime import datetime, timezone + +import numpy as np +import pytest + +import wayy_db as wdb + +import sys +sys.path.insert(0, str(__file__).replace("/tests/python/test_streaming.py", "")) +from api.pubsub import InMemoryPubSub +from api.streaming import StreamingManager, TickBuffer + + +class TestTickBuffer: + """Tests for TickBuffer data structure.""" + + def test_empty_buffer(self): + buffer = TickBuffer() + assert len(buffer) == 0 + + def test_append_single_tick(self): + buffer = TickBuffer() + buffer.append( + timestamp=1704067200000000000, + symbol="BTC-USD", + price=42150.50, + volume=1.5, + bid=42150.00, + ask=42151.00, + ) + assert len(buffer) == 1 + assert buffer.timestamps[0] == 1704067200000000000 + assert buffer.symbols[0] == "BTC-USD" + assert buffer.prices[0] == 42150.50 + + def test_append_multiple_ticks(self): + buffer = TickBuffer() + for i in range(100): + buffer.append( + timestamp=1704067200000000000 + i * 1000000, + symbol=f"SYM-{i % 5}", + price=100.0 + i * 0.1, + volume=float(i), + ) + assert len(buffer) == 100 + + def test_clear(self): + buffer = TickBuffer() + for i in range(10): + buffer.append( + timestamp=i, + symbol="BTC-USD", + price=100.0, + ) + assert len(buffer) == 10 + buffer.clear() + assert len(buffer) == 0 + + def test_to_columnar(self): + buffer = TickBuffer() + buffer.append(1000, "BTC-USD", 42150.0, 1.0, 42149.0, 42151.0) + buffer.append(2000, "ETH-USD", 2250.0, 10.0, 2249.0, 2251.0) + + data = buffer.to_columnar() + + assert "timestamp" in data + assert "symbol" in data + assert "price" in data + assert "volume" in data + assert "bid" in data + assert "ask" in data + + assert data["timestamp"].dtype == np.int64 + assert data["symbol"].dtype == np.uint32 + assert data["price"].dtype == np.float64 + + assert len(data["timestamp"]) == 2 + np.testing.assert_array_equal(data["timestamp"], [1000, 2000]) + np.testing.assert_array_equal(data["price"], [42150.0, 2250.0]) + + +class TestStreamingManager: + """Tests for StreamingManager with PubSub integration.""" + + @pytest.fixture + def pubsub(self): + return InMemoryPubSub(max_buffer_per_channel=10000) + + @pytest.fixture + def manager(self, pubsub): + """Create a streaming manager with in-memory pubsub.""" + manager = StreamingManager( + flush_interval=0.1, + max_buffer_size=100, + batch_broadcast_interval=0.01, + pubsub=pubsub, + ) + return manager + + @pytest.fixture + def temp_db(self, temp_dir): + """Create a temporary database for testing.""" + return wdb.Database(temp_dir) + + @pytest.mark.asyncio + async def test_ingest_single_tick(self, manager): + """Test ingesting a single tick.""" + await manager.start() + await manager.ingest_tick( + table="ticks", + symbol="BTC-USD", + price=42150.50, + volume=1.5, + ) + + stats = manager.get_stats() + assert stats["ticks_received"] == 1 + assert stats["buffer_sizes"]["ticks"] == 1 + await manager.stop() + + @pytest.mark.asyncio + async def test_ingest_publishes_to_pubsub(self, manager): + """Test that ingestion publishes to PubSub channels.""" + await manager.start() + + received = [] + + async def on_tick(msg): + received.append(msg) + + await manager._pubsub.subscribe("ticks:BTC-USD", on_tick, "test") + + await manager.ingest_tick( + table="ticks", + symbol="BTC-USD", + price=42150.50, + ) + + assert len(received) == 1 + assert received[0]["price"] == 42150.50 + assert received[0]["_channel"] == "ticks:BTC-USD" + assert received[0]["_seq"] == 1 + await manager.stop() + + @pytest.mark.asyncio + async def test_ingest_batch(self, manager): + """Test ingesting a batch of ticks.""" + await manager.start() + ticks = [ + {"symbol": "BTC-USD", "price": 42150.0, "volume": 1.0}, + {"symbol": "ETH-USD", "price": 2250.0, "volume": 10.0}, + {"symbol": "SOL-USD", "price": 100.0, "volume": 100.0}, + ] + await manager.ingest_batch(table="ticks", ticks=ticks) + + stats = manager.get_stats() + assert stats["ticks_received"] == 3 + assert stats["buffer_sizes"]["ticks"] == 3 + await manager.stop() + + @pytest.mark.asyncio + async def test_batch_publishes_to_channels(self, manager): + """Test that batch ingestion publishes to per-symbol channels.""" + await manager.start() + + btc_received = [] + eth_received = [] + + async def on_btc(msg): + btc_received.append(msg) + + async def on_eth(msg): + eth_received.append(msg) + + await manager._pubsub.subscribe("ticks:BTC-USD", on_btc, "btc_sub") + await manager._pubsub.subscribe("ticks:ETH-USD", on_eth, "eth_sub") + + ticks = [ + {"symbol": "BTC-USD", "price": 42150.0}, + {"symbol": "ETH-USD", "price": 2250.0}, + {"symbol": "BTC-USD", "price": 42160.0}, + ] + await manager.ingest_batch(table="ticks", ticks=ticks) + + assert len(btc_received) == 2 + assert len(eth_received) == 1 + assert btc_received[0]["price"] == 42150.0 + assert btc_received[1]["price"] == 42160.0 + await manager.stop() + + @pytest.mark.asyncio + async def test_latest_quotes(self, manager): + """Test that latest quotes are cached.""" + await manager.start() + await manager.ingest_tick(table="ticks", symbol="BTC-USD", price=42150.50) + await manager.ingest_tick(table="ticks", symbol="BTC-USD", price=42200.00) + + quote = manager.get_latest_quote("ticks", "BTC-USD") + assert quote is not None + assert quote["price"] == 42200.00 + await manager.stop() + + @pytest.mark.asyncio + async def test_get_all_quotes(self, manager): + """Test getting all quotes for a table.""" + await manager.start() + await manager.ingest_tick(table="ticks", symbol="BTC-USD", price=42150.0) + await manager.ingest_tick(table="ticks", symbol="ETH-USD", price=2250.0) + await manager.ingest_tick(table="ticks", symbol="SOL-USD", price=100.0) + + quotes = manager.get_all_quotes("ticks") + assert len(quotes) == 3 + assert "BTC-USD" in quotes + assert "ETH-USD" in quotes + assert "SOL-USD" in quotes + await manager.stop() + + @pytest.mark.asyncio + async def test_flush_to_database(self, manager, temp_db): + """Test flushing buffered data to database.""" + manager.set_database(temp_db) + await manager.start() + + for i in range(10): + await manager.ingest_tick( + table="test_ticks", + symbol="BTC-USD", + price=42000.0 + i, + timestamp=1704067200000000000 + i * 1000000000, + ) + + await manager._flush_table("test_ticks") + + assert temp_db.has_table("test_ticks") + table = temp_db["test_ticks"] + assert table.num_rows == 10 + + prices = table["price"].to_numpy() + assert prices[0] == pytest.approx(42000.0) + assert prices[9] == pytest.approx(42009.0) + await manager.stop() + + @pytest.mark.asyncio + async def test_append_to_existing_table(self, manager, temp_db): + """Test appending to an existing table.""" + manager.set_database(temp_db) + await manager.start() + + for i in range(5): + await manager.ingest_tick( + table="test_ticks", + symbol="BTC-USD", + price=42000.0 + i, + timestamp=1704067200000000000 + i * 1000000000, + ) + await manager._flush_table("test_ticks") + + for i in range(5, 10): + await manager.ingest_tick( + table="test_ticks", + symbol="BTC-USD", + price=42000.0 + i, + timestamp=1704067200000000000 + i * 1000000000, + ) + await manager._flush_table("test_ticks") + + table = temp_db["test_ticks"] + assert table.num_rows == 10 + + prices = table["price"].to_numpy() + for i in range(10): + assert prices[i] == pytest.approx(42000.0 + i) + await manager.stop() + + @pytest.mark.asyncio + async def test_auto_flush_on_buffer_full(self, manager, temp_db): + """Test that buffer auto-flushes when full.""" + manager.set_database(temp_db) + manager.max_buffer_size = 50 + await manager.start() + + for i in range(60): + await manager.ingest_tick( + table="test_ticks", + symbol="BTC-USD", + price=42000.0 + i, + timestamp=1704067200000000000 + i * 1000000000, + ) + + stats = manager.get_stats() + assert stats["ticks_flushed"] >= 50 + await manager.stop() + + @pytest.mark.asyncio + async def test_start_stop(self, manager, temp_db): + """Test starting and stopping the streaming manager.""" + manager.set_database(temp_db) + + await manager.start() + assert manager._running + + await manager.ingest_tick(table="ticks", symbol="BTC-USD", price=42150.0) + + await manager.stop() + assert not manager._running + + @pytest.mark.asyncio + async def test_stats_includes_pubsub(self, manager): + """Test that stats include pubsub info.""" + await manager.start() + + for i in range(10): + await manager.ingest_tick( + table="ticks", + symbol=f"SYM-{i % 3}", + price=100.0 + i, + ) + + stats = manager.get_stats() + + assert stats["ticks_received"] == 10 + assert "buffer_sizes" in stats + assert "subscriber_counts" in stats + assert stats["latest_quotes"] == 3 + assert "pubsub" in stats + assert stats["pubsub"]["backend"] == "in_memory" + assert stats["pubsub"]["messages_published"] == 10 + await manager.stop() + + +class TestStreamingManagerNoPubSub: + """Tests for StreamingManager without PubSub (backward compat).""" + + @pytest.fixture + def manager(self): + return StreamingManager( + flush_interval=0.1, + max_buffer_size=100, + batch_broadcast_interval=0.01, + pubsub=None, + ) + + @pytest.mark.asyncio + async def test_works_without_pubsub(self, manager): + """Test that streaming still works without a PubSub backend.""" + await manager.start() + + await manager.ingest_tick(table="ticks", symbol="BTC-USD", price=42150.0) + + stats = manager.get_stats() + assert stats["ticks_received"] == 1 + assert "pubsub" not in stats + + await manager.stop() + + +class TestStreamingPerformance: + """Performance tests for streaming with PubSub.""" + + @pytest.mark.asyncio + async def test_high_throughput_ingestion(self): + """Test high-throughput tick ingestion.""" + pubsub = InMemoryPubSub(max_buffer_per_channel=100000) + manager = StreamingManager( + flush_interval=10.0, + max_buffer_size=100000, + pubsub=pubsub, + ) + await manager.start() + + num_ticks = 10000 + start = time.time() + + for i in range(num_ticks): + await manager.ingest_tick( + table="benchmark", + symbol=f"SYM-{i % 100}", + price=100.0 + (i % 1000) * 0.01, + timestamp=1704067200000000000 + i * 1000, + ) + + elapsed = time.time() - start + ticks_per_second = num_ticks / elapsed + + print(f"\nIngestion w/ PubSub: {ticks_per_second:.0f} ticks/second") + print(f" Total ticks: {num_ticks}") + print(f" Elapsed: {elapsed:.3f}s") + + assert ticks_per_second > 5000 + + stats = manager.get_stats() + assert stats["ticks_received"] == num_ticks + assert stats["pubsub"]["messages_published"] == num_ticks + await manager.stop() + + @pytest.mark.asyncio + async def test_batch_vs_single_ingestion(self): + """Compare batch vs single tick ingestion performance.""" + pubsub = InMemoryPubSub(max_buffer_per_channel=100000) + manager = StreamingManager( + flush_interval=10.0, + max_buffer_size=100000, + pubsub=pubsub, + ) + await manager.start() + + num_ticks = 5000 + + # Single tick ingestion + start = time.time() + for i in range(num_ticks): + await manager.ingest_tick( + table="single", + symbol="BTC-USD", + price=42000.0 + i * 0.01, + ) + single_elapsed = time.time() - start + + # Batch ingestion + batch_size = 100 + batches = num_ticks // batch_size + start = time.time() + for b in range(batches): + ticks = [ + {"symbol": "BTC-USD", "price": 42000.0 + (b * batch_size + i) * 0.01} + for i in range(batch_size) + ] + await manager.ingest_batch(table="batch", ticks=ticks) + batch_elapsed = time.time() - start + + print(f"\nSingle ingestion: {num_ticks / single_elapsed:.0f} ticks/second") + print(f"Batch ingestion: {num_ticks / batch_elapsed:.0f} ticks/second") + print(f"Batch speedup: {single_elapsed / batch_elapsed:.1f}x") + + assert batch_elapsed <= single_elapsed * 1.5 + await manager.stop() + + +class TestMultipleSymbols: + """Tests for multi-symbol streaming.""" + + @pytest.mark.asyncio + async def test_multiple_symbols_separate_quotes(self): + pubsub = InMemoryPubSub() + manager = StreamingManager(pubsub=pubsub) + await manager.start() + + symbols = ["BTC-USD", "ETH-USD", "SOL-USD", "ADA-USD", "DOT-USD"] + + for i, symbol in enumerate(symbols): + await manager.ingest_tick( + table="ticks", + symbol=symbol, + price=1000.0 * (i + 1), + ) + + for i, symbol in enumerate(symbols): + quote = manager.get_latest_quote("ticks", symbol) + assert quote is not None + assert quote["price"] == pytest.approx(1000.0 * (i + 1)) + + await manager.stop() + + @pytest.mark.asyncio + async def test_pubsub_channels_per_symbol(self): + """Test that each symbol gets its own PubSub channel.""" + pubsub = InMemoryPubSub() + manager = StreamingManager(pubsub=pubsub) + await manager.start() + + btc_msgs = [] + eth_msgs = [] + + async def on_btc(msg): + btc_msgs.append(msg) + + async def on_eth(msg): + eth_msgs.append(msg) + + await pubsub.subscribe("ticks:BTC-USD", on_btc, "btc") + await pubsub.subscribe("ticks:ETH-USD", on_eth, "eth") + + await manager.ingest_tick(table="ticks", symbol="BTC-USD", price=42000.0) + await manager.ingest_tick(table="ticks", symbol="ETH-USD", price=2200.0) + await manager.ingest_tick(table="ticks", symbol="BTC-USD", price=42100.0) + + assert len(btc_msgs) == 2 + assert len(eth_msgs) == 1 + await manager.stop() + + +class TestTimestamps: + """Tests for timestamp handling.""" + + @pytest.mark.asyncio + async def test_auto_timestamp(self): + pubsub = InMemoryPubSub() + manager = StreamingManager(pubsub=pubsub) + await manager.start() + + before = int(datetime.now(timezone.utc).timestamp() * 1e9) + await manager.ingest_tick(table="ticks", symbol="BTC-USD", price=42000.0) + after = int(datetime.now(timezone.utc).timestamp() * 1e9) + + quote = manager.get_latest_quote("ticks", "BTC-USD") + assert before <= quote["timestamp"] <= after + await manager.stop() + + @pytest.mark.asyncio + async def test_explicit_timestamp(self): + pubsub = InMemoryPubSub() + manager = StreamingManager(pubsub=pubsub) + await manager.start() + + ts = 1704067200000000000 + await manager.ingest_tick( + table="ticks", + symbol="BTC-USD", + price=42000.0, + timestamp=ts, + ) + + quote = manager.get_latest_quote("ticks", "BTC-USD") + assert quote["timestamp"] == ts + await manager.stop() + + @pytest.mark.asyncio + async def test_timestamp_ordering_in_flush(self, temp_dir): + pubsub = InMemoryPubSub() + manager = StreamingManager(pubsub=pubsub) + db = wdb.Database(temp_dir) + manager.set_database(db) + await manager.start() + + timestamps = [3000, 1000, 4000, 2000, 5000] + for ts in timestamps: + await manager.ingest_tick( + table="ticks", + symbol="BTC-USD", + price=42000.0, + timestamp=ts, + ) + + await manager._flush_table("ticks") + + table = db["ticks"] + stored_ts = table["timestamp"].to_numpy() + np.testing.assert_array_equal(stored_ts, timestamps) + await manager.stop() diff --git a/tests/test_column.cpp b/tests/test_column.cpp new file mode 100644 index 0000000000000000000000000000000000000000..651db29ae18abd9506fee13261c6f725f951f9ea --- /dev/null +++ b/tests/test_column.cpp @@ -0,0 +1,80 @@ +#include +#include "wayy_db/column.hpp" + +using namespace wayy_db; + +TEST(ColumnViewTest, BasicOperations) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0}; + ColumnView view(data.data(), data.size()); + + EXPECT_EQ(view.size(), 5); + EXPECT_FALSE(view.empty()); + EXPECT_EQ(view[0], 1.0); + EXPECT_EQ(view[4], 5.0); + EXPECT_EQ(view.front(), 1.0); + EXPECT_EQ(view.back(), 5.0); +} + +TEST(ColumnViewTest, Iteration) { + std::vector data = {10, 20, 30}; + ColumnView view(data.data(), data.size()); + + int64_t sum = 0; + for (auto val : view) { + sum += val; + } + EXPECT_EQ(sum, 60); +} + +TEST(ColumnViewTest, Subview) { + std::vector data = {1.0, 2.0, 3.0, 4.0, 5.0}; + ColumnView view(data.data(), data.size()); + + auto sub = view.subview(1, 3); + EXPECT_EQ(sub.size(), 3); + EXPECT_EQ(sub[0], 2.0); + EXPECT_EQ(sub[2], 4.0); +} + +TEST(ColumnViewTest, OutOfRange) { + std::vector data = {1.0, 2.0}; + ColumnView view(data.data(), data.size()); + + EXPECT_THROW(view.at(5), std::out_of_range); + EXPECT_THROW(view.subview(1, 5), std::out_of_range); +} + +TEST(ColumnTest, ConstructWithOwnedData) { + std::vector data(40); // 5 doubles + auto* ptr = reinterpret_cast(data.data()); + for (int i = 0; i < 5; ++i) ptr[i] = static_cast(i); + + Column col("test", DType::Float64, std::move(data)); + + EXPECT_EQ(col.name(), "test"); + EXPECT_EQ(col.dtype(), DType::Float64); + EXPECT_EQ(col.size(), 5); + EXPECT_EQ(col.byte_size(), 40); +} + +TEST(ColumnTest, TypedAccess) { + std::vector data(24); + auto* ptr = reinterpret_cast(data.data()); + ptr[0] = 100; + ptr[1] = 200; + ptr[2] = 300; + + Column col("ints", DType::Int64, std::move(data)); + auto view = col.as_int64(); + + EXPECT_EQ(view.size(), 3); + EXPECT_EQ(view[0], 100); + EXPECT_EQ(view[2], 300); +} + +TEST(ColumnTest, TypeMismatch) { + std::vector data(24); + Column col("ints", DType::Int64, std::move(data)); + + EXPECT_THROW(col.as_float64(), TypeMismatch); +} diff --git a/tests/test_joins.cpp b/tests/test_joins.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f9f7abd67bb03ac3273b659c18badacec3b94537 --- /dev/null +++ b/tests/test_joins.cpp @@ -0,0 +1,138 @@ +#include +#include "wayy_db/table.hpp" +#include "wayy_db/ops/joins.hpp" + +using namespace wayy_db; + +class JoinsTest : public ::testing::Test { +protected: + Table create_trades() { + Table trades("trades"); + + // Trades at times 100, 200, 300 for symbols 0 (AAPL) and 1 (MSFT) + std::vector timestamps = {100, 150, 200, 250, 300}; + std::vector symbols = {0, 1, 0, 1, 0}; // AAPL, MSFT, AAPL, MSFT, AAPL + std::vector prices = {150.0, 380.0, 151.0, 381.0, 152.0}; + std::vector sizes = {100, 200, 150, 250, 100}; + + trades.add_column("timestamp", DType::Timestamp, timestamps.data(), timestamps.size()); + trades.add_column("symbol", DType::Symbol, symbols.data(), symbols.size()); + trades.add_column("price", DType::Float64, prices.data(), prices.size()); + trades.add_column("size", DType::Int64, sizes.data(), sizes.size()); + trades.set_sorted_by("timestamp"); + + return trades; + } + + Table create_quotes() { + Table quotes("quotes"); + + // Quotes at times 50, 90, 140, 190, 280 + std::vector timestamps = {50, 90, 140, 190, 280}; + std::vector symbols = {0, 1, 0, 1, 0}; + std::vector bids = {149.5, 379.5, 150.5, 380.5, 151.5}; + std::vector asks = {150.0, 380.0, 151.0, 381.0, 152.0}; + + quotes.add_column("timestamp", DType::Timestamp, timestamps.data(), timestamps.size()); + quotes.add_column("symbol", DType::Symbol, symbols.data(), symbols.size()); + quotes.add_column("bid", DType::Float64, bids.data(), bids.size()); + quotes.add_column("ask", DType::Float64, asks.data(), asks.size()); + quotes.set_sorted_by("timestamp"); + + return quotes; + } +}; + +TEST_F(JoinsTest, AsOfJoinBasic) { + auto trades = create_trades(); + auto quotes = create_quotes(); + + auto result = ops::aj(trades, quotes, {"symbol"}, "timestamp"); + + // Result should have same number of rows as trades + EXPECT_EQ(result.num_rows(), 5); + + // Check that we have columns from both tables + EXPECT_TRUE(result.has_column("timestamp")); + EXPECT_TRUE(result.has_column("symbol")); + EXPECT_TRUE(result.has_column("price")); + EXPECT_TRUE(result.has_column("bid")); + EXPECT_TRUE(result.has_column("ask")); + + // Verify as-of semantics: + // Trade at t=100, symbol=AAPL should get quote at t=90... wait, that's MSFT + // Trade at t=100, symbol=AAPL should get quote at t=50 (AAPL) + auto bids = result.column("bid").as_float64(); + EXPECT_DOUBLE_EQ(bids[0], 149.5); // AAPL trade at 100 -> AAPL quote at 50 + + // Trade at t=150, symbol=MSFT should get quote at t=90 (MSFT) + EXPECT_DOUBLE_EQ(bids[1], 379.5); + + // Trade at t=200, symbol=AAPL should get quote at t=140 (AAPL) + EXPECT_DOUBLE_EQ(bids[2], 150.5); +} + +TEST_F(JoinsTest, AsOfJoinRequiresSorted) { + Table left("left"); + Table right("right"); + + std::vector ts = {1, 2, 3}; + left.add_column("ts", DType::Timestamp, ts.data(), ts.size()); + right.add_column("ts", DType::Timestamp, ts.data(), ts.size()); + + // Neither is sorted + EXPECT_THROW(ops::aj(left, right, {}, "ts"), InvalidOperation); + + // Only left is sorted + left.set_sorted_by("ts"); + EXPECT_THROW(ops::aj(left, right, {}, "ts"), InvalidOperation); +} + +TEST_F(JoinsTest, WindowJoinBasic) { + auto trades = create_trades(); + auto quotes = create_quotes(); + + // Window: 60ns before, 0ns after + auto result = ops::wj(trades, quotes, {"symbol"}, "timestamp", 60, 0); + + // Window join may have more rows than left table + EXPECT_GT(result.num_rows(), 0); + + // Check columns exist + EXPECT_TRUE(result.has_column("bid")); + EXPECT_TRUE(result.has_column("price")); +} + +TEST_F(JoinsTest, AsOfJoinNoMatches) { + Table trades("trades"); + Table quotes("quotes"); + + // Trades for symbol 0 + std::vector trade_ts = {100, 200}; + std::vector trade_sym = {0, 0}; + std::vector trade_px = {100.0, 101.0}; + + trades.add_column("timestamp", DType::Timestamp, trade_ts.data(), trade_ts.size()); + trades.add_column("symbol", DType::Symbol, trade_sym.data(), trade_sym.size()); + trades.add_column("price", DType::Float64, trade_px.data(), trade_px.size()); + trades.set_sorted_by("timestamp"); + + // Quotes for symbol 1 (different symbol) + std::vector quote_ts = {50, 150}; + std::vector quote_sym = {1, 1}; + std::vector quote_bid = {99.0, 100.0}; + + quotes.add_column("timestamp", DType::Timestamp, quote_ts.data(), quote_ts.size()); + quotes.add_column("symbol", DType::Symbol, quote_sym.data(), quote_sym.size()); + quotes.add_column("bid", DType::Float64, quote_bid.data(), quote_bid.size()); + quotes.set_sorted_by("timestamp"); + + auto result = ops::aj(trades, quotes, {"symbol"}, "timestamp"); + + // Should still have 2 rows, but bid should be 0 (null) + EXPECT_EQ(result.num_rows(), 2); + + auto bids = result.column("bid").as_float64(); + EXPECT_DOUBLE_EQ(bids[0], 0.0); + EXPECT_DOUBLE_EQ(bids[1], 0.0); +} diff --git a/tests/test_mmap.cpp b/tests/test_mmap.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2d4743290cb71f82bbef5f760cfd7581c404e4d8 --- /dev/null +++ b/tests/test_mmap.cpp @@ -0,0 +1,132 @@ +#include +#include "wayy_db/mmap_file.hpp" +#include "wayy_db/types.hpp" + +#include +#include + +using namespace wayy_db; +namespace fs = std::filesystem; + +class MmapFileTest : public ::testing::Test { +protected: + void SetUp() override { + test_dir_ = "/tmp/wayy_mmap_test_" + std::to_string(getpid()); + fs::create_directories(test_dir_); + } + + void TearDown() override { + fs::remove_all(test_dir_); + } + + std::string test_dir_; +}; + +TEST_F(MmapFileTest, CreateAndWrite) { + std::string path = test_dir_ + "/test.bin"; + + { + MmapFile file(path, MmapFile::Mode::Create, 1024); + + EXPECT_TRUE(file.is_open()); + EXPECT_EQ(file.size(), 1024); + EXPECT_EQ(file.path(), path); + + // Write some data + auto* data = static_cast(file.data()); + for (int i = 0; i < 256; ++i) { + data[i] = i * 2; + } + + file.sync(); + } + + // Verify data persisted + { + MmapFile file(path, MmapFile::Mode::ReadOnly); + + EXPECT_EQ(file.size(), 1024); + + auto* data = static_cast(file.data()); + EXPECT_EQ(data[0], 0); + EXPECT_EQ(data[100], 200); + EXPECT_EQ(data[255], 510); + } +} + +TEST_F(MmapFileTest, ReadWrite) { + std::string path = test_dir_ + "/rw.bin"; + + // Create initial file + { + MmapFile file(path, MmapFile::Mode::Create, 100); + std::memset(file.data(), 0, 100); + } + + // Open for read-write and modify + { + MmapFile file(path, MmapFile::Mode::ReadWrite); + auto* data = static_cast(file.data()); + data[50] = 42; + file.sync(); + } + + // Verify modification + { + MmapFile file(path, MmapFile::Mode::ReadOnly); + auto* data = static_cast(file.data()); + EXPECT_EQ(data[50], 42); + } +} + +TEST_F(MmapFileTest, Resize) { + std::string path = test_dir_ + "/resize.bin"; + + MmapFile file(path, MmapFile::Mode::Create, 100); + EXPECT_EQ(file.size(), 100); + + file.resize(500); + EXPECT_EQ(file.size(), 500); + + // Can still write to expanded region + auto* data = static_cast(file.data()); + data[400] = 123; + file.sync(); +} + +TEST_F(MmapFileTest, MoveSemantics) { + std::string path = test_dir_ + "/move.bin"; + + MmapFile file1(path, MmapFile::Mode::Create, 256); + void* original_data = file1.data(); + + MmapFile file2 = std::move(file1); + + EXPECT_FALSE(file1.is_open()); + EXPECT_TRUE(file2.is_open()); + EXPECT_EQ(file2.data(), original_data); + EXPECT_EQ(file2.size(), 256); +} + +TEST_F(MmapFileTest, OpenNonexistent) { + std::string path = test_dir_ + "/nonexistent.bin"; + + EXPECT_THROW( + MmapFile file(path, MmapFile::Mode::ReadOnly), + WayyException + ); +} + +TEST_F(MmapFileTest, CloseAndReopen) { + std::string path = test_dir_ + "/close.bin"; + + MmapFile file(path, MmapFile::Mode::Create, 100); + EXPECT_TRUE(file.is_open()); + + file.close(); + EXPECT_FALSE(file.is_open()); + + file.open(path, MmapFile::Mode::ReadOnly); + EXPECT_TRUE(file.is_open()); + EXPECT_EQ(file.size(), 100); +} diff --git a/tests/test_table.cpp b/tests/test_table.cpp new file mode 100644 index 0000000000000000000000000000000000000000..565a55d2a14fff31de52afc4828876a8bb81135f --- /dev/null +++ b/tests/test_table.cpp @@ -0,0 +1,155 @@ +#include +#include "wayy_db/table.hpp" + +#include +#include + +using namespace wayy_db; +namespace fs = std::filesystem; + +class TableTest : public ::testing::Test { +protected: + void SetUp() override { + test_dir_ = "/tmp/wayy_test_" + std::to_string(getpid()); + fs::create_directories(test_dir_); + } + + void TearDown() override { + fs::remove_all(test_dir_); + } + + std::string test_dir_; +}; + +TEST_F(TableTest, EmptyTable) { + Table table("test"); + + EXPECT_EQ(table.name(), "test"); + EXPECT_EQ(table.num_rows(), 0); + EXPECT_EQ(table.num_columns(), 0); + EXPECT_FALSE(table.is_sorted()); +} + +TEST_F(TableTest, AddColumn) { + Table table("test"); + + std::vector prices = {100.0, 101.0, 102.0}; + table.add_column("price", DType::Float64, prices.data(), prices.size()); + + EXPECT_EQ(table.num_rows(), 3); + EXPECT_EQ(table.num_columns(), 1); + EXPECT_TRUE(table.has_column("price")); + EXPECT_FALSE(table.has_column("nonexistent")); +} + +TEST_F(TableTest, MultipleColumns) { + Table table("trades"); + + std::vector timestamps = {1000, 2000, 3000}; + std::vector prices = {100.0, 101.0, 102.0}; + std::vector sizes = {10, 20, 30}; + + table.add_column("timestamp", DType::Timestamp, timestamps.data(), timestamps.size()); + table.add_column("price", DType::Float64, prices.data(), prices.size()); + table.add_column("size", DType::Int64, sizes.data(), sizes.size()); + + EXPECT_EQ(table.num_columns(), 3); + EXPECT_EQ(table.num_rows(), 3); + + auto names = table.column_names(); + EXPECT_EQ(names.size(), 3); +} + +TEST_F(TableTest, ColumnSizeMismatch) { + Table table("test"); + + std::vector col1 = {1.0, 2.0, 3.0}; + std::vector col2 = {1.0, 2.0}; // Different size + + table.add_column("col1", DType::Float64, col1.data(), col1.size()); + EXPECT_THROW( + table.add_column("col2", DType::Float64, col2.data(), col2.size()), + InvalidOperation + ); +} + +TEST_F(TableTest, SortedBy) { + Table table("test"); + + std::vector timestamps = {1000, 2000, 3000}; + table.add_column("timestamp", DType::Timestamp, timestamps.data(), timestamps.size()); + + EXPECT_FALSE(table.is_sorted()); + + table.set_sorted_by("timestamp"); + EXPECT_TRUE(table.is_sorted()); + EXPECT_EQ(table.sorted_by(), "timestamp"); +} + +TEST_F(TableTest, SortedByNonexistent) { + Table table("test"); + + std::vector data = {1, 2, 3}; + table.add_column("col", DType::Int64, data.data(), data.size()); + + EXPECT_THROW(table.set_sorted_by("nonexistent"), ColumnNotFound); +} + +TEST_F(TableTest, SaveAndLoad) { + std::string table_dir = test_dir_ + "/trades"; + + // Create and save + { + Table table("trades"); + + std::vector timestamps = {1000, 2000, 3000}; + std::vector prices = {100.0, 101.0, 102.0}; + + table.add_column("timestamp", DType::Timestamp, timestamps.data(), timestamps.size()); + table.add_column("price", DType::Float64, prices.data(), prices.size()); + table.set_sorted_by("timestamp"); + + table.save(table_dir); + } + + // Load and verify + { + Table loaded = Table::load(table_dir); + + EXPECT_EQ(loaded.name(), "trades"); + EXPECT_EQ(loaded.num_rows(), 3); + EXPECT_EQ(loaded.num_columns(), 2); + EXPECT_EQ(loaded.sorted_by(), "timestamp"); + + auto ts = loaded.column("timestamp").as_int64(); + EXPECT_EQ(ts[0], 1000); + EXPECT_EQ(ts[2], 3000); + + auto prices = loaded.column("price").as_float64(); + EXPECT_DOUBLE_EQ(prices[1], 101.0); + } +} + +TEST_F(TableTest, Mmap) { + std::string table_dir = test_dir_ + "/mmap_test"; + + // Create and save + { + Table table("mmap_test"); + + std::vector data = {10, 20, 30, 40, 50}; + table.add_column("values", DType::Int64, data.data(), data.size()); + table.save(table_dir); + } + + // Memory-map and verify + { + Table mapped = Table::mmap(table_dir); + + EXPECT_EQ(mapped.num_rows(), 5); + + auto values = mapped.column("values").as_int64(); + EXPECT_EQ(values[0], 10); + EXPECT_EQ(values[4], 50); + } +} diff --git a/tests/test_types.cpp b/tests/test_types.cpp new file mode 100644 index 0000000000000000000000000000000000000000..eeaa509bf8ea9277be291f5f62785f4112894a10 --- /dev/null +++ b/tests/test_types.cpp @@ -0,0 +1,40 @@ +#include +#include "wayy_db/types.hpp" + +using namespace wayy_db; + +TEST(TypesTest, DTypeSizes) { + EXPECT_EQ(dtype_size(DType::Int64), 8); + EXPECT_EQ(dtype_size(DType::Float64), 8); + EXPECT_EQ(dtype_size(DType::Timestamp), 8); + EXPECT_EQ(dtype_size(DType::Symbol), 4); + EXPECT_EQ(dtype_size(DType::Bool), 1); +} + +TEST(TypesTest, DTypeToString) { + EXPECT_EQ(dtype_to_string(DType::Int64), "int64"); + EXPECT_EQ(dtype_to_string(DType::Float64), "float64"); + EXPECT_EQ(dtype_to_string(DType::Timestamp), "timestamp"); + EXPECT_EQ(dtype_to_string(DType::Symbol), "symbol"); + EXPECT_EQ(dtype_to_string(DType::Bool), "bool"); +} + +TEST(TypesTest, DTypeFromString) { + EXPECT_EQ(dtype_from_string("int64"), DType::Int64); + EXPECT_EQ(dtype_from_string("float64"), DType::Float64); + EXPECT_EQ(dtype_from_string("timestamp"), DType::Timestamp); + EXPECT_EQ(dtype_from_string("symbol"), DType::Symbol); + EXPECT_EQ(dtype_from_string("bool"), DType::Bool); +} + +TEST(TypesTest, DTypeFromStringInvalid) { + EXPECT_THROW(dtype_from_string("invalid"), WayyException); +} + +TEST(TypesTest, ColumnHeaderSize) { + EXPECT_EQ(sizeof(ColumnHeader), 64); +} + +TEST(TypesTest, MagicNumber) { + EXPECT_EQ(WAYY_MAGIC, 0x5741595944420001ULL); +}