|
|
"""
|
|
|
Base-3 packing utilities for memory-efficient ternary weight storage.
|
|
|
|
|
|
Ternary weights ({-1, 0, +1}) can be represented in base-3, allowing
|
|
|
multiple ternary values to be packed into a single byte or integer.
|
|
|
This provides significant memory savings over storing each value as a float32.
|
|
|
|
|
|
Theoretical packing:
|
|
|
- 1 ternary value requires log2(3) β 1.58 bits
|
|
|
- 5 ternary values fit in 1 byte (3^5 = 243 < 256)
|
|
|
- Compression ratio: 32 bits (float) β ~1.6 bits (packed) = 20x compression
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
from typing import Tuple
|
|
|
|
|
|
|
|
|
def pack_ternary_base3(W_ternary: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, ...]]:
|
|
|
"""
|
|
|
Pack ternary weights into base-3 representation for memory efficiency.
|
|
|
|
|
|
Packs multiple ternary values ({-1, 0, +1}) into uint8 storage using base-3
|
|
|
encoding. This achieves near-optimal compression for ternary data.
|
|
|
|
|
|
Encoding scheme:
|
|
|
-1 β 0 (base 3)
|
|
|
0 β 1 (base 3)
|
|
|
+1 β 2 (base 3)
|
|
|
|
|
|
Then pack 5 base-3 digits into one byte:
|
|
|
packed_byte = d0 + d1*3 + d2*9 + d3*27 + d4*81
|
|
|
|
|
|
Args:
|
|
|
W_ternary: Ternary weight tensor with values in {-1, 0, +1}
|
|
|
Shape: [out_features, in_features] or [k, out_features, in_features]
|
|
|
|
|
|
Returns:
|
|
|
packed: Packed weights as uint8 tensor (5 values per byte)
|
|
|
original_shape: Shape of original tensor for unpacking
|
|
|
|
|
|
Notes:
|
|
|
- 5 ternary values per byte (3^5 = 243 < 256)
|
|
|
- Pad with zeros if dimensions not divisible by 5
|
|
|
- This is the primary memory optimization for ternary weights
|
|
|
"""
|
|
|
original_shape = tuple(W_ternary.shape)
|
|
|
|
|
|
|
|
|
base3 = (W_ternary + 1).flatten().to(torch.uint8)
|
|
|
|
|
|
|
|
|
numel = base3.numel()
|
|
|
pad_size = (5 - numel % 5) % 5
|
|
|
if pad_size > 0:
|
|
|
base3 = torch.cat([base3, torch.zeros(pad_size, dtype=torch.uint8, device=base3.device)])
|
|
|
|
|
|
|
|
|
base3 = base3.view(-1, 5)
|
|
|
|
|
|
|
|
|
powers_of_3 = torch.tensor([1, 3, 9, 27, 81], dtype=torch.uint8, device=base3.device)
|
|
|
packed = (base3 * powers_of_3).sum(dim=1)
|
|
|
|
|
|
return packed, original_shape
|
|
|
|
|
|
|
|
|
def unpack_ternary_base3(
|
|
|
packed: torch.Tensor,
|
|
|
original_shape: Tuple[int, ...],
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
Unpack base-3 encoded ternary weights back to full representation.
|
|
|
|
|
|
Reverses the packing operation to recover ternary weights.
|
|
|
|
|
|
Args:
|
|
|
packed: Packed uint8 tensor (5 values per byte)
|
|
|
original_shape: Original shape of the ternary tensor
|
|
|
|
|
|
Returns:
|
|
|
W_ternary: Ternary weight tensor with values in {-1, 0, +1}
|
|
|
"""
|
|
|
|
|
|
d0 = packed % 3
|
|
|
d1 = (packed // 3) % 3
|
|
|
d2 = (packed // 9) % 3
|
|
|
d3 = (packed // 27) % 3
|
|
|
d4 = (packed // 81) % 3
|
|
|
|
|
|
|
|
|
base3 = torch.stack([d0, d1, d2, d3, d4], dim=1).flatten()
|
|
|
|
|
|
|
|
|
numel = 1
|
|
|
for dim in original_shape:
|
|
|
numel *= dim
|
|
|
|
|
|
|
|
|
base3 = base3[:numel]
|
|
|
|
|
|
|
|
|
W_ternary = base3.to(torch.float32) - 1
|
|
|
|
|
|
|
|
|
W_ternary = W_ternary.view(original_shape)
|
|
|
|
|
|
return W_ternary
|
|
|
|
|
|
|
|
|
def compute_compression_ratio(
|
|
|
original_size: int,
|
|
|
packed_size: int,
|
|
|
) -> float:
|
|
|
"""
|
|
|
Compute compression ratio for packed ternary weights.
|
|
|
|
|
|
Args:
|
|
|
original_size: Size in bytes of original float32 weights
|
|
|
packed_size: Size in bytes of packed ternary weights
|
|
|
|
|
|
Returns:
|
|
|
Compression ratio (e.g., 20.0 means 20x compression)
|
|
|
|
|
|
Examples:
|
|
|
>>> # 512 x 512 float32 weights = 512*512*4 bytes = 1,048,576 bytes
|
|
|
>>> # Packed: 512*512 ternary values / 5 per byte β 52,429 bytes
|
|
|
>>> ratio = compute_compression_ratio(1048576, 52429)
|
|
|
>>> print(f"Compression: {ratio:.1f}x")
|
|
|
Compression: 20.0x
|
|
|
"""
|
|
|
return original_size / packed_size if packed_size > 0 else 0.0
|
|
|
|
|
|
|
|
|
def estimate_memory_savings(
|
|
|
in_features: int,
|
|
|
out_features: int,
|
|
|
num_layers: int = 1,
|
|
|
) -> dict:
|
|
|
"""
|
|
|
Estimate memory savings from ternary packing for a given layer configuration.
|
|
|
|
|
|
Args:
|
|
|
in_features: Input dimension
|
|
|
out_features: Output dimension
|
|
|
num_layers: Number of layers (for cumulative savings)
|
|
|
|
|
|
Returns:
|
|
|
Dictionary with memory statistics:
|
|
|
- float32_bytes: Memory for float32 weights
|
|
|
- packed_bytes: Memory for packed ternary weights
|
|
|
- savings_bytes: Absolute memory saved
|
|
|
- compression_ratio: Ratio of compression
|
|
|
|
|
|
Examples:
|
|
|
>>> stats = estimate_memory_savings(768, 3072, num_layers=12)
|
|
|
>>> print(f"Total savings: {stats['savings_bytes'] / 1e6:.1f} MB")
|
|
|
"""
|
|
|
|
|
|
weights_per_layer = in_features * out_features
|
|
|
float32_bytes_per_layer = weights_per_layer * 4
|
|
|
|
|
|
|
|
|
packed_bytes_per_layer = (weights_per_layer + 4) // 5
|
|
|
|
|
|
|
|
|
float32_bytes = float32_bytes_per_layer * num_layers
|
|
|
packed_bytes = packed_bytes_per_layer * num_layers
|
|
|
|
|
|
|
|
|
savings_bytes = float32_bytes - packed_bytes
|
|
|
compression_ratio = compute_compression_ratio(float32_bytes, packed_bytes)
|
|
|
|
|
|
return {
|
|
|
'float32_bytes': float32_bytes,
|
|
|
'packed_bytes': packed_bytes,
|
|
|
'savings_bytes': savings_bytes,
|
|
|
'compression_ratio': compression_ratio,
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pack_ternary_bitwise(W_ternary: torch.Tensor) -> torch.Tensor:
|
|
|
"""
|
|
|
Alternative packing using 2 bits per ternary value.
|
|
|
|
|
|
Simpler but less efficient than base-3 packing:
|
|
|
-1 β 00
|
|
|
0 β 01
|
|
|
+1 β 10
|
|
|
|
|
|
This uses 2 bits per value (4 values per byte) instead of optimal 1.58 bits.
|
|
|
Easier to implement but 20% less efficient than base-3 packing.
|
|
|
|
|
|
TODO:
|
|
|
- Implement 2-bit packing scheme
|
|
|
- Compare with base-3 for speed vs. compression trade-off
|
|
|
"""
|
|
|
|
|
|
raise NotImplementedError("pack_ternary_bitwise not yet implemented")
|
|
|
|
|
|
|
|
|
def unpack_ternary_bitwise(packed: torch.Tensor, original_shape: Tuple[int, ...]) -> torch.Tensor:
|
|
|
"""
|
|
|
Unpack 2-bit encoded ternary weights.
|
|
|
|
|
|
TODO:
|
|
|
- Implement bitwise unpacking
|
|
|
"""
|
|
|
|
|
|
raise NotImplementedError("unpack_ternary_bitwise not yet implemented")
|
|
|
|