DeepSeek Multi-Latent Attention

This repository provides a PyTorch implementation of the Multi-Head Latent Attention (MLA) mechanism introduced in the DeepSeek-V2 paper. This is not a trained model, but rather a modular attention implementation that significantly reduces KV cache for efficient inference while maintaining model performance through its innovative architecture. It can be used as a drop-in attention module in transformer architectures.

Key Features

  • Low-Rank Key-Value Joint Compression: Reduces memory footprint during inference
  • Decoupled Rotary Position Embedding: Enables efficient position-aware attention
  • Optimized Cache Management: Handles both compressed KV states and rotary embeddings
  • Cross-Attention Support: Works for both self-attention and cross-attention scenarios

Installation

Clone this repository:

git clone https://huggingface.co/bird-of-paradise/deepseek-mla

Or download directly from the HuggingFace repository page.

Quick Start

import torch
from src.mla import MultiHeadLatentAttention

# Initialize MLA
mla = MultiHeadLatentAttention(
    d_model=512,      # Model dimension
    num_head=8,       # Number of attention heads
    d_embed=512,      # Embedding dimension
    d_c=64,          # KV compression dimension
    d_c1=64,         # Query compression dimension
    d_rotate=32,     # Rotary embedding dimension
)

# Input sequence
x = torch.randn(2, 10, 512)  # [batch_size, seq_len, d_model]

# Forward pass
output = mla(x)

Testing

To run the test suite, execute the following command from the project root directory:

python -m src.tests.test_mla

Architecture Details

MLA Architecture

MLA combines two key innovations:

  1. Low-rank compression pathway for efficient KV caching
  2. Decoupled position-aware pathway using RoPE

For detailed architectural insights, see insights/architecture.md.

Caching Behavior

During inference, MLA maintains two caches:

cache_kv: [batch, max_len, d_c]    # Compressed KV states
cache_rk: [batch, max_len, d_r]    # Shared rotary key

For detailed insights on attention masking and caching, see insights/attention_mask.md.

Usage Examples

Basic Attention

# Standard self-attention
output = mla(sequence)

# Cross-attention
output = mla(query, key_value_states=context)

Cached Generation

# Initial forward pass
output = mla(prompt, use_cache=True, start_pos=0)

# Generate tokens using cache
for i in range(max_new_tokens):
    output = mla(next_token, use_cache=True, start_pos=prompt_len + i)

Implementation Details

The implementation closely follows the formulation in the DeepSeek-V2 paper:

MLA Formulas

Key aspects:

  • Separate compression pathways for queries and key-values
  • Position encoding through decoupled RoPE pathway
  • Efficient cache management for both pathways

Contributing

Contributions are welcome! Feel free to:

  • Report bugs and issues
  • Submit pull requests for improvements
  • Add additional test cases
  • Provide documentation clarifications

Please ensure all tests pass before submitting pull requests.

Citation

@misc{deepseek2024,
    title={DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model}, 
    author={DeepSeek-AI and et al.},
    year={2024},
    journal={arXiv preprint arXiv:2405.04434}
}

License

MIT License


license: mit

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The HF Inference API does not support text-generation models for deepseek-mla library.