engram / kvcos /core /compression.py
eigengram's picture
feat: upload core kvcos library
0769ff3 verified
"""
ENGRAM Protocol β€” KV Cache Compression Layer
Implements:
- FP16 passthrough (no compression)
- Q8_0: group quantization matching llama.cpp GGML_TYPE_Q8_0
Phase 1 production fallback. ~2x compression, <5% speed hit (D5).
- PolarQuant: MSE-optimal random rotation + Lloyd-Max codebook at 3 bits.
QJL REMOVED β€” confirmed harmful by 6+ independent implementations (D5).
Softmax amplifies QJL variance, making two-stage worse than MSE-only.
Reference: TheTom/turboquant_plus (511+ tests, most mature impl)
"""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
import torch
from kvcos.core.types import CompressionMethod
# ── Q8_0 Constants ────────────────────────────────────────────────────────────
Q8_GROUP_SIZE = 32
@dataclass(frozen=True)
class CompressionResult:
"""Result of compressing a KV cache tensor."""
data: torch.Tensor
method: CompressionMethod
original_dtype: torch.dtype
compression_ratio: float
metadata: dict[str, str]
# ── FP16 Passthrough ──────────────────────────────────────────────────────────
def compress_fp16(kv: torch.Tensor) -> CompressionResult:
"""No-op compression: ensure tensor is FP16."""
data = kv.to(torch.float16).contiguous()
return CompressionResult(
data=data,
method=CompressionMethod.FP16,
original_dtype=kv.dtype,
compression_ratio=1.0,
metadata={},
)
def decompress_fp16(data: torch.Tensor) -> torch.Tensor:
return data.to(torch.float16)
# ── Q8_0 Quantization ────────────────────────────────────────────────────────
# Matches llama.cpp GGML_TYPE_Q8_0 layout:
# 32-element groups, 1 float16 scale per group, 32 int8 values
# Storage: (32*1 + 2) / (32*2) = 34/64 β‰ˆ 1.88x compression
def compress_q8_0(kv: torch.Tensor) -> CompressionResult:
"""Quantize KV cache to Q8_0 (int8 with per-group scale).
Stores dequantized bfloat16 for safetensors compatibility β€”
safetensors doesn't support int8+scale pairs natively.
"""
original_dtype = kv.dtype
original_bytes = kv.numel() * kv.element_size()
kv_flat = kv.float().contiguous()
orig_shape = kv_flat.shape
last_dim = orig_shape[-1]
pad_amount = (Q8_GROUP_SIZE - last_dim % Q8_GROUP_SIZE) % Q8_GROUP_SIZE
if pad_amount > 0:
kv_flat = torch.nn.functional.pad(kv_flat, (0, pad_amount))
new_shape = kv_flat.shape[:-1] + (-1, Q8_GROUP_SIZE)
grouped = kv_flat.reshape(new_shape)
scales = grouped.abs().amax(dim=-1, keepdim=True) / 127.0
scales = scales.clamp(min=1e-10)
quantized = torch.clamp(torch.round(grouped / scales), -127, 127)
dequantized = (quantized * scales).reshape(kv_flat.shape)
if pad_amount > 0:
dequantized = dequantized[..., :last_dim]
dequantized = dequantized.reshape(orig_shape).to(torch.bfloat16)
compressed_bytes = dequantized.numel() * 2
return CompressionResult(
data=dequantized,
method=CompressionMethod.Q8_0,
original_dtype=original_dtype,
compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0,
metadata={"q8_group_size": str(Q8_GROUP_SIZE)},
)
def decompress_q8_0(data: torch.Tensor) -> torch.Tensor:
return data.to(torch.float16)
# ── PolarQuant (Phase 2 β€” TurboQuant without QJL) ────────────────────────────
# QJL is INTENTIONALLY ABSENT per D5.
class PolarQuantConfig:
"""Configuration for PolarQuant compression."""
def __init__(self, bits: int = 3, seed: int = 42):
self.bits = bits
self.n_centroids = 2**bits
self.seed = seed
self._rotation_cache: dict[int, torch.Tensor] = {}
self._codebook_cache: dict[int, torch.Tensor] = {}
def get_rotation_matrix(self, dim: int, device: torch.device) -> torch.Tensor:
"""Get fixed random orthogonal rotation matrix R ∈ R^(dΓ—d)."""
if dim not in self._rotation_cache:
rng = np.random.RandomState(self.seed)
gaussian = rng.randn(dim, dim).astype(np.float32)
q, r = np.linalg.qr(gaussian)
d = np.diag(r)
ph = np.sign(d)
q *= ph[np.newaxis, :]
self._rotation_cache[dim] = torch.from_numpy(q)
return self._rotation_cache[dim].to(device)
def get_lloyd_max_codebook(self, dim: int) -> torch.Tensor:
"""Lloyd-Max optimal centroids for N(0,1), 3-bit (8 levels)."""
if dim not in self._codebook_cache:
codebook = torch.tensor(
[-1.748, -1.050, -0.501, -0.000, 0.000, 0.501, 1.050, 1.748],
dtype=torch.float32,
)
self._codebook_cache[dim] = codebook
return self._codebook_cache[dim]
_POLAR_CONFIG = PolarQuantConfig()
def compress_polarquant(kv: torch.Tensor) -> CompressionResult:
"""Compress using PolarQuant (3-bit Lloyd-Max after random rotation).
Phase 2 implementation. Currently stores dequantized bfloat16.
True 3-bit packed storage is Phase 2+.
"""
original_dtype = kv.dtype
original_bytes = kv.numel() * kv.element_size()
device = kv.device
kv_float = kv.float().contiguous()
orig_shape = kv_float.shape
head_dim = orig_shape[-1]
flat = kv_float.reshape(-1, head_dim)
R = _POLAR_CONFIG.get_rotation_matrix(head_dim, device)
rotated = flat @ R
dim_std = rotated.std(dim=0, keepdim=True).clamp(min=1e-10)
normalized = rotated / dim_std
codebook = _POLAR_CONFIG.get_lloyd_max_codebook(head_dim).to(device)
distances = (normalized.unsqueeze(-1) - codebook.unsqueeze(0).unsqueeze(0)) ** 2
indices = distances.argmin(dim=-1)
dequantized = codebook[indices]
dequantized = dequantized * dim_std
R_inv = R.T
dequantized = dequantized @ R_inv
dequantized = dequantized.reshape(orig_shape).to(torch.bfloat16)
compressed_bytes = dequantized.numel() * 2
return CompressionResult(
data=dequantized,
method=CompressionMethod.POLARQUANT,
original_dtype=original_dtype,
compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0,
metadata={
"polarquant_bits": "3",
"polarquant_seed": str(_POLAR_CONFIG.seed),
"qjl_enabled": "false", # D5: QJL permanently disabled
},
)
def decompress_polarquant(data: torch.Tensor) -> torch.Tensor:
return data.to(torch.float16)
# ── INT8 Quantization (Phase 2 β€” true on-disk compression) ───────────────────
# Stores actual int8 tensors in safetensors (1 byte/element vs 2 for fp16).
# Per-row symmetric quantization: scale = max(abs(row)) / 127.
# Separate scale tensor stored alongside quantized data.
# 2x on-disk compression with cos_sim > 0.999.
@dataclass(frozen=True)
class Int8CompressedPair:
"""INT8 quantized tensor + per-row scales."""
quantized: torch.Tensor # int8 [same shape as input]
scales: torch.Tensor # float16 [shape[:-1]] β€” one scale per row
def compress_int8_tensor(kv: torch.Tensor) -> Int8CompressedPair:
"""Quantize a KV tensor to int8 with per-row scales.
Args:
kv: [..., head_dim] tensor (any dtype)
Returns:
Int8CompressedPair with int8 data and float16 scales
"""
orig_shape = kv.shape
flat = kv.float().reshape(-1, orig_shape[-1])
row_max = flat.abs().amax(dim=1, keepdim=True).clamp(min=1e-8)
scales = row_max / 127.0
quantized = (flat / scales).round().clamp(-127, 127).to(torch.int8)
scales_f16 = scales.squeeze(1).to(torch.float16)
return Int8CompressedPair(
quantized=quantized.reshape(orig_shape),
scales=scales_f16.reshape(orig_shape[:-1]),
)
def decompress_int8_tensor(quantized: torch.Tensor, scales: torch.Tensor) -> torch.Tensor:
"""Dequantize int8 tensor using per-row scales.
Returns float16 tensor of the original shape.
"""
return (quantized.float() * scales.float().unsqueeze(-1)).to(torch.float16)
def compress_int8(kv: torch.Tensor) -> CompressionResult:
"""INT8 compression β€” returns dequantized float16 for CompressionResult compat.
The actual int8 storage is handled by the serializer which calls
compress_int8_tensor() directly for true on-disk compression.
This wrapper exists for the dispatcher API.
"""
pair = compress_int8_tensor(kv)
dequantized = decompress_int8_tensor(pair.quantized, pair.scales)
original_bytes = kv.numel() * kv.element_size()
# True on-disk: int8 data + float16 scales
compressed_bytes = pair.quantized.numel() * 1 + pair.scales.numel() * 2
return CompressionResult(
data=dequantized,
method=CompressionMethod.INT8,
original_dtype=kv.dtype,
compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0,
metadata={"int8_scale_dtype": "float16"},
)
# ── LAYER_DELTA Compression ──────────────────────────────────────────────────
# Stores layer 0 as fp16 baseline, layers 1..N as int8 deltas from previous.
# Inter-layer residuals are typically small (adjacent layers are correlated),
# so int8 quantization of deltas achieves better fidelity than direct int8.
# On-disk: ~(1/N) fp16 + ((N-1)/N) int8 β‰ˆ slightly better than straight INT8.
@dataclass(frozen=True)
class LayerDeltaCompressed:
"""Layer-delta compressed: fp16 baseline + int8 deltas."""
baseline: torch.Tensor # [n_kv_heads, n_cells, head_dim] fp16
delta_quantized: list[torch.Tensor] # each int8 [n_kv_heads, n_cells, head_dim]
delta_scales: list[torch.Tensor] # each fp16 [n_kv_heads, n_cells]
n_layers: int
def compress_layer_delta(kv: torch.Tensor) -> LayerDeltaCompressed:
"""Compress KV tensor using inter-layer delta encoding.
Args:
kv: [n_layers, n_kv_heads, n_cells, head_dim]
Returns:
LayerDeltaCompressed with fp16 baseline + int8 deltas
"""
n_layers = kv.shape[0]
baseline = kv[0].to(torch.float16)
deltas: list[torch.Tensor] = []
scales: list[torch.Tensor] = []
for i in range(1, n_layers):
delta = (kv[i].float() - kv[i - 1].float())
flat = delta.reshape(-1, delta.shape[-1])
row_max = flat.abs().amax(dim=1).clamp(min=1e-8) / 127.0
q = (flat / row_max.unsqueeze(1)).round().clamp(-127, 127).to(torch.int8)
deltas.append(q.reshape(delta.shape))
scales.append(row_max.to(torch.float16).reshape(delta.shape[:-1]))
return LayerDeltaCompressed(
baseline=baseline, delta_quantized=deltas,
delta_scales=scales, n_layers=n_layers,
)
def decompress_layer_delta(data: LayerDeltaCompressed) -> torch.Tensor:
"""Decompress layer-delta encoded KV tensor."""
layers = [data.baseline.float()]
for dq, ds in zip(data.delta_quantized, data.delta_scales):
flat = dq.float().reshape(-1, dq.shape[-1])
delta = (flat * ds.float().reshape(-1).unsqueeze(1)).reshape(dq.shape)
layers.append(layers[-1] + delta)
return torch.stack(layers).to(torch.float16)
def compress_layer_delta_result(kv: torch.Tensor) -> CompressionResult:
"""Layer-delta wrapper for CompressionResult API."""
compressed = compress_layer_delta(kv)
decompressed = decompress_layer_delta(compressed)
original_bytes = kv.numel() * kv.element_size()
# On-disk: baseline fp16 + (N-1) int8 deltas + (N-1) fp16 scales
n = compressed.n_layers
per_layer_elements = kv[0].numel()
scale_elements = kv.shape[1] * kv.shape[2] # n_kv_heads * n_cells
compressed_bytes = (
per_layer_elements * 2 # baseline fp16
+ (n - 1) * per_layer_elements * 1 # int8 deltas
+ (n - 1) * scale_elements * 2 # fp16 scales
)
return CompressionResult(
data=decompressed,
method=CompressionMethod.LAYER_DELTA,
original_dtype=kv.dtype,
compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0,
metadata={"delta_n_layers": str(n)},
)
# ── Dispatcher ────────────────────────────────────────────────────────────────
def compress(kv: torch.Tensor, method: CompressionMethod) -> CompressionResult:
"""Compress a KV cache tensor using the specified method."""
match method:
case CompressionMethod.FP16:
return compress_fp16(kv)
case CompressionMethod.Q8_0:
return compress_q8_0(kv)
case CompressionMethod.POLARQUANT:
return compress_polarquant(kv)
case CompressionMethod.INT8:
return compress_int8(kv)
case CompressionMethod.LAYER_DELTA:
return compress_layer_delta_result(kv)
case CompressionMethod.Q4_0:
import warnings
warnings.warn(
"Q4_0 has 92% dequantization slowdown at 64K+ context. "
"Using Q8_0 instead. See D5.",
UserWarning,
stacklevel=2,
)
return compress_q8_0(kv)
case _:
raise ValueError(f"Unknown compression method: {method}")
def decompress(data: torch.Tensor, method: CompressionMethod) -> torch.Tensor:
"""Decompress a KV cache tensor."""
match method:
case CompressionMethod.FP16:
return decompress_fp16(data)
case CompressionMethod.Q8_0 | CompressionMethod.Q4_0:
return decompress_q8_0(data)
case CompressionMethod.POLARQUANT:
return decompress_polarquant(data)
case CompressionMethod.INT8 | CompressionMethod.LAYER_DELTA:
# Already dequantized float16 in CompressionResult
return data.to(torch.float16)
case _:
raise ValueError(f"Unknown compression method: {method}")