BitLinear / tests /test_quantization.py
krisaujla's picture
Upload folder using huggingface_hub
fd8c8b9 verified
"""
Unit tests for quantization utilities.
These tests are here to validate ternary quantization, scaling, and packing functions. Here are the following test cases:
TestAbsmaxScale (3 tests)
1. test_global_scale - Tests global absmax scaling computation
2. test_per_channel_scale - Tests per-channel (per-row) absmax scaling
3. test_zero_tensor - Validates behavior with zero tensors (numerical stability)
TestTernaryQuantize (3 tests)
1. test_quantization_values - Ensures output contains only {-1, 0, +1}
2. test_sign_preservation - Validates sign preservation for large values
3. test_threshold_behavior - Tests threshold-based zero assignment
TestWeightToTernary (3 tests)
1. test_output_shapes - Verifies correct output tensor shapes
2. test_per_channel_vs_global - Tests per-channel vs. global scaling modes
3. test_reconstruction_quality - Validates reconstruction error is reasonable
TestActivationQuantization (2 tests)
1. test_quantization_range - Tests 8-bit quantization range
2. test_per_token_scaling - Validates per-token vs. global scaling
TestDequantization (1 test)
1. test_dequantize_inverse - Tests quantize β†’ dequantize inverse operation
TestBase3Packing (3 tests)
1. test_pack_unpack_roundtrip - Validates pack β†’ unpack recovers original
2. test_memory_efficiency - Tests ~20x compression achievement
3. test_packing_with_padding - Tests padding for non-multiple-of-5 dimensions
TestCompressionUtilities (2 tests)
1. test_compression_ratio_calculation - Tests compression ratio computation
2. test_memory_savings_estimation - Validates memory savings estimation
TestQuantizationIntegration (2 tests)
1. test_full_quantization_pipeline - Tests dense β†’ ternary β†’ packed β†’ unpacked
2. test_quantization_preserves_functionality - Validates quantized layer outputs
"""
import pytest
import torch
from bitlinear.quantization import (
absmax_scale,
ternary_quantize,
weight_to_ternary,
quantize_activations_absmax,
dequantize_scale,
)
from bitlinear.packing import (
pack_ternary_base3,
unpack_ternary_base3,
compute_compression_ratio,
estimate_memory_savings,
)
class TestAbsmaxScale:
"""Tests for absmax_scale function."""
def test_global_scale(self):
"""Test global absmax scaling."""
W = torch.tensor([[1.0, -2.0, 3.0], [4.0, -5.0, 6.0]])
scale = absmax_scale(W, dim=None)
assert torch.isclose(scale, torch.tensor(6.0))
def test_per_channel_scale(self):
"""Test per-channel (per-row) absmax scaling."""
W = torch.tensor([[1.0, -2.0, 3.0], [4.0, -5.0, 6.0]])
scale = absmax_scale(W, dim=1)
expected = torch.tensor([3.0, 6.0])
assert torch.allclose(scale, expected)
def test_zero_tensor(self):
"""Test behavior with zero tensor."""
W = torch.zeros(10, 10)
scale = absmax_scale(W, dim=None)
# Should handle division by zero gracefully (clamped to epsilon)
assert scale > 0
assert scale < 1e-4
class TestTernaryQuantize:
"""Tests for ternary_quantize function."""
def test_quantization_values(self):
"""Test that output contains only {-1, 0, +1}."""
W = torch.randn(100, 100)
W_ternary = ternary_quantize(W)
unique_values = torch.unique(W_ternary)
assert set(unique_values.tolist()).issubset({-1.0, 0.0, 1.0})
def test_sign_preservation(self):
"""Test that signs are preserved correctly."""
# Use values well above threshold (> 0.5 * max)
W = torch.tensor([[10.0, -10.0, 0.01], [-8.0, 8.0, -0.01]])
W_ternary = ternary_quantize(W)
# Large positive values should be +1
assert W_ternary[0, 0] == 1.0
# Large negative values should be -1
assert W_ternary[0, 1] == -1.0
assert W_ternary[1, 0] == -1.0
# Large positive
assert W_ternary[1, 1] == 1.0
def test_threshold_behavior(self):
"""Test that threshold determines zero assignment."""
# Create tensor with known values
W = torch.tensor([[10.0, 0.1, -10.0], [0.2, -0.2, 5.0]])
W_ternary = ternary_quantize(W)
# Small values near zero should become 0
# Exact behavior depends on threshold, but there should be some zeros
assert 0.0 in W_ternary
class TestWeightToTernary:
"""Tests for weight_to_ternary function."""
def test_output_shapes(self):
"""Test that output shapes are correct."""
W = torch.randn(512, 768)
W_ternary, gamma = weight_to_ternary(W, per_channel=True)
assert W_ternary.shape == (512, 768)
assert gamma.shape == (512,)
def test_per_channel_vs_global(self):
"""Test difference between per-channel and global scaling."""
W = torch.randn(512, 768)
W_t_pc, gamma_pc = weight_to_ternary(W, per_channel=True)
W_t_g, gamma_g = weight_to_ternary(W, per_channel=False)
assert gamma_pc.shape == (512,)
assert gamma_g.shape == torch.Size([]) # Scalar
def test_reconstruction_quality(self):
"""Test that reconstruction W_ternary * gamma approximates W."""
W = torch.randn(512, 768)
W_ternary, gamma = weight_to_ternary(W, per_channel=True)
W_reconstructed = W_ternary * gamma.unsqueeze(1)
error = torch.norm(W - W_reconstructed) / torch.norm(W)
# Ternary quantization has inherent error, allow up to 0.9 relative error
# This is expected for aggressive quantization to only 3 values
assert error < 1.0
class TestActivationQuantization:
"""Tests for activation quantization."""
def test_quantization_range(self):
"""Test that quantized values are in expected range."""
x = torch.randn(16, 32, 512)
x_quant = quantize_activations_absmax(x, bits=8, per_token=True)
# Should be roughly in similar range as input
assert x_quant.abs().max() <= x.abs().max() * 1.1
def test_per_token_scaling(self):
"""Test per-token vs. global scaling."""
x = torch.randn(16, 32, 512)
x_quant_per_token = quantize_activations_absmax(x, bits=8, per_token=True)
x_quant_global = quantize_activations_absmax(x, bits=8, per_token=False)
# Both should work without errors
assert x_quant_per_token.shape == x.shape
assert x_quant_global.shape == x.shape
class TestDequantization:
"""Tests for dequantization."""
def test_dequantize_inverse(self):
"""Test that quantize β†’ dequantize is approximately identity."""
W = torch.randn(512, 768)
W_quant, scale = weight_to_ternary(W, per_channel=True)
W_dequant = dequantize_scale(W_quant, scale)
# Should be close to W_quant * scale reconstruction
W_expected = W_quant * scale.unsqueeze(1)
assert torch.allclose(W_dequant, W_expected)
class TestBase3Packing:
"""Tests for base-3 packing utilities."""
def test_pack_unpack_roundtrip(self):
"""Test that pack β†’ unpack recovers original ternary weights."""
W_ternary = torch.randint(-1, 2, (512, 768)).float()
packed, shape = pack_ternary_base3(W_ternary)
W_unpacked = unpack_ternary_base3(packed, shape)
assert torch.allclose(W_ternary, W_unpacked)
def test_memory_efficiency(self):
"""Test that packing achieves expected compression."""
W_ternary = torch.randint(-1, 2, (512, 768)).float()
original_size = W_ternary.numel() * 4 # float32 = 4 bytes
packed, shape = pack_ternary_base3(W_ternary)
packed_size = packed.numel() * 1 # uint8 = 1 byte
compression = original_size / packed_size
# Should achieve ~20x compression (32 bits β†’ 1.6 bits)
assert compression > 15 # Allow some overhead
def test_packing_with_padding(self):
"""Test packing when dimensions are not multiples of 5."""
# Test with various sizes to ensure padding is handled correctly
for size in [(13, 17), (100, 203), (7, 11)]:
W_ternary = torch.randint(-1, 2, size).float()
packed, shape = pack_ternary_base3(W_ternary)
W_unpacked = unpack_ternary_base3(packed, shape)
assert torch.allclose(W_ternary, W_unpacked)
class TestCompressionUtilities:
"""Tests for compression ratio and memory estimation utilities."""
def test_compression_ratio_calculation(self):
"""Test compression ratio calculation."""
ratio = compute_compression_ratio(1024, 51)
assert abs(ratio - 20.0) < 0.5
def test_memory_savings_estimation(self):
"""Test memory savings estimation for layer."""
stats = estimate_memory_savings(768, 3072, num_layers=12)
assert 'float32_bytes' in stats
assert 'packed_bytes' in stats
assert 'savings_bytes' in stats
assert 'compression_ratio' in stats
assert stats['compression_ratio'] > 15
class TestQuantizationIntegration:
"""Integration tests for quantization pipeline."""
def test_full_quantization_pipeline(self):
"""Test complete pipeline: dense β†’ ternary β†’ packed β†’ unpacked."""
# 1. Start with dense weights
W = torch.randn(128, 256)
# 2. Quantize to ternary
W_ternary, gamma = weight_to_ternary(W, per_channel=True)
# 3. Pack to base-3
packed, shape = pack_ternary_base3(W_ternary)
# 4. Unpack
W_unpacked = unpack_ternary_base3(packed, shape)
# 5. Verify correctness
assert torch.allclose(W_ternary, W_unpacked)
assert set(W_unpacked.unique().tolist()).issubset({-1.0, 0.0, 1.0})
def test_quantization_preserves_functionality(self):
"""Test that quantized layer produces reasonable outputs."""
from bitlinear import BitLinear
import torch.nn as nn
# Create dense layer
dense = nn.Linear(256, 128)
# Test input
x = torch.randn(16, 256)
out_dense = dense(x)
# Quantize to BitLinear
bitlinear = BitLinear.from_linear(dense)
out_quantized = bitlinear(x)
# Outputs should have same shape
assert out_dense.shape == out_quantized.shape
# Outputs should be correlated (similar but not identical)
# Calculate correlation
correlation = torch.corrcoef(torch.stack([out_dense.flatten(), out_quantized.flatten()]))[0, 1]
assert correlation > 0.5 # Should have reasonable correlation