BitLinear / MODEL_CARD.md
krisaujla's picture
Upload folder using huggingface_hub
fd8c8b9 verified

Model Card: BitLinear

Model Description

BitLinear is a PyTorch implementation of ultra-low-precision (1.58-bit) ternary linear layers that can serve as drop-in replacements for nn.Linear in neural networks, particularly Transformers. It achieves ~19x memory compression while maintaining high output similarity.

Model Details

Intended Use

Primary Use Cases

  • Edge Deployment: Deploying large models on memory-constrained devices
  • Production Inference: Reducing memory footprint for serving large language models
  • Research: Exploring ultra-low-precision neural networks
  • Cost Optimization: Reducing cloud infrastructure costs through memory savings

Out-of-Scope Use Cases

  • Training from scratch (requires quantization-aware training)
  • Applications requiring exact numerical precision
  • Real-time applications where Python overhead is prohibitive (use C++/CUDA extensions)

How to Use

Basic Usage

import torch
from bitlinear import BitLinear

# Create a BitLinear layer (same interface as nn.Linear)
layer = BitLinear(in_features=512, out_features=1024, bias=True)

# Forward pass
x = torch.randn(32, 128, 512)
output = layer(x)  # Same as nn.Linear

Converting Existing Models

import torch.nn as nn
from bitlinear import convert_linear_to_bitlinear

# Convert a pre-trained model
model = nn.TransformerEncoderLayer(d_model=512, nhead=8)
model_compressed = convert_linear_to_bitlinear(model, inplace=False)

# Use as normal
x = torch.randn(10, 32, 512)
output = model_compressed(x)

Multi-Ternary for Better Accuracy

from bitlinear import MultiTernaryLinear

# Use k=3 components for 75% error reduction
layer = MultiTernaryLinear(in_features=512, out_features=1024, k=3)

Performance

Memory Compression

  • Average Compression: 19.23x (95% of theoretical 20x)
  • GPT-2 Small Example: 324 MB → 16.8 MB (307 MB saved)
Layer Size nn.Linear BitLinear (Packed) Compression
512×512 1.00 MB 0.05 MB 18.6x
1024×1024 4.00 MB 0.21 MB 19.3x
4096×4096 64.02 MB 3.23 MB 19.8x

Accuracy

  • Cosine Similarity: > 0.96 (96%+)
  • Relative Error: ~0.28 (28%)
  • Multi-Ternary (k=3): 75% error reduction vs k=1

Limitations

Known Limitations

  1. Accuracy Trade-off: Ternary quantization introduces approximation error (~3-5% typical)
  2. Training: Requires quantization-aware training (QAT) for optimal results
  3. Speed: Python implementation may be slower than nn.Linear (use C++/CUDA for production)
  4. Activation Quantization: Currently only weights are quantized (full BitNet includes activation quantization)

Recommendations

  • Fine-tune converted models for best accuracy
  • Use k≥2 for MultiTernaryLinear when accuracy is critical
  • Profile performance on your specific hardware
  • Test accuracy on your specific task before deployment

Training

Quantization-Aware Training (QAT)

For best results, fine-tune models with BitLinear layers:

# Convert pre-trained model
model_bit = convert_linear_to_bitlinear(pretrained_model)

# Fine-tune with standard training loop
optimizer = torch.optim.AdamW(model_bit.parameters(), lr=1e-4)
# ... train as normal ...

From Scratch Training

Training from scratch with ternary weights requires:

  • Careful initialization
  • Straight-through estimators for gradients
  • Potentially modified learning rates

See read/IMPLEMENTATION_GUIDE.md for details.

Technical Specifications

Architecture

  • Weight Quantization: Ternary {-1, 0, +1}
  • Scaling: Per-output-channel absmax scaling
  • Packing: Base-3 encoding (5 values per byte)
  • Decomposition: Greedy residual quantization for multi-ternary

Implementation

  • Python: Pure PyTorch baseline
  • C++: Optimized CPU kernels with PyBind11
  • CUDA: GPU kernels with warp-level reductions and shared memory tiling

Requirements

  • Python ≥ 3.8
  • PyTorch ≥ 2.0.0
  • NumPy ≥ 1.20.0
  • C++ compiler (for C++ extensions)
  • CUDA toolkit (optional, for GPU support)

Evaluation

Benchmarks

Comprehensive benchmarks available in BENCHMARKS.md:

  • Memory compression analysis
  • Forward pass timing
  • Accuracy metrics
  • Real-world transformer examples

Validation

All implementations validated against:

  • Unit tests (pytest suite)
  • Numerical correctness tests
  • Integration tests with Transformers
  • Cross-implementation consistency (Python vs C++)

Citation

If you use BitLinear in your research, please cite:

@article{jmlr_ternary_2024,
  title={Ternary Representations of Neural Networks},
  journal={Journal of Machine Learning Research},
  volume={26},
  year={2024},
  url={https://jmlr.org/papers/volume26/24-2050/24-2050.pdf}
}

@article{bitnet2023,
  title={BitNet: Scaling 1-bit Transformers for Large Language Models},
  author={Wang, Hongyu and Ma, Shuming and Dong, Li and Huang, Shaohan and Wang, Huaijie and Ma, Lingxiao and Yang, Fan and Wang, Ruiping and Wu, Yi and Wei, Furu},
  journal={arXiv preprint arXiv:2310.11453},
  year={2023}
}

Model Card Contact

For questions or issues, please open an issue on GitHub or contact the maintainers.

Glossary

  • Ternary Quantization: Representing weights with only three values {-1, 0, +1}
  • Absmax Scaling: Scaling factor computed as max(abs(weights))
  • Base-3 Packing: Encoding ternary values in base-3 for memory efficiency
  • Multi-Ternary: Sum of k ternary components for improved approximation
  • QAT: Quantization-Aware Training - training with quantization in the loop

More Information

  • Documentation: See README.md and read/ directory
  • Examples: See examples/ directory
  • Benchmarks: See BENCHMARKS.md
  • Implementation Guide: See read/IMPLEMENTATION_GUIDE.md