BitLinear / bitlinear /packing.py
krisaujla's picture
Upload folder using huggingface_hub
fd8c8b9 verified
"""
Base-3 packing utilities for memory-efficient ternary weight storage.
Ternary weights ({-1, 0, +1}) can be represented in base-3, allowing
multiple ternary values to be packed into a single byte or integer.
This provides significant memory savings over storing each value as a float32.
Theoretical packing:
- 1 ternary value requires log2(3) β‰ˆ 1.58 bits
- 5 ternary values fit in 1 byte (3^5 = 243 < 256)
- Compression ratio: 32 bits (float) β†’ ~1.6 bits (packed) = 20x compression
"""
import torch
from typing import Tuple
def pack_ternary_base3(W_ternary: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, ...]]:
"""
Pack ternary weights into base-3 representation for memory efficiency.
Packs multiple ternary values ({-1, 0, +1}) into uint8 storage using base-3
encoding. This achieves near-optimal compression for ternary data.
Encoding scheme:
-1 β†’ 0 (base 3)
0 β†’ 1 (base 3)
+1 β†’ 2 (base 3)
Then pack 5 base-3 digits into one byte:
packed_byte = d0 + d1*3 + d2*9 + d3*27 + d4*81
Args:
W_ternary: Ternary weight tensor with values in {-1, 0, +1}
Shape: [out_features, in_features] or [k, out_features, in_features]
Returns:
packed: Packed weights as uint8 tensor (5 values per byte)
original_shape: Shape of original tensor for unpacking
Notes:
- 5 ternary values per byte (3^5 = 243 < 256)
- Pad with zeros if dimensions not divisible by 5
- This is the primary memory optimization for ternary weights
"""
original_shape = tuple(W_ternary.shape)
# Map {-1, 0, 1} to {0, 1, 2}
base3 = (W_ternary + 1).flatten().to(torch.uint8)
# Pad to multiple of 5
numel = base3.numel()
pad_size = (5 - numel % 5) % 5
if pad_size > 0:
base3 = torch.cat([base3, torch.zeros(pad_size, dtype=torch.uint8, device=base3.device)])
# Reshape into groups of 5
base3 = base3.view(-1, 5)
# Pack each group: d0 + d1*3 + d2*9 + d3*27 + d4*81
powers_of_3 = torch.tensor([1, 3, 9, 27, 81], dtype=torch.uint8, device=base3.device)
packed = (base3 * powers_of_3).sum(dim=1)
return packed, original_shape
def unpack_ternary_base3(
packed: torch.Tensor,
original_shape: Tuple[int, ...],
) -> torch.Tensor:
"""
Unpack base-3 encoded ternary weights back to full representation.
Reverses the packing operation to recover ternary weights.
Args:
packed: Packed uint8 tensor (5 values per byte)
original_shape: Original shape of the ternary tensor
Returns:
W_ternary: Ternary weight tensor with values in {-1, 0, +1}
"""
# Extract 5 base-3 digits from each byte
d0 = packed % 3
d1 = (packed // 3) % 3
d2 = (packed // 9) % 3
d3 = (packed // 27) % 3
d4 = (packed // 81) % 3
# Stack digits
base3 = torch.stack([d0, d1, d2, d3, d4], dim=1).flatten()
# Compute original number of elements
numel = 1
for dim in original_shape:
numel *= dim
# Truncate padding
base3 = base3[:numel]
# Map {0, 1, 2} back to {-1, 0, +1}
W_ternary = base3.to(torch.float32) - 1
# Reshape to original shape
W_ternary = W_ternary.view(original_shape)
return W_ternary
def compute_compression_ratio(
original_size: int,
packed_size: int,
) -> float:
"""
Compute compression ratio for packed ternary weights.
Args:
original_size: Size in bytes of original float32 weights
packed_size: Size in bytes of packed ternary weights
Returns:
Compression ratio (e.g., 20.0 means 20x compression)
Examples:
>>> # 512 x 512 float32 weights = 512*512*4 bytes = 1,048,576 bytes
>>> # Packed: 512*512 ternary values / 5 per byte β‰ˆ 52,429 bytes
>>> ratio = compute_compression_ratio(1048576, 52429)
>>> print(f"Compression: {ratio:.1f}x")
Compression: 20.0x
"""
return original_size / packed_size if packed_size > 0 else 0.0
def estimate_memory_savings(
in_features: int,
out_features: int,
num_layers: int = 1,
) -> dict:
"""
Estimate memory savings from ternary packing for a given layer configuration.
Args:
in_features: Input dimension
out_features: Output dimension
num_layers: Number of layers (for cumulative savings)
Returns:
Dictionary with memory statistics:
- float32_bytes: Memory for float32 weights
- packed_bytes: Memory for packed ternary weights
- savings_bytes: Absolute memory saved
- compression_ratio: Ratio of compression
Examples:
>>> stats = estimate_memory_savings(768, 3072, num_layers=12)
>>> print(f"Total savings: {stats['savings_bytes'] / 1e6:.1f} MB")
"""
# Calculate float32 weight size
weights_per_layer = in_features * out_features
float32_bytes_per_layer = weights_per_layer * 4 # 4 bytes per float32
# Calculate packed size (5 ternary values per byte)
packed_bytes_per_layer = (weights_per_layer + 4) // 5 # Ceiling division
# Scale by number of layers
float32_bytes = float32_bytes_per_layer * num_layers
packed_bytes = packed_bytes_per_layer * num_layers
# Calculate savings
savings_bytes = float32_bytes - packed_bytes
compression_ratio = compute_compression_ratio(float32_bytes, packed_bytes)
return {
'float32_bytes': float32_bytes,
'packed_bytes': packed_bytes,
'savings_bytes': savings_bytes,
'compression_ratio': compression_ratio,
}
# Advanced packing schemes (for future optimization for which ill do later)
def pack_ternary_bitwise(W_ternary: torch.Tensor) -> torch.Tensor:
"""
Alternative packing using 2 bits per ternary value.
Simpler but less efficient than base-3 packing:
-1 β†’ 00
0 β†’ 01
+1 β†’ 10
This uses 2 bits per value (4 values per byte) instead of optimal 1.58 bits.
Easier to implement but 20% less efficient than base-3 packing.
TODO:
- Implement 2-bit packing scheme
- Compare with base-3 for speed vs. compression trade-off
"""
# TODO: Implement bitwise packing (future optimization)
raise NotImplementedError("pack_ternary_bitwise not yet implemented")
def unpack_ternary_bitwise(packed: torch.Tensor, original_shape: Tuple[int, ...]) -> torch.Tensor:
"""
Unpack 2-bit encoded ternary weights.
TODO:
- Implement bitwise unpacking
"""
# TODO: Implement bitwise unpacking
raise NotImplementedError("unpack_ternary_bitwise not yet implemented")