| """ |
| ENGRAM Protocol β Compression Tests |
| |
| |
| Tests for kvcos.core.compression: |
| - FP16 passthrough |
| - Q8_0 round-trip accuracy & shape preservation |
| - PolarQuant round-trip accuracy & rotation invariants |
| - Dispatcher routing and Q4_0 fallback warning |
| - Edge cases: padding, single-element groups |
| """ |
|
|
| from __future__ import annotations |
|
|
| import warnings |
|
|
| import pytest |
| import torch |
|
|
| from kvcos.core.compression import ( |
| Q8_GROUP_SIZE, |
| CompressionResult, |
| compress, |
| compress_fp16, |
| compress_polarquant, |
| compress_q8_0, |
| decompress, |
| decompress_fp16, |
| decompress_polarquant, |
| decompress_q8_0, |
| ) |
| from kvcos.core.types import CompressionMethod |
|
|
|
|
| |
|
|
|
|
| class TestFP16: |
| """FP16 passthrough: no quantization, just dtype normalization.""" |
|
|
| def test_fp16_passthrough_shape( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| keys, _ = llama_kv_256 |
| result = compress_fp16(keys) |
| assert result.data.shape == keys.shape |
|
|
| def test_fp16_passthrough_dtype( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| keys, _ = llama_kv_256 |
| result = compress_fp16(keys) |
| assert result.data.dtype == torch.float16 |
|
|
| def test_fp16_passthrough_exact( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| keys, _ = llama_kv_256 |
| result = compress_fp16(keys) |
| assert torch.equal(result.data, keys.to(torch.float16)) |
|
|
| def test_fp16_compression_ratio_one( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| keys, _ = llama_kv_256 |
| result = compress_fp16(keys) |
| assert result.compression_ratio == 1.0 |
|
|
| def test_fp16_method_tag( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| keys, _ = llama_kv_256 |
| result = compress_fp16(keys) |
| assert result.method == CompressionMethod.FP16 |
|
|
| def test_fp16_from_fp32(self) -> None: |
| """FP32 input is cast to FP16.""" |
| t = torch.randn(4, 8, 32, 128, dtype=torch.float32) |
| result = compress_fp16(t) |
| assert result.data.dtype == torch.float16 |
| assert result.original_dtype == torch.float32 |
|
|
| def test_fp16_decompress_identity( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| keys, _ = llama_kv_256 |
| result = compress_fp16(keys) |
| out = decompress_fp16(result.data) |
| assert torch.equal(out, result.data) |
|
|
|
|
| |
|
|
|
|
| class TestQ8_0: |
| """Q8_0: group quantization matching llama.cpp GGML_TYPE_Q8_0.""" |
|
|
| def test_q8_0_shape_preserved( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| keys, _ = llama_kv_256 |
| result = compress_q8_0(keys) |
| assert result.data.shape == keys.shape |
|
|
| def test_q8_0_output_dtype( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| """Q8_0 stores dequantized bfloat16 for safetensors compat.""" |
| keys, _ = llama_kv_256 |
| result = compress_q8_0(keys) |
| assert result.data.dtype == torch.bfloat16 |
|
|
| def test_q8_0_method_tag( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| keys, _ = llama_kv_256 |
| result = compress_q8_0(keys) |
| assert result.method == CompressionMethod.Q8_0 |
|
|
| def test_q8_0_metadata_group_size( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| keys, _ = llama_kv_256 |
| result = compress_q8_0(keys) |
| assert result.metadata["q8_group_size"] == str(Q8_GROUP_SIZE) |
|
|
| def test_q8_0_round_trip_low_error( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| """Q8_0 quantization error should be < 1% relative MSE.""" |
| keys, _ = llama_kv_256 |
| result = compress_q8_0(keys) |
| decompressed = decompress_q8_0(result.data) |
|
|
| original = keys.float() |
| restored = decompressed.float() |
|
|
| mse = ((original - restored) ** 2).mean() |
| signal_power = (original**2).mean() |
| relative_mse = (mse / signal_power).item() |
| assert relative_mse < 0.01, f"Q8_0 relative MSE {relative_mse:.6f} > 1%" |
|
|
| def test_q8_0_round_trip_values( |
| self, phi3_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| """Q8_0 round-trip on Phi-3 (head_dim=96, needs padding).""" |
| keys, values = phi3_kv_256 |
| for tensor in (keys, values): |
| result = compress_q8_0(tensor) |
| assert result.data.shape == tensor.shape |
|
|
| def test_q8_0_compression_ratio_fp32(self) -> None: |
| """FP32 input β bfloat16 output gives 2x compression ratio.""" |
| t = torch.randn(2, 4, 64, 128, dtype=torch.float32) |
| result = compress_q8_0(t) |
| assert abs(result.compression_ratio - 2.0) < 0.01 |
|
|
| def test_q8_0_compression_ratio_fp16( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| """FP16 input β bfloat16 output gives 1x ratio (same byte width).""" |
| keys, _ = llama_kv_256 |
| result = compress_q8_0(keys) |
| assert abs(result.compression_ratio - 1.0) < 0.01 |
|
|
| def test_q8_0_preserves_original_dtype(self) -> None: |
| t = torch.randn(4, 8, 32, 128, dtype=torch.float32) |
| result = compress_q8_0(t) |
| assert result.original_dtype == torch.float32 |
|
|
| def test_q8_0_padding_dim_not_divisible(self) -> None: |
| """Head dims not divisible by 32 get padded then unpadded.""" |
| t = torch.randn(2, 4, 16, 96, dtype=torch.float16) |
| result = compress_q8_0(t) |
| assert result.data.shape == t.shape |
|
|
| t2 = torch.randn(2, 4, 16, 100, dtype=torch.float16) |
| result2 = compress_q8_0(t2) |
| assert result2.data.shape == t2.shape |
|
|
| def test_q8_0_zero_tensor(self) -> None: |
| """All-zero tensor should round-trip exactly.""" |
| t = torch.zeros(2, 4, 16, 128, dtype=torch.float16) |
| result = compress_q8_0(t) |
| decompressed = decompress_q8_0(result.data) |
| assert torch.allclose(decompressed, t.to(torch.float16), atol=1e-6) |
|
|
|
|
| |
|
|
|
|
| class TestPolarQuant: |
| """PolarQuant: MSE-optimal random rotation + Lloyd-Max at 3 bits. |
| QJL intentionally absent (D5). |
| """ |
|
|
| def test_polarquant_shape_preserved( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| keys, _ = llama_kv_256 |
| result = compress_polarquant(keys) |
| assert result.data.shape == keys.shape |
|
|
| def test_polarquant_output_dtype( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| keys, _ = llama_kv_256 |
| result = compress_polarquant(keys) |
| assert result.data.dtype == torch.bfloat16 |
|
|
| def test_polarquant_method_tag( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| keys, _ = llama_kv_256 |
| result = compress_polarquant(keys) |
| assert result.method == CompressionMethod.POLARQUANT |
|
|
| def test_polarquant_metadata_qjl_disabled( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| """D5: QJL must be marked disabled in metadata.""" |
| keys, _ = llama_kv_256 |
| result = compress_polarquant(keys) |
| assert result.metadata["qjl_enabled"] == "false" |
| assert result.metadata["polarquant_bits"] == "3" |
|
|
| def test_polarquant_round_trip_bounded_error( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| """PolarQuant 3-bit error should be < 15% relative MSE. |
| |
| 3-bit Lloyd-Max on rotated Gaussian: theoretical ~10% for 8 centroids. |
| Allow margin for rotation + dtype casting. |
| """ |
| keys, _ = llama_kv_256 |
| result = compress_polarquant(keys) |
| decompressed = decompress_polarquant(result.data) |
|
|
| original = keys.float() |
| restored = decompressed.float() |
|
|
| mse = ((original - restored) ** 2).mean() |
| signal_power = (original**2).mean() |
| relative_mse = (mse / signal_power).item() |
| assert relative_mse < 0.15, f"PolarQuant relative MSE {relative_mse:.4f} > 15%" |
|
|
| def test_polarquant_worse_than_q8_0( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| """3-bit PolarQuant should have higher error than 8-bit Q8_0.""" |
| keys, _ = llama_kv_256 |
| original = keys.float() |
|
|
| q8_result = compress_q8_0(keys) |
| pq_result = compress_polarquant(keys) |
|
|
| q8_mse = ((original - decompress_q8_0(q8_result.data).float()) ** 2).mean() |
| pq_mse = ( |
| (original - decompress_polarquant(pq_result.data).float()) ** 2 |
| ).mean() |
|
|
| assert pq_mse > q8_mse, "PolarQuant 3-bit should be less accurate than Q8_0" |
|
|
| def test_polarquant_deterministic( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| """Same input β same output (fixed seed rotation matrix).""" |
| keys, _ = llama_kv_256 |
| r1 = compress_polarquant(keys) |
| r2 = compress_polarquant(keys) |
| assert torch.equal(r1.data, r2.data) |
|
|
| def test_polarquant_phi3_shape( |
| self, phi3_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| """Phi-3 head_dim=96 works with PolarQuant.""" |
| keys, _ = phi3_kv_256 |
| result = compress_polarquant(keys) |
| assert result.data.shape == keys.shape |
|
|
|
|
| |
|
|
|
|
| class TestDispatcher: |
| """compress() and decompress() dispatch to correct implementations.""" |
|
|
| @pytest.mark.parametrize( |
| "method", |
| [CompressionMethod.FP16, CompressionMethod.Q8_0, CompressionMethod.POLARQUANT], |
| ) |
| def test_compress_dispatches(self, method: CompressionMethod) -> None: |
| t = torch.randn(2, 4, 16, 128, dtype=torch.float16) |
| result = compress(t, method) |
| assert isinstance(result, CompressionResult) |
| assert result.method == method |
|
|
| @pytest.mark.parametrize( |
| "method", |
| [CompressionMethod.FP16, CompressionMethod.Q8_0, CompressionMethod.POLARQUANT], |
| ) |
| def test_decompress_returns_fp16(self, method: CompressionMethod) -> None: |
| t = torch.randn(2, 4, 16, 128, dtype=torch.float16) |
| result = compress(t, method) |
| out = decompress(result.data, method) |
| assert out.dtype == torch.float16 |
|
|
| def test_q4_0_warns_and_falls_back(self) -> None: |
| """D5: Q4_0 emits warning and uses Q8_0 instead.""" |
| t = torch.randn(2, 4, 16, 128, dtype=torch.float16) |
| with warnings.catch_warnings(record=True) as w: |
| warnings.simplefilter("always") |
| result = compress(t, CompressionMethod.Q4_0) |
| assert len(w) == 1 |
| assert "Q4_0" in str(w[0].message) |
| assert "92%" in str(w[0].message) |
| assert result.method == CompressionMethod.Q8_0 |
|
|
| def test_unknown_method_raises(self) -> None: |
| t = torch.randn(2, 4, 16, 128, dtype=torch.float16) |
| with pytest.raises(ValueError, match="Unknown compression method"): |
| compress(t, "invalid_method") |
|
|
| def test_decompress_unknown_raises(self) -> None: |
| t = torch.randn(2, 4, 16, 128, dtype=torch.float16) |
| with pytest.raises(ValueError, match="Unknown compression method"): |
| decompress(t, "invalid_method") |
|
|
|
|
| |
|
|
|
|
| class TestRoundTrip: |
| """Full compress β decompress round-trip through dispatcher.""" |
|
|
| @pytest.mark.parametrize( |
| "method", |
| [CompressionMethod.FP16, CompressionMethod.Q8_0, CompressionMethod.POLARQUANT], |
| ) |
| def test_round_trip_shape_preserved(self, method: CompressionMethod) -> None: |
| t = torch.randn(4, 8, 64, 128, dtype=torch.float16) |
| result = compress(t, method) |
| out = decompress(result.data, method) |
| assert out.shape == t.shape |
|
|
| def test_round_trip_both_kv( |
| self, llama_kv_256: tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| """Compress and decompress both keys and values.""" |
| keys, values = llama_kv_256 |
| for tensor in (keys, values): |
| for method in (CompressionMethod.FP16, CompressionMethod.Q8_0): |
| result = compress(tensor, method) |
| out = decompress(result.data, method) |
| assert out.shape == tensor.shape |
| assert out.dtype == torch.float16 |
|
|