Model Card: BitLinear
Model Description
BitLinear is a PyTorch implementation of ultra-low-precision (1.58-bit) ternary linear layers that can serve as drop-in replacements for nn.Linear in neural networks, particularly Transformers. It achieves ~19x memory compression while maintaining high output similarity.
Model Details
- Developed by: BitLinear Contributors
- Model type: Quantization / Compression
- Language: Python, C++, CUDA
- License: MIT
- Repository: https://github.com/yourusername/bitlinear
Intended Use
Primary Use Cases
- Edge Deployment: Deploying large models on memory-constrained devices
- Production Inference: Reducing memory footprint for serving large language models
- Research: Exploring ultra-low-precision neural networks
- Cost Optimization: Reducing cloud infrastructure costs through memory savings
Out-of-Scope Use Cases
- Training from scratch (requires quantization-aware training)
- Applications requiring exact numerical precision
- Real-time applications where Python overhead is prohibitive (use C++/CUDA extensions)
How to Use
Basic Usage
import torch
from bitlinear import BitLinear
# Create a BitLinear layer (same interface as nn.Linear)
layer = BitLinear(in_features=512, out_features=1024, bias=True)
# Forward pass
x = torch.randn(32, 128, 512)
output = layer(x) # Same as nn.Linear
Converting Existing Models
import torch.nn as nn
from bitlinear import convert_linear_to_bitlinear
# Convert a pre-trained model
model = nn.TransformerEncoderLayer(d_model=512, nhead=8)
model_compressed = convert_linear_to_bitlinear(model, inplace=False)
# Use as normal
x = torch.randn(10, 32, 512)
output = model_compressed(x)
Multi-Ternary for Better Accuracy
from bitlinear import MultiTernaryLinear
# Use k=3 components for 75% error reduction
layer = MultiTernaryLinear(in_features=512, out_features=1024, k=3)
Performance
Memory Compression
- Average Compression: 19.23x (95% of theoretical 20x)
- GPT-2 Small Example: 324 MB → 16.8 MB (307 MB saved)
| Layer Size | nn.Linear | BitLinear (Packed) | Compression |
|---|---|---|---|
| 512×512 | 1.00 MB | 0.05 MB | 18.6x |
| 1024×1024 | 4.00 MB | 0.21 MB | 19.3x |
| 4096×4096 | 64.02 MB | 3.23 MB | 19.8x |
Accuracy
- Cosine Similarity: > 0.96 (96%+)
- Relative Error: ~0.28 (28%)
- Multi-Ternary (k=3): 75% error reduction vs k=1
Limitations
Known Limitations
- Accuracy Trade-off: Ternary quantization introduces approximation error (~3-5% typical)
- Training: Requires quantization-aware training (QAT) for optimal results
- Speed: Python implementation may be slower than nn.Linear (use C++/CUDA for production)
- Activation Quantization: Currently only weights are quantized (full BitNet includes activation quantization)
Recommendations
- Fine-tune converted models for best accuracy
- Use k≥2 for MultiTernaryLinear when accuracy is critical
- Profile performance on your specific hardware
- Test accuracy on your specific task before deployment
Training
Quantization-Aware Training (QAT)
For best results, fine-tune models with BitLinear layers:
# Convert pre-trained model
model_bit = convert_linear_to_bitlinear(pretrained_model)
# Fine-tune with standard training loop
optimizer = torch.optim.AdamW(model_bit.parameters(), lr=1e-4)
# ... train as normal ...
From Scratch Training
Training from scratch with ternary weights requires:
- Careful initialization
- Straight-through estimators for gradients
- Potentially modified learning rates
See read/IMPLEMENTATION_GUIDE.md for details.
Technical Specifications
Architecture
- Weight Quantization: Ternary {-1, 0, +1}
- Scaling: Per-output-channel absmax scaling
- Packing: Base-3 encoding (5 values per byte)
- Decomposition: Greedy residual quantization for multi-ternary
Implementation
- Python: Pure PyTorch baseline
- C++: Optimized CPU kernels with PyBind11
- CUDA: GPU kernels with warp-level reductions and shared memory tiling
Requirements
- Python ≥ 3.8
- PyTorch ≥ 2.0.0
- NumPy ≥ 1.20.0
- C++ compiler (for C++ extensions)
- CUDA toolkit (optional, for GPU support)
Evaluation
Benchmarks
Comprehensive benchmarks available in BENCHMARKS.md:
- Memory compression analysis
- Forward pass timing
- Accuracy metrics
- Real-world transformer examples
Validation
All implementations validated against:
- Unit tests (pytest suite)
- Numerical correctness tests
- Integration tests with Transformers
- Cross-implementation consistency (Python vs C++)
Citation
If you use BitLinear in your research, please cite:
@article{jmlr_ternary_2024,
title={Ternary Representations of Neural Networks},
journal={Journal of Machine Learning Research},
volume={26},
year={2024},
url={https://jmlr.org/papers/volume26/24-2050/24-2050.pdf}
}
@article{bitnet2023,
title={BitNet: Scaling 1-bit Transformers for Large Language Models},
author={Wang, Hongyu and Ma, Shuming and Dong, Li and Huang, Shaohan and Wang, Huaijie and Ma, Lingxiao and Yang, Fan and Wang, Ruiping and Wu, Yi and Wei, Furu},
journal={arXiv preprint arXiv:2310.11453},
year={2023}
}
Model Card Contact
For questions or issues, please open an issue on GitHub or contact the maintainers.
Glossary
- Ternary Quantization: Representing weights with only three values {-1, 0, +1}
- Absmax Scaling: Scaling factor computed as max(abs(weights))
- Base-3 Packing: Encoding ternary values in base-3 for memory efficiency
- Multi-Ternary: Sum of k ternary components for improved approximation
- QAT: Quantization-Aware Training - training with quantization in the loop
More Information
- Documentation: See
README.mdandread/directory - Examples: See
examples/directory - Benchmarks: See
BENCHMARKS.md - Implementation Guide: See
read/IMPLEMENTATION_GUIDE.md