contextforge-demo / tests /test_mcp_server.py
Pablo
feat: APOHARA: Context Forge V5 — synthesis + rebrand complete
cf0a8ed
# MERGED: OpenCode (deep KV physics) + CC (surface coverage)
# All tests hermetic: no GPU, no TCP, no downloaded weights required
from __future__ import annotations
import logging
from datetime import datetime, timedelta, timezone
import numpy as np
import pytest
# Optional dep guard — skip entire module if fastapi not installed
fastapi = pytest.importorskip("fastapi", reason="fastapi not installed — install with: pip install fastapi")
from fastapi.testclient import TestClient
from apohara_context_forge.mcp import server as srv
from apohara_context_forge.mcp.server import (
app,
get_compressor,
get_coordinator,
get_metrics,
get_registry,
)
from apohara_context_forge.models import (
CompressionDecision,
ContextEntry,
Degradation,
MetricsSnapshot,
)
from apohara_context_forge.registry.context_registry import ContextRegistry
# ---- Fakes (module-level so dependency_overrides + lifespan patches both work) -----
class FakeMetrics:
def __init__(self, *, gpu_label: str = "cuda", raise_on_label: bool = False) -> None:
self._gpu_label = gpu_label
self._raise_on_label = raise_on_label
self.register_calls: list[bool] = []
self.decision_calls: list[CompressionDecision] = []
self._snapshot_kwargs: dict | None = None
def _resolve_gpu_label(self) -> str:
if self._raise_on_label:
raise RuntimeError("gpu probe blew up")
return self._gpu_label
def record_register(self, matched: bool) -> None:
self.register_calls.append(matched)
def record_decision(self, decision: CompressionDecision) -> None:
self.decision_calls.append(decision)
async def snapshot(
self, *, current_compressor_model, compressor_degradations
) -> MetricsSnapshot:
self._snapshot_kwargs = {
"current_compressor_model": current_compressor_model,
"compressor_degradations": compressor_degradations,
}
return MetricsSnapshot(
vram_source="psutil",
compressor_model=current_compressor_model,
vram_used_gb=1.0,
vram_total_gb=8.0,
ttft_ms=0.0,
tokens_processed=0,
tokens_saved=0,
dedup_rate=0.0,
compression_ratio=0.0,
degradations=list(compressor_degradations),
)
class FakeCompressor:
def __init__(
self,
current_model: str = "xlm-roberta-large",
degradations: list[Degradation] | None = None,
) -> None:
self.current_model = current_model
self.degradations = degradations or []
class FakeRegistry:
def __init__(self, entry: ContextEntry | None = None) -> None:
self._entry = entry
self.register_calls: list[tuple[str, str]] = []
self.cleared = False
async def register(self, agent_id: str, context: str) -> ContextEntry:
self.register_calls.append((agent_id, context))
if self._entry is not None:
return self._entry
now = datetime.now(timezone.utc)
return ContextEntry(
agent_id=agent_id,
context=context,
token_count=len(context.split()),
created_at=now,
expires_at=now + timedelta(seconds=300),
)
async def clear(self) -> None:
self.cleared = True
class FakeCoordinator:
def __init__(self, decision: CompressionDecision | Exception) -> None:
self._decision = decision
self.decide_calls: list[tuple[str, str]] = []
async def decide(self, agent_id: str, context: str) -> CompressionDecision:
self.decide_calls.append((agent_id, context))
if isinstance(self._decision, Exception):
raise self._decision
return self._decision
# ---- FakeDedupEngine for the full-flow test (re-uses test_registry pattern) ---------
class FakeDedupEngine:
def __init__(self) -> None:
self._key_for_text: dict[str, float] = {}
self._next_key: float = 1.0
def _key(self, text: str) -> float:
if text not in self._key_for_text:
self._key_for_text[text] = self._next_key
self._next_key += 1.0
return self._key_for_text[text]
async def embed(self, text: str) -> np.ndarray:
v = np.zeros(8, dtype=np.float32)
v[0] = self._key(text)
return v
async def similarity(self, e1: np.ndarray, e2: np.ndarray) -> float:
return 1.0 if float(e1[0]) == float(e2[0]) else 0.0
def find_shared_prefix(self, a: str, b: str) -> str:
n = min(len(a), len(b))
i = 0
while i < n and a[i] == b[i]:
i += 1
return a[:i]
def count_prefix_tokens(self, prefix: str) -> int:
return len(prefix.split())
# ---- Helpers ------------------------------------------------------------------------
def _client_with_overrides(overrides: dict) -> TestClient:
"""Build a TestClient that bypasses the production lifespan by injecting
only the dependency overrides. We do NOT enter the context manager so the
lifespan never fires (which means no real ContextCompressor / VLLMClient
construction). Keys must be the dependency function references themselves
(e.g. ``get_registry``) — FastAPI matches by identity, not by name."""
for dep, factory in overrides.items():
app.dependency_overrides[dep] = factory
return TestClient(app)
@pytest.fixture(autouse=True)
def _clear_overrides():
yield
app.dependency_overrides.clear()
# ---- Tests --------------------------------------------------------------------------
def test_health_returns_ok_with_gpu_label() -> None:
metrics = FakeMetrics(gpu_label="cuda")
client = _client_with_overrides({get_metrics: lambda: metrics})
resp = client.get("/health")
assert resp.status_code == 200
assert resp.json() == {"status": "ok", "gpu": "cuda"}
def test_health_returns_degraded_on_internal_error() -> None:
metrics = FakeMetrics(raise_on_label=True)
client = _client_with_overrides({get_metrics: lambda: metrics})
resp = client.get("/health")
assert resp.status_code == 200
assert resp.json() == {"status": "degraded", "gpu": "unknown"}
def test_metrics_snapshot_returns_valid_pydantic() -> None:
metrics = FakeMetrics()
compressor = FakeCompressor(
current_model="xlm-roberta-large",
degradations=[Degradation(component="compressor", reason="OOM", fallback="cpu")],
)
client = _client_with_overrides(
{get_metrics: lambda: metrics, get_compressor: lambda: compressor}
)
resp = client.get("/metrics/snapshot")
assert resp.status_code == 200
snap = MetricsSnapshot.model_validate(resp.json())
assert snap.compressor_model == "xlm-roberta-large"
assert any(d.component == "compressor" for d in snap.degradations)
assert metrics._snapshot_kwargs is not None
assert metrics._snapshot_kwargs["current_compressor_model"] == "xlm-roberta-large"
def test_register_context_happy_path() -> None:
now = datetime.now(timezone.utc)
stub_entry = ContextEntry(
agent_id="alice",
context="hello world",
token_count=2,
created_at=now,
expires_at=now + timedelta(seconds=300),
)
registry = FakeRegistry(entry=stub_entry)
metrics = FakeMetrics()
client = _client_with_overrides(
{get_registry: lambda: registry, get_metrics: lambda: metrics}
)
resp = client.post(
"/tools/register_context",
json={"agent_id": "alice", "context": "hello world"},
)
assert resp.status_code == 200
parsed = ContextEntry.model_validate_json(resp.text)
assert parsed.agent_id == "alice"
assert parsed.context == "hello world"
assert metrics.register_calls == [False]
assert registry.register_calls == [("alice", "hello world")]
def test_register_context_422_on_empty_agent_id() -> None:
client = _client_with_overrides(
{get_registry: lambda: FakeRegistry(), get_metrics: lambda: FakeMetrics()}
)
resp = client.post(
"/tools/register_context",
json={"agent_id": "", "context": "x"},
)
assert resp.status_code == 422
def test_register_context_422_on_extra_field() -> None:
client = _client_with_overrides(
{get_registry: lambda: FakeRegistry(), get_metrics: lambda: FakeMetrics()}
)
resp = client.post(
"/tools/register_context",
json={"agent_id": "a", "context": "x", "hostile": 1},
)
assert resp.status_code == 422
def test_register_context_422_on_missing_field() -> None:
client = _client_with_overrides(
{get_registry: lambda: FakeRegistry(), get_metrics: lambda: FakeMetrics()}
)
resp = client.post("/tools/register_context", json={"agent_id": "a"})
assert resp.status_code == 422
def test_get_optimized_context_happy_path() -> None:
decision = CompressionDecision(
strategy="compress",
final_context="compressed body",
shared_prefix="",
original_tokens=1000,
final_tokens=500,
tokens_saved=500,
rationale="ctx_tokens > threshold",
)
coordinator = FakeCoordinator(decision=decision)
metrics = FakeMetrics()
client = _client_with_overrides(
{get_coordinator: lambda: coordinator, get_metrics: lambda: metrics}
)
resp = client.post(
"/tools/get_optimized_context",
json={"agent_id": "alice", "context": "hello"},
)
assert resp.status_code == 200
parsed = CompressionDecision.model_validate(resp.json())
assert parsed == decision
assert len(metrics.decision_calls) == 1
assert coordinator.decide_calls == [("alice", "hello")]
def test_get_optimized_context_503_fallback_on_handler_exception() -> None:
coordinator = FakeCoordinator(decision=RuntimeError("boom"))
metrics = FakeMetrics()
client = _client_with_overrides(
{get_coordinator: lambda: coordinator, get_metrics: lambda: metrics}
)
resp = client.post(
"/tools/get_optimized_context",
json={"agent_id": "alice", "context": "the original body"},
)
assert resp.status_code == 503
parsed = CompressionDecision.model_validate(resp.json())
assert parsed.strategy == "passthrough"
assert parsed.final_context == "the original body"
assert parsed.original_tokens == 0
assert parsed.final_tokens == 0
assert parsed.tokens_saved == 0
assert metrics.decision_calls == []
def test_get_optimized_context_422_on_malformed_body() -> None:
decision = CompressionDecision(
strategy="passthrough",
final_context="",
shared_prefix="",
original_tokens=0,
final_tokens=0,
tokens_saved=0,
rationale="",
)
client = _client_with_overrides(
{
get_coordinator: lambda: FakeCoordinator(decision=decision),
get_metrics: lambda: FakeMetrics(),
}
)
resp = client.post("/tools/get_optimized_context", json={"agent_id": "a"})
assert resp.status_code == 422
def test_no_log_includes_request_body(caplog: pytest.LogCaptureFixture) -> None:
sentinel = "REDACTION-SENTINEL-XYZZY-9F3A2B7C-do-not-log"
registry = FakeRegistry()
metrics = FakeMetrics()
client = _client_with_overrides(
{get_registry: lambda: registry, get_metrics: lambda: metrics}
)
with caplog.at_level(logging.DEBUG):
# Trigger both happy-path register AND the 503 warning path so any
# mishandled log surface is exercised.
client.post(
"/tools/register_context",
json={"agent_id": "alice", "context": sentinel},
)
# Now exercise the 503 path with the sentinel in the body
bad_coord = FakeCoordinator(decision=RuntimeError("boom"))
app.dependency_overrides[get_coordinator] = lambda: bad_coord
client.post(
"/tools/get_optimized_context",
json={"agent_id": "alice", "context": sentinel},
)
for record in caplog.records:
assert sentinel not in record.getMessage()
for value in record.__dict__.values():
assert sentinel not in str(value)
def test_lifespan_constructs_and_disposes(monkeypatch: pytest.MonkeyPatch) -> None:
# Replace the heavy production classes the lifespan reaches for so
# `with TestClient(app) as client:` does not download model weights or
# touch the network.
class _LifeReg:
instances: list = []
def __init__(self) -> None:
self.cleared = False
type(self).instances.append(self)
async def clear(self) -> None:
self.cleared = True
class _LifeComp:
def __init__(self) -> None:
pass
class _LifeCoord:
def __init__(self, registry=None, compressor=None) -> None:
self.registry = registry
self.compressor = compressor
class _LifeMetr:
def __init__(self) -> None:
pass
class _LifeVllm:
instances: list = []
def __init__(self) -> None:
self.closed = False
type(self).instances.append(self)
async def aclose(self) -> None:
self.closed = True
monkeypatch.setattr(srv, "ContextRegistry", _LifeReg)
monkeypatch.setattr(srv, "ContextCompressor", _LifeComp)
monkeypatch.setattr(srv, "CompressionCoordinator", _LifeCoord)
monkeypatch.setattr(srv, "MetricsCollector", _LifeMetr)
monkeypatch.setattr(srv, "VLLMClient", _LifeVllm)
with TestClient(app) as client:
assert isinstance(client.app.state.registry, _LifeReg)
assert isinstance(client.app.state.compressor, _LifeComp)
assert isinstance(client.app.state.coordinator, _LifeCoord)
assert isinstance(client.app.state.metrics, _LifeMetr)
assert isinstance(client.app.state.vllm, _LifeVllm)
# Coordinator must be wired to the SAME registry+compressor instances
assert client.app.state.coordinator.registry is client.app.state.registry
assert client.app.state.coordinator.compressor is client.app.state.compressor
# On context exit the lifespan ran cleanup
assert _LifeReg.instances and _LifeReg.instances[-1].cleared is True
assert _LifeVllm.instances and _LifeVllm.instances[-1].closed is True
def test_full_flow_register_then_optimize_passthrough() -> None:
# Real ContextRegistry with a hermetic FakeDedupEngine (no model download)
# plus a stub coordinator that always returns passthrough.
registry = ContextRegistry(dedup=FakeDedupEngine())
metrics = FakeMetrics()
compressor = FakeCompressor()
short_ctx = "this is a short context"
passthrough = CompressionDecision(
strategy="passthrough",
final_context=short_ctx,
shared_prefix="",
original_tokens=5,
final_tokens=5,
tokens_saved=0,
rationale="ctx_tokens <= threshold AND no long shared prefix",
)
coordinator = FakeCoordinator(decision=passthrough)
client = _client_with_overrides(
{
get_registry: lambda: registry,
get_metrics: lambda: metrics,
get_compressor: lambda: compressor,
get_coordinator: lambda: coordinator,
}
)
reg_resp = client.post(
"/tools/register_context",
json={"agent_id": "alice", "context": short_ctx},
)
assert reg_resp.status_code == 200
reg_entry = ContextEntry.model_validate_json(reg_resp.text)
assert reg_entry.agent_id == "alice"
opt_resp = client.post(
"/tools/get_optimized_context",
json={"agent_id": "alice", "context": short_ctx},
)
assert opt_resp.status_code == 200
decision = CompressionDecision.model_validate(opt_resp.json())
assert decision.strategy == "passthrough"
snap_resp = client.get("/metrics/snapshot")
assert snap_resp.status_code == 200
snap = MetricsSnapshot.model_validate(snap_resp.json())
# passthrough records (0,0) — tokens_processed stays 0; that's fine
assert snap.tokens_processed == 0
assert metrics.register_calls == [False]
assert len(metrics.decision_calls) == 1