Text Generation
Transformers
PyTorch
English
shram
research
sparse-attention
mixture-of-experts
custom_code
Instructions to use smithblack-0/SHRAM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use smithblack-0/SHRAM with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="smithblack-0/SHRAM", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("smithblack-0/SHRAM", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use smithblack-0/SHRAM with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "smithblack-0/SHRAM" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/smithblack-0/SHRAM
- SGLang
How to use smithblack-0/SHRAM with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use smithblack-0/SHRAM with Docker Model Runner:
docker model run hf.co/smithblack-0/SHRAM
Update architecture and tokenizer
Browse files- README.md +103 -0
- __attention__bottlenecked_ensemble_attention.py +252 -0
- __attention__expert_packing.py +335 -0
- __attention__load_balance_loss.py +88 -0
- __attention__mosrah.py +140 -0
- __attention__positions_converter.py +105 -0
- __attention__router.py +162 -0
- __attention__shram.py +116 -0
- __attention__sliding_window_attention.py +233 -0
- __cache__mosrah_cache.py +359 -0
- __cache__shram_cache.py +141 -0
- __cache__shram_layer_cache.py +233 -0
- __cache__sliding_window_cache.py +289 -0
- __cache__slow_mosrah_cache.py +321 -0
- __init__.py +21 -0
- config.json +28 -0
- configuration.py +199 -0
- decoder_layer.py +87 -0
- huggingface.py +518 -0
- mlp.py +52 -0
- model.py +132 -0
- rope.py +291 -0
- tokenizer.json +0 -0
- tokenizer_config.json +13 -0
README.md
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
license: mit
|
| 5 |
+
library_name: transformers
|
| 6 |
+
pipeline_tag: text-generation
|
| 7 |
+
tags:
|
| 8 |
+
- pytorch
|
| 9 |
+
- research
|
| 10 |
+
- sparse-attention
|
| 11 |
+
- mixture-of-experts
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# SHRAM — Sparse Hybrid Token Routed Attention Mixture
|
| 15 |
+
|
| 16 |
+
A research baseline implementing the SHRAM architecture from "An Examination of Sparse
|
| 17 |
+
Attention for Long Context Purposes." No pretrained weights — pull the architecture from
|
| 18 |
+
the Hub and instantiate a freshly initialised model from config. Every parameter is
|
| 19 |
+
overridable at instantiation time via kwargs.
|
| 20 |
+
|
| 21 |
+
> **Important:** `trust_remote_code=True` is required. It downloads the architecture
|
| 22 |
+
> source files from the Hub and imports them into your Python process. Review the
|
| 23 |
+
> source at [smithblack-0/SHRAM](https://huggingface.co/smithblack-0/SHRAM) before use.
|
| 24 |
+
|
| 25 |
+
## Architecture
|
| 26 |
+
|
| 27 |
+
SHRAM replaces every standard attention layer with a hybrid layer `H(x) = h_l(x) + h_s(x)`:
|
| 28 |
+
|
| 29 |
+
- **h_l** — local sliding-window causal attention path.
|
| 30 |
+
- **h_s** — MoSRAH sparse routed path. Each token selects K of L available expert heads
|
| 31 |
+
via token-choice routing. Bottlenecked Ensemble Attention (BEA) is applied per head.
|
| 32 |
+
|
| 33 |
+
All other components follow the Llama 3 baseline (RMSNorm, SwiGLU FFN, RoPE).
|
| 34 |
+
|
| 35 |
+
## Usage
|
| 36 |
+
|
| 37 |
+
This repository contains no pretrained weights. The intended workflow is: pull the
|
| 38 |
+
architecture config from the Hub, instantiate a model with fresh random weights, then
|
| 39 |
+
train it yourself.
|
| 40 |
+
|
| 41 |
+
```python
|
| 42 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
| 43 |
+
|
| 44 |
+
# Step 1: pull the architecture config from the Hub.
|
| 45 |
+
# AutoConfig.from_pretrained downloads config.json only — no weights are loaded.
|
| 46 |
+
# Override any parameter via kwargs.
|
| 47 |
+
config = AutoConfig.from_pretrained(
|
| 48 |
+
"smithblack-0/SHRAM",
|
| 49 |
+
trust_remote_code=True,
|
| 50 |
+
num_hidden_layers=16, # example override
|
| 51 |
+
num_mosrah_heads=32, # example override
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# Step 2: instantiate with fresh random weights.
|
| 55 |
+
# from_config never loads a checkpoint — it always produces a randomly initialised model.
|
| 56 |
+
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
| 57 |
+
|
| 58 |
+
# Step 3: load the tokenizer.
|
| 59 |
+
tokenizer = AutoTokenizer.from_pretrained("smithblack-0/SHRAM")
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
After training your own checkpoint, save and reload it in the standard way:
|
| 63 |
+
|
| 64 |
+
```python
|
| 65 |
+
model.save_pretrained("./my-checkpoint")
|
| 66 |
+
model = AutoModelForCausalLM.from_pretrained("./my-checkpoint", trust_remote_code=True)
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
## Constructor Defaults
|
| 70 |
+
|
| 71 |
+
The values below are the defaults you get if you call `AutoConfig.from_pretrained` with
|
| 72 |
+
no overrides. They are not the parameters of a pretrained model — this repository
|
| 73 |
+
contains no weights. All values are overridable via kwargs.
|
| 74 |
+
|
| 75 |
+
| Parameter | Default |
|
| 76 |
+
|-----------|---------|
|
| 77 |
+
| `alpha` | 1.0 |
|
| 78 |
+
| `attention_dropout` | 0.0 |
|
| 79 |
+
| `beta` | 32.0 |
|
| 80 |
+
| `dtype` | None |
|
| 81 |
+
| `head_dim` | 16 |
|
| 82 |
+
| `hidden_size` | 512 |
|
| 83 |
+
| `inference_sequence_length` | 1024 |
|
| 84 |
+
| `intermediate_size` | 1366 |
|
| 85 |
+
| `local_rope_theta` | 10000.0 |
|
| 86 |
+
| `mosrah_rope_theta` | 10000.0 |
|
| 87 |
+
| `num_hidden_layers` | 12 |
|
| 88 |
+
| `num_mosrah_heads` | 16 |
|
| 89 |
+
| `num_selected_heads` | 16 |
|
| 90 |
+
| `num_sliding_window_heads` | 16 |
|
| 91 |
+
| `output_hidden_states` | False |
|
| 92 |
+
| `rms_norm_eps` | 1e-05 |
|
| 93 |
+
| `rope_mode` | main_sequence |
|
| 94 |
+
| `tie_word_embeddings` | False |
|
| 95 |
+
| `training_sequence_length` | 1024 |
|
| 96 |
+
| `use_cache` | True |
|
| 97 |
+
| `vocab_size` | 50277 |
|
| 98 |
+
| `window_size` | 128 |
|
| 99 |
+
|
| 100 |
+
## License
|
| 101 |
+
|
| 102 |
+
MIT. Clean-room synthesis informed by the reference paper. Tokenizer is GPT-NeoX
|
| 103 |
+
(`EleutherAI/gpt-neox-20b`, Apache 2.0).
|
__attention__bottlenecked_ensemble_attention.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Bottlenecked Ensemble Attention (BEA) for the MoSRAH sparse path.
|
| 2 |
+
|
| 3 |
+
BEA is the packed expert-choice attention operator over the MoSRAH sparse path.
|
| 4 |
+
It consumes packed expert-choice tensors, a supplied position tensor, an active
|
| 5 |
+
token mask, and an optional layer-local MoSRAH cache. It returns outputs in the
|
| 6 |
+
same packed expert-choice space expected by later unpacking.
|
| 7 |
+
|
| 8 |
+
BEA does not compute positions and does not choose packed-position semantics.
|
| 9 |
+
Those are supplied by the caller. If caching is used, BEA stores post-RoPE keys
|
| 10 |
+
(K̃) and raw values (V) into the sparse cache and attends against the
|
| 11 |
+
accumulated cached state returned by that cache.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import math
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
| 19 |
+
|
| 20 |
+
from .configuration import ShramConfig
|
| 21 |
+
from .__cache__mosrah_cache import MoSRAHCache
|
| 22 |
+
from .rope import RotaryEmbedding
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class BottleneckedEnsembleAttention(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
Packed expert-choice attention operator for the MoSRAH sparse path.
|
| 28 |
+
Operates per-head independently on an ensemble of tokens.
|
| 29 |
+
FlexAttention saves flops on dead tokens.
|
| 30 |
+
|
| 31 |
+
Architectural properties:
|
| 32 |
+
- consumes packed expert-choice tensors of shape (B, L, T, d)
|
| 33 |
+
- uses independent per-head Q/K/V/O projection parameters
|
| 34 |
+
- applies YaRN-capable RoPE using supplied position_ids
|
| 35 |
+
- stores post-RoPE K̃ and raw V in MoSRAHCache when caching is enabled
|
| 36 |
+
- uses a fast fused attention path
|
| 37 |
+
- returns outputs in the same packed expert-choice space (B, L, T, d)
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
config: SHRAM config. Must expose `hidden_size`, `num_mosrah_heads`,
|
| 41 |
+
`head_dim`, `mosrah_rope_theta`, `training_sequence_length`,
|
| 42 |
+
`inference_sequence_length`, `alpha`, and `beta`.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, config: ShramConfig) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
self.hidden_size = config.hidden_size
|
| 49 |
+
self.num_heads = config.num_mosrah_heads
|
| 50 |
+
self.head_dim = config.head_dim
|
| 51 |
+
|
| 52 |
+
# Independent per-head projections. No cross-head parameter sharing.
|
| 53 |
+
self.q_proj = nn.Parameter(
|
| 54 |
+
torch.empty(self.num_heads, self.hidden_size, self.head_dim)
|
| 55 |
+
)
|
| 56 |
+
self.k_proj = nn.Parameter(
|
| 57 |
+
torch.empty(self.num_heads, self.hidden_size, self.head_dim)
|
| 58 |
+
)
|
| 59 |
+
self.v_proj = nn.Parameter(
|
| 60 |
+
torch.empty(self.num_heads, self.hidden_size, self.head_dim)
|
| 61 |
+
)
|
| 62 |
+
self.o_proj = nn.Parameter(
|
| 63 |
+
torch.empty(self.num_heads, self.head_dim, self.hidden_size)
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self._reset_parameters()
|
| 67 |
+
|
| 68 |
+
# BEA uses the YaRN-capable RoPE path. The caller supplies the position tensor;
|
| 69 |
+
# this unit only consumes it. In training modes, dilation will be 1.0 and so
|
| 70 |
+
# no yarn dilation occurs.
|
| 71 |
+
self.rope = RotaryEmbedding(
|
| 72 |
+
mode="yarn",
|
| 73 |
+
head_dim=self.head_dim,
|
| 74 |
+
theta=config.mosrah_rope_theta,
|
| 75 |
+
initial_seq_length=config.training_sequence_length,
|
| 76 |
+
dilation=config.scale,
|
| 77 |
+
alpha=config.alpha,
|
| 78 |
+
beta=config.beta,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
def forward(
|
| 82 |
+
self,
|
| 83 |
+
packed_embeddings: torch.Tensor,
|
| 84 |
+
position_ids: torch.Tensor,
|
| 85 |
+
active_mask: torch.Tensor,
|
| 86 |
+
cache: MoSRAHCache | None = None,
|
| 87 |
+
) -> torch.Tensor:
|
| 88 |
+
"""Apply BEA to packed expert-choice tensors.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
packed_embeddings: Packed expert-choice hidden states of shape (B, L, T, d).
|
| 92 |
+
position_ids: Supplied packed positions of shape (B, L, T).
|
| 93 |
+
active_mask: Boolean active-token mask of shape (B, L, T).
|
| 94 |
+
cache: Optional layer-local MoSRAH cache.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Packed expert-choice output tensor of shape (B, L, T, d).
|
| 98 |
+
"""
|
| 99 |
+
batch_size, _, query_length, _ = packed_embeddings.shape
|
| 100 |
+
self._validate_tensor_shape(packed_embeddings)
|
| 101 |
+
self._validate_position_shape(packed_embeddings, position_ids)
|
| 102 |
+
self._validate_active_mask_shape(packed_embeddings, active_mask)
|
| 103 |
+
|
| 104 |
+
# Independent per-head projections:
|
| 105 |
+
# (B, L, T, d) x (L, d, u) -> (B, L, T, u)
|
| 106 |
+
query_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.q_proj)
|
| 107 |
+
key_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.k_proj)
|
| 108 |
+
value_states = torch.einsum("bltd,ldu->bltu", packed_embeddings, self.v_proj)
|
| 109 |
+
|
| 110 |
+
rotated_query_states, rotated_key_states, attention_scaling = self.rope(
|
| 111 |
+
query_states,
|
| 112 |
+
key_states,
|
| 113 |
+
position_ids,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if cache is not None:
|
| 117 |
+
# In cached execution, the current query tensor uses local tensor rows
|
| 118 |
+
# 0..Q-1, but the key tensor returned by the cache is the full accumulated
|
| 119 |
+
# packed sequence for each (batch, head) slot. The only additional data
|
| 120 |
+
# needed to align those two views is the pre-update cached prefix length.
|
| 121 |
+
# which will indicate how many queries were processed before now.
|
| 122 |
+
num_tokens_processed = cache.get_heads_lengths().clone()
|
| 123 |
+
key_states, value_states, key_active_mask = cache.update(
|
| 124 |
+
rotated_key_states,
|
| 125 |
+
value_states,
|
| 126 |
+
active_mask,
|
| 127 |
+
)
|
| 128 |
+
else:
|
| 129 |
+
num_tokens_processed = torch.zeros(
|
| 130 |
+
batch_size,
|
| 131 |
+
self.num_heads,
|
| 132 |
+
dtype=torch.long,
|
| 133 |
+
device=packed_embeddings.device,
|
| 134 |
+
)
|
| 135 |
+
key_states = rotated_key_states
|
| 136 |
+
key_active_mask = active_mask
|
| 137 |
+
|
| 138 |
+
block_mask = self._make_block_mask(
|
| 139 |
+
query_active_mask=active_mask,
|
| 140 |
+
key_active_mask=key_active_mask,
|
| 141 |
+
num_tokens_processed=num_tokens_processed,
|
| 142 |
+
query_length=query_length,
|
| 143 |
+
key_length=key_states.shape[2],
|
| 144 |
+
device=packed_embeddings.device,
|
| 145 |
+
)
|
| 146 |
+
attended_states = flex_attention(
|
| 147 |
+
rotated_query_states,
|
| 148 |
+
key_states,
|
| 149 |
+
value_states,
|
| 150 |
+
block_mask=block_mask,
|
| 151 |
+
scale=attention_scaling / math.sqrt(self.head_dim),
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Project back to model width:
|
| 155 |
+
# (B, L, T, u) x (L, u, d) -> (B, L, T, d)
|
| 156 |
+
return torch.einsum("bltu,lud->bltd", attended_states, self.o_proj)
|
| 157 |
+
|
| 158 |
+
def _reset_parameters(self) -> None:
|
| 159 |
+
"""Initialize per-head projection weights."""
|
| 160 |
+
for weight in (self.q_proj, self.k_proj, self.v_proj, self.o_proj):
|
| 161 |
+
nn.init.xavier_uniform_(weight)
|
| 162 |
+
|
| 163 |
+
def _validate_tensor_shape(self, packed_embeddings: torch.Tensor) -> None:
|
| 164 |
+
"""Validate the local packed-embedding shape contract required by BEA."""
|
| 165 |
+
if packed_embeddings.shape[1] != self.num_heads:
|
| 166 |
+
raise ValueError(
|
| 167 |
+
f"Expected packed_embeddings.shape[1] == num_mosrah_heads={self.num_heads}, "
|
| 168 |
+
f"got {packed_embeddings.shape[1]}."
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
if packed_embeddings.shape[-1] != self.hidden_size:
|
| 172 |
+
raise ValueError(
|
| 173 |
+
f"Expected packed_embeddings last dim == hidden_size={self.hidden_size}, "
|
| 174 |
+
f"got {packed_embeddings.shape[-1]}."
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
def _validate_position_shape(
|
| 178 |
+
self,
|
| 179 |
+
packed_embeddings: torch.Tensor,
|
| 180 |
+
position_ids: torch.Tensor,
|
| 181 |
+
) -> None:
|
| 182 |
+
"""Validate the supplied packed-position tensor shape."""
|
| 183 |
+
if position_ids.shape != packed_embeddings.shape[:3]:
|
| 184 |
+
raise ValueError(
|
| 185 |
+
f"position_ids must have shape {tuple(packed_embeddings.shape[:3])}, "
|
| 186 |
+
f"got {tuple(position_ids.shape)}."
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
def _validate_active_mask_shape(
|
| 190 |
+
self,
|
| 191 |
+
packed_embeddings: torch.Tensor,
|
| 192 |
+
active_mask: torch.Tensor,
|
| 193 |
+
) -> None:
|
| 194 |
+
"""Validate the supplied active-token mask shape."""
|
| 195 |
+
if active_mask.shape != packed_embeddings.shape[:3]:
|
| 196 |
+
raise ValueError(
|
| 197 |
+
f"active_mask must have shape {tuple(packed_embeddings.shape[:3])}, "
|
| 198 |
+
f"got {tuple(active_mask.shape)}."
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def _make_block_mask(
|
| 202 |
+
self,
|
| 203 |
+
query_active_mask: torch.Tensor,
|
| 204 |
+
key_active_mask: torch.Tensor,
|
| 205 |
+
num_tokens_processed: torch.Tensor,
|
| 206 |
+
query_length: int,
|
| 207 |
+
key_length: int,
|
| 208 |
+
device: torch.device,
|
| 209 |
+
):
|
| 210 |
+
"""Create the packed-sequence causal mask for FlexAttention.
|
| 211 |
+
|
| 212 |
+
At the root, causality is still triangular. The only nuance is cached
|
| 213 |
+
execution: query rows are indexed locally as 0..Q-1 inside the current
|
| 214 |
+
query tensor, but the key tensor may already contain a cached prefix for
|
| 215 |
+
that (batch, head) slot. The causal horizon for query tensor row q is
|
| 216 |
+
therefore:
|
| 217 |
+
|
| 218 |
+
cached_prefix_lengths[b, h] + q
|
| 219 |
+
|
| 220 |
+
Query and key activity masks are then composed with that triangular rule
|
| 221 |
+
so FlexAttention can skip padded query rows and ignore inactive key slots.
|
| 222 |
+
"""
|
| 223 |
+
batch_size, num_heads, _ = query_active_mask.shape
|
| 224 |
+
|
| 225 |
+
# Build the per-(batch, head, query_row) triangular horizon from a simple
|
| 226 |
+
# arange over query rows plus the cached prefix lengths for each slot.
|
| 227 |
+
relative_query_positions = torch.arange(
|
| 228 |
+
query_length,
|
| 229 |
+
device=device,
|
| 230 |
+
dtype=torch.long,
|
| 231 |
+
).view(1, 1, query_length)
|
| 232 |
+
causal_query_positions = num_tokens_processed.unsqueeze(-1) + relative_query_positions
|
| 233 |
+
|
| 234 |
+
def packed_causal_mask(
|
| 235 |
+
batch_idx: torch.Tensor,
|
| 236 |
+
head_idx: torch.Tensor,
|
| 237 |
+
query_idx: torch.Tensor,
|
| 238 |
+
key_idx: torch.Tensor,
|
| 239 |
+
) -> torch.Tensor:
|
| 240 |
+
query_is_active = query_active_mask[batch_idx, head_idx, query_idx]
|
| 241 |
+
key_is_active = key_active_mask[batch_idx, head_idx, key_idx]
|
| 242 |
+
is_causal = key_idx <= causal_query_positions[batch_idx, head_idx, query_idx]
|
| 243 |
+
return query_is_active & key_is_active & is_causal
|
| 244 |
+
|
| 245 |
+
return create_block_mask(
|
| 246 |
+
packed_causal_mask,
|
| 247 |
+
B=batch_size,
|
| 248 |
+
H=num_heads,
|
| 249 |
+
Q_LEN=query_length,
|
| 250 |
+
KV_LEN=key_length,
|
| 251 |
+
device=device,
|
| 252 |
+
)
|
__attention__expert_packing.py
ADDED
|
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Expert packing and unpacking for the MoSRAH path.
|
| 2 |
+
|
| 3 |
+
This module implements the low-level token-choice -> expert-choice -> token-choice
|
| 4 |
+
conversion boundary specified in the paper. The externally visible behavior is fixed:
|
| 5 |
+
|
| 6 |
+
- setup_packing() prepares the auxiliary ordering data.
|
| 7 |
+
- pack_experts() converts routed token-choice state into packed expert-choice state.
|
| 8 |
+
- unpack_experts() restores token-choice ordering afterward.
|
| 9 |
+
|
| 10 |
+
Stable sort is a correctness requirement. It preserves causal ordering inside each
|
| 11 |
+
expert bucket, which is the foundation on which BEA's later triangular causal mask
|
| 12 |
+
is correct.
|
| 13 |
+
|
| 14 |
+
pack_experts() returns two distinct masks that serve different roles and must not
|
| 15 |
+
be interchanged:
|
| 16 |
+
|
| 17 |
+
- unpacking_mask: marks every packed slot that contains a routed token copy,
|
| 18 |
+
live or dead. Always has exactly B*N*K True entries. Required by unpack_experts
|
| 19 |
+
so its reshape invariant holds regardless of outer token liveness.
|
| 20 |
+
- active_mask: marks only the packed slots whose source token was semantically
|
| 21 |
+
live. This is what BEA consumes for attention gating. Dead outer tokens must
|
| 22 |
+
not influence sparse attention outputs.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# Setup
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
def setup_packing(
|
| 33 |
+
selected_heads: torch.Tensor,
|
| 34 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 35 |
+
"""Prepare the auxiliary ordering data used by pack/unpack.
|
| 36 |
+
|
| 37 |
+
Routing produces token-choice state I of shape (B, N, K): for each token, which
|
| 38 |
+
K experts were selected. Packing needs the same routed token copies reordered into
|
| 39 |
+
expert-major order so each expert bucket becomes contiguous.
|
| 40 |
+
|
| 41 |
+
The paper's setup step does this by flattening (N, K) into one axis to produce
|
| 42 |
+
H in token-major order, then computing a stable argsort permutation Pi over the
|
| 43 |
+
expert indices stored in H. Applying Pi reorders the flattened routed copies into
|
| 44 |
+
expert-major order while preserving their original token order *within* each expert
|
| 45 |
+
bucket. That preservation is why stable sort is required for causality.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
selected_heads: Routed token-choice head selections I of shape (B, N, K).
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Tuple of:
|
| 52 |
+
- flattened_selected_heads: H of shape (B, N*K)
|
| 53 |
+
- permutation: stable expert-major permutation Pi of shape (B, N*K)
|
| 54 |
+
- inverse_permutation: inverse permutation Pi^{-1} of shape (B, N*K)
|
| 55 |
+
"""
|
| 56 |
+
batch_size, sequence_length, num_selected_heads = selected_heads.shape
|
| 57 |
+
flattened_selected_heads = selected_heads.reshape(
|
| 58 |
+
batch_size,
|
| 59 |
+
sequence_length * num_selected_heads,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
permutation = torch.argsort(flattened_selected_heads, dim=-1, stable=True)
|
| 63 |
+
inverse_permutation = torch.argsort(permutation, dim=-1)
|
| 64 |
+
|
| 65 |
+
return flattened_selected_heads, permutation, inverse_permutation
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# ---------------------------------------------------------------------------
|
| 69 |
+
# Packing
|
| 70 |
+
# ---------------------------------------------------------------------------
|
| 71 |
+
|
| 72 |
+
def pack_experts(
|
| 73 |
+
hidden_states: torch.Tensor,
|
| 74 |
+
position_ids: torch.Tensor,
|
| 75 |
+
selected_heads: torch.Tensor,
|
| 76 |
+
num_experts: int,
|
| 77 |
+
flattened_selected_heads: torch.Tensor,
|
| 78 |
+
permutation: torch.Tensor,
|
| 79 |
+
outer_active_mask: torch.Tensor,
|
| 80 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 81 |
+
"""Pack token-choice hidden states into expert-choice padded form.
|
| 82 |
+
|
| 83 |
+
The paper's packing path has two jobs:
|
| 84 |
+
|
| 85 |
+
1. Convert routed token-choice copies into expert-major order.
|
| 86 |
+
2. Materialize that expert-major order into a padded tensor layout BEA can consume.
|
| 87 |
+
|
| 88 |
+
The routed hidden-state copies are not stored explicitly in token-choice form.
|
| 89 |
+
Instead, the same token hidden state is conceptually copied once per selected expert.
|
| 90 |
+
The packing step reconstructs those copies by expanding local source-token indices,
|
| 91 |
+
reordering those indices with Pi, then gathering hidden states, positions, and outer
|
| 92 |
+
liveness in that packed order. All three are carried through the same expert-major
|
| 93 |
+
rearrangement so they remain aligned in the packed frame.
|
| 94 |
+
|
| 95 |
+
Packed positions are sourced from the authoritative upstream position_ids tensor
|
| 96 |
+
rather than synthesized locally from arange(N). This preserves advanced positions
|
| 97 |
+
correctly during cached inference while leaving training/full-sequence behavior
|
| 98 |
+
unchanged when position_ids is the ordinary sequential token positions.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
hidden_states: Token-choice hidden states x of shape (B, N, d).
|
| 102 |
+
position_ids: Authoritative upstream token positions J of shape (B, N).
|
| 103 |
+
selected_heads: Routed head selections I of shape (B, N, K).
|
| 104 |
+
num_experts: Total number of experts L.
|
| 105 |
+
flattened_selected_heads: H from setup_packing(), shape (B, N*K).
|
| 106 |
+
permutation: Pi from setup_packing(), shape (B, N*K).
|
| 107 |
+
outer_active_mask: Current-chunk active mask of shape (B, N), where True
|
| 108 |
+
means the token is semantically live. Dead tokens do not become
|
| 109 |
+
semantically active in the packed sparse representation.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Tuple of:
|
| 113 |
+
- packed_hidden_states: x' of shape (B, L, T, d)
|
| 114 |
+
- packed_positions: J' of shape (B, L, T)
|
| 115 |
+
- unpacking_mask: of shape (B, L, T). True where a slot contains any
|
| 116 |
+
routed token copy, live or dead. Always has exactly B*N*K True entries.
|
| 117 |
+
Pass this to unpack_experts — not active_mask.
|
| 118 |
+
- active_mask: of shape (B, L, T). True only where a slot contains a
|
| 119 |
+
copy of a live outer token. Pass this to BEA for attention gating.
|
| 120 |
+
"""
|
| 121 |
+
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
| 122 |
+
_, _, num_selected_heads = selected_heads.shape
|
| 123 |
+
|
| 124 |
+
# -----------------------------------------------------------------------
|
| 125 |
+
# Reconstruct routed local source-token indices in token-choice order.
|
| 126 |
+
#
|
| 127 |
+
# The internal arange(N) is no longer the packed position tensor. It is only
|
| 128 |
+
# the local source-row index object used to gather from the current chunk
|
| 129 |
+
# tensor x. Flattening this object gives a (B, N*K) tensor aligned with H's
|
| 130 |
+
# token-major routed-copy order.
|
| 131 |
+
# -----------------------------------------------------------------------
|
| 132 |
+
source_token_indices = torch.arange(
|
| 133 |
+
sequence_length,
|
| 134 |
+
device=hidden_states.device,
|
| 135 |
+
dtype=torch.long,
|
| 136 |
+
).view(1, sequence_length, 1).expand(
|
| 137 |
+
batch_size,
|
| 138 |
+
sequence_length,
|
| 139 |
+
num_selected_heads,
|
| 140 |
+
)
|
| 141 |
+
flattened_source_indices = source_token_indices.reshape(
|
| 142 |
+
batch_size,
|
| 143 |
+
sequence_length * num_selected_heads,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# -----------------------------------------------------------------------
|
| 147 |
+
# Reorder source-token indices into expert-major order.
|
| 148 |
+
#
|
| 149 |
+
# Applying Pi yields the local source-token rows in the packed expert-major
|
| 150 |
+
# order required by the paper. Those same reordered source indices are then
|
| 151 |
+
# used to gather hidden states, authoritative upstream positions, and outer
|
| 152 |
+
# liveness so all three remain aligned under the exact same packing
|
| 153 |
+
# transformation.
|
| 154 |
+
# -----------------------------------------------------------------------
|
| 155 |
+
sorted_source_indices = flattened_source_indices.gather(
|
| 156 |
+
dim=1,
|
| 157 |
+
index=permutation,
|
| 158 |
+
)
|
| 159 |
+
sorted_hidden_states = hidden_states.gather(
|
| 160 |
+
dim=1,
|
| 161 |
+
index=sorted_source_indices.unsqueeze(-1).expand(-1, -1, hidden_dim),
|
| 162 |
+
)
|
| 163 |
+
sorted_positions = position_ids.gather(
|
| 164 |
+
dim=1,
|
| 165 |
+
index=sorted_source_indices,
|
| 166 |
+
)
|
| 167 |
+
sorted_active_mask = outer_active_mask.gather(
|
| 168 |
+
dim=1,
|
| 169 |
+
index=sorted_source_indices,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# -----------------------------------------------------------------------
|
| 173 |
+
# Count how many routed copies land in each expert bucket.
|
| 174 |
+
#
|
| 175 |
+
# S[b, l] is the number of routed token copies assigned to expert l in batch b.
|
| 176 |
+
# T is the maximum such count across all batches and experts; it determines the
|
| 177 |
+
# padded expert-length dimension of the packed representation.
|
| 178 |
+
# -----------------------------------------------------------------------
|
| 179 |
+
tokens_per_expert = _bincount_rows(
|
| 180 |
+
values=flattened_selected_heads,
|
| 181 |
+
num_bins=num_experts,
|
| 182 |
+
)
|
| 183 |
+
max_tokens_per_expert = int(tokens_per_expert.max().item())
|
| 184 |
+
|
| 185 |
+
# -----------------------------------------------------------------------
|
| 186 |
+
# Construct the active-token mask M.
|
| 187 |
+
#
|
| 188 |
+
# Each expert bucket is left-justified: if S[b, l] = s, then slots
|
| 189 |
+
# t = 0, ..., s-1 are active and all later slots are padding. The resulting
|
| 190 |
+
# mask therefore both identifies real packed tokens and enforces left-justified
|
| 191 |
+
# packing. This is the unpacking_mask — it marks slot occupancy regardless of
|
| 192 |
+
# outer token liveness, and always has exactly B*N*K True entries.
|
| 193 |
+
# -----------------------------------------------------------------------
|
| 194 |
+
time_axis = torch.arange(
|
| 195 |
+
max_tokens_per_expert,
|
| 196 |
+
device=hidden_states.device,
|
| 197 |
+
dtype=torch.long,
|
| 198 |
+
).view(1, 1, max_tokens_per_expert)
|
| 199 |
+
unpacking_mask = time_axis < tokens_per_expert.unsqueeze(-1)
|
| 200 |
+
|
| 201 |
+
# -----------------------------------------------------------------------
|
| 202 |
+
# Materialize the padded packed tensors.
|
| 203 |
+
#
|
| 204 |
+
# The packed hidden states x', packed original-token positions J', and packed
|
| 205 |
+
# active-token mask are allocated as zero-filled tensors. Active entries are
|
| 206 |
+
# then written into those buffers in the expert-major order established above.
|
| 207 |
+
# Padding remains zero / inactive.
|
| 208 |
+
# -----------------------------------------------------------------------
|
| 209 |
+
packed_hidden_states = hidden_states.new_zeros(
|
| 210 |
+
batch_size,
|
| 211 |
+
num_experts,
|
| 212 |
+
max_tokens_per_expert,
|
| 213 |
+
hidden_dim,
|
| 214 |
+
)
|
| 215 |
+
packed_positions = position_ids.new_zeros(
|
| 216 |
+
batch_size,
|
| 217 |
+
num_experts,
|
| 218 |
+
max_tokens_per_expert,
|
| 219 |
+
)
|
| 220 |
+
active_mask = torch.zeros(
|
| 221 |
+
batch_size,
|
| 222 |
+
num_experts,
|
| 223 |
+
max_tokens_per_expert,
|
| 224 |
+
dtype=torch.bool,
|
| 225 |
+
device=hidden_states.device,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
packed_hidden_states[unpacking_mask] = sorted_hidden_states.reshape(-1, hidden_dim)
|
| 229 |
+
packed_positions[unpacking_mask] = sorted_positions.reshape(-1)
|
| 230 |
+
active_mask[unpacking_mask] = sorted_active_mask.reshape(-1)
|
| 231 |
+
|
| 232 |
+
return packed_hidden_states, packed_positions, unpacking_mask, active_mask
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# ---------------------------------------------------------------------------
|
| 236 |
+
# Unpacking
|
| 237 |
+
# ---------------------------------------------------------------------------
|
| 238 |
+
|
| 239 |
+
def unpack_experts(
|
| 240 |
+
expert_outputs: torch.Tensor,
|
| 241 |
+
selected_heads: torch.Tensor,
|
| 242 |
+
unpacking_mask: torch.Tensor,
|
| 243 |
+
inverse_permutation: torch.Tensor,
|
| 244 |
+
) -> torch.Tensor:
|
| 245 |
+
"""Restore token-choice ordering from BEA expert-choice output.
|
| 246 |
+
|
| 247 |
+
Unpacking inverts the packing path only on occupied entries. Padding does not
|
| 248 |
+
participate: the output tensor is first filtered by unpacking_mask to recover
|
| 249 |
+
only the real routed-token copies in expert-major order, then Pi^{-1} restores
|
| 250 |
+
the original token-choice ordering, and finally the tensor is reshaped back to
|
| 251 |
+
(B, N, K, d).
|
| 252 |
+
|
| 253 |
+
The unpacking_mask — not active_mask — must be used here. Even copies of dead
|
| 254 |
+
outer tokens occupy slots and must be un-scattered correctly for the inverse
|
| 255 |
+
permutation to hold. The total True entry count in unpacking_mask is always
|
| 256 |
+
B*N*K, which is exactly what the reshape to (B, N*K, d) requires.
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
expert_outputs: Expert-choice BEA output y of shape (B, L, T, d).
|
| 260 |
+
selected_heads: Routed head selections I of shape (B, N, K).
|
| 261 |
+
unpacking_mask: From pack_experts(), shape (B, L, T). Identifies all
|
| 262 |
+
occupied packed slots regardless of outer token liveness.
|
| 263 |
+
inverse_permutation: Pi^{-1} from setup_packing(), shape (B, N*K).
|
| 264 |
+
|
| 265 |
+
Returns:
|
| 266 |
+
Restored token-choice tensor y_tilde of shape (B, N, K, d).
|
| 267 |
+
"""
|
| 268 |
+
batch_size, sequence_length, num_selected_heads = selected_heads.shape
|
| 269 |
+
hidden_dim = expert_outputs.shape[-1]
|
| 270 |
+
|
| 271 |
+
active_outputs = expert_outputs[unpacking_mask]
|
| 272 |
+
sorted_token_choice_outputs = active_outputs.reshape(
|
| 273 |
+
batch_size,
|
| 274 |
+
sequence_length * num_selected_heads,
|
| 275 |
+
hidden_dim,
|
| 276 |
+
)
|
| 277 |
+
restored_outputs = sorted_token_choice_outputs.gather(
|
| 278 |
+
dim=1,
|
| 279 |
+
index=inverse_permutation.unsqueeze(-1).expand(-1, -1, hidden_dim),
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
return restored_outputs.reshape(
|
| 283 |
+
batch_size,
|
| 284 |
+
sequence_length,
|
| 285 |
+
num_selected_heads,
|
| 286 |
+
hidden_dim,
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# ---------------------------------------------------------------------------
|
| 291 |
+
# Helpers
|
| 292 |
+
# ---------------------------------------------------------------------------
|
| 293 |
+
|
| 294 |
+
def _bincount_rows(
|
| 295 |
+
values: torch.Tensor,
|
| 296 |
+
num_bins: int,
|
| 297 |
+
) -> torch.Tensor:
|
| 298 |
+
"""Count per-row integer occurrences for a 2D tensor.
|
| 299 |
+
|
| 300 |
+
torch.bincount operates on a flat 1D vector, but the packing algorithm needs
|
| 301 |
+
one bincount per batch row. The trick used here is to shift each row into its
|
| 302 |
+
own disjoint bin range before flattening:
|
| 303 |
+
|
| 304 |
+
row 0 uses bins [0, ..., num_bins - 1]
|
| 305 |
+
row 1 uses bins [num_bins, ..., 2*num_bins - 1]
|
| 306 |
+
row 2 uses bins [2*num_bins, ..., 3*num_bins - 1]
|
| 307 |
+
...
|
| 308 |
+
|
| 309 |
+
After that shift, one global torch.bincount produces all row-local counts at
|
| 310 |
+
once. Reshaping the result back to (B, num_bins) recovers the per-row bincount.
|
| 311 |
+
|
| 312 |
+
This is a vectorized implementation detail only; externally visible behavior
|
| 313 |
+
remains exactly the paper's S tensor of per-batch per-expert token counts.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
values: Integer tensor of shape (B, M) with entries in [0, num_bins).
|
| 317 |
+
num_bins: Number of bins.
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
Counts tensor of shape (B, num_bins).
|
| 321 |
+
"""
|
| 322 |
+
batch_size = values.shape[0]
|
| 323 |
+
|
| 324 |
+
row_offsets = torch.arange(
|
| 325 |
+
batch_size,
|
| 326 |
+
device=values.device,
|
| 327 |
+
dtype=values.dtype,
|
| 328 |
+
).unsqueeze(1) * num_bins
|
| 329 |
+
shifted_values = values + row_offsets
|
| 330 |
+
|
| 331 |
+
counts = torch.bincount(
|
| 332 |
+
shifted_values.reshape(-1),
|
| 333 |
+
minlength=batch_size * num_bins,
|
| 334 |
+
)
|
| 335 |
+
return counts.reshape(batch_size, num_bins)
|
__attention__load_balance_loss.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Auxiliary-loss-free load balancing operator for MoSRAH routing.
|
| 2 |
+
|
| 3 |
+
This module implements the custom autograd Function H(b, f) described in the paper's
|
| 4 |
+
Implementation Concerns section. The operator bridges two requirements that are in
|
| 5 |
+
tension: it must behave like a standard auxiliary loss (scalar output, scalable via
|
| 6 |
+
multiplication) so that existing training loops remain compatible, while simultaneously
|
| 7 |
+
implementing DeepSeek-style bias correction rather than the usual auxiliary-loss gradient
|
| 8 |
+
path through the router weights.
|
| 9 |
+
|
| 10 |
+
The resolution is a custom backward pass. The forward emits the load balance imbalance
|
| 11 |
+
as a scalar loss. The backward, instead of differentiating that scalar with respect to
|
| 12 |
+
its inputs, writes a bias-correction gradient directly to expert_bias. This gradient is
|
| 13 |
+
then consumed by the main AdamW optimizer in the normal way, achieving DeepSeek-style
|
| 14 |
+
correction without a standalone SGD update step.
|
| 15 |
+
|
| 16 |
+
Paper ref: Appendix A.Implementation Concerns.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class LoadBalanceLoss(torch.autograd.Function):
|
| 23 |
+
"""Custom autograd operator for DeepSeek-style auxiliary-loss-free load balancing.
|
| 24 |
+
|
| 25 |
+
Forward computes the load balance imbalance:
|
| 26 |
+
|
| 27 |
+
L_load_balance = H(b, f) = sum_l | f_l - 1/L |
|
| 28 |
+
|
| 29 |
+
Backward emits a bias-correction gradient to expert_bias:
|
| 30 |
+
|
| 31 |
+
grad_b = L_grad * sign(f_l - 1/L)
|
| 32 |
+
|
| 33 |
+
expert_bias (b) is included as a forward input so PyTorch registers it as a node
|
| 34 |
+
in the computation graph and routes gradients through it. routing_freqs (f) receives
|
| 35 |
+
no gradient — its origin is the discrete TopK operation which has no gradient, so
|
| 36 |
+
defining a gradient for f here would be mathematically incorrect.
|
| 37 |
+
|
| 38 |
+
Paper ref: Appendix A.Implementation Concerns.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
@staticmethod
|
| 42 |
+
def forward(
|
| 43 |
+
ctx: torch.autograd.function.FunctionCtx,
|
| 44 |
+
expert_bias: torch.Tensor,
|
| 45 |
+
routing_freqs: torch.Tensor,
|
| 46 |
+
) -> torch.Tensor:
|
| 47 |
+
"""Compute the load balance loss.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
ctx: Autograd context for saving state needed in backward.
|
| 51 |
+
expert_bias: Learned per-head bias b, shape (L,). Included as an input so
|
| 52 |
+
PyTorch tracks it as a computation graph node needing a gradient.
|
| 53 |
+
routing_freqs: Realized routing frequency f_l per head, shape (L,). Computed
|
| 54 |
+
from the discrete TopK selection — not differentiable.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Scalar loss equal to sum_l |f_l - 1/L|.
|
| 58 |
+
"""
|
| 59 |
+
L = expert_bias.shape[0]
|
| 60 |
+
# imbalance = f_l - 1/L for each head: positive means overloaded, negative means
|
| 61 |
+
# underloaded. Saved for backward where sign(imbalance) determines the direction
|
| 62 |
+
# of the bias-correction update.
|
| 63 |
+
imbalance = routing_freqs - 1.0 / L
|
| 64 |
+
ctx.save_for_backward(imbalance)
|
| 65 |
+
return imbalance.abs().sum()
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def backward(
|
| 69 |
+
ctx: torch.autograd.function.FunctionCtx,
|
| 70 |
+
grad_output: torch.Tensor,
|
| 71 |
+
) -> tuple[torch.Tensor, None]:
|
| 72 |
+
"""Emit the DeepSeek-style bias-correction gradient.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
ctx: Autograd context carrying imbalance saved in forward.
|
| 76 |
+
grad_output: Incoming gradient L_grad (scalar). Any rescaling of the loss
|
| 77 |
+
by the training loop arrives here and is propagated to grad_b, so the
|
| 78 |
+
correction magnitude is proportional to the loss weight chosen by the
|
| 79 |
+
consumer.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Gradient for expert_bias: L_grad * sign(f_l - 1/L), shape (L,).
|
| 83 |
+
None for routing_freqs: no gradient is defined for the discrete routing
|
| 84 |
+
frequency.
|
| 85 |
+
"""
|
| 86 |
+
(imbalance,) = ctx.saved_tensors
|
| 87 |
+
grad_expert_bias = grad_output * imbalance.sign()
|
| 88 |
+
return grad_expert_bias, None
|
__attention__mosrah.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Full MoSRAH sparse path for SHRAM.
|
| 2 |
+
|
| 3 |
+
This module coordinates the routed sparse attention path used inside the SHRAM
|
| 4 |
+
hybrid attention layer. The underlying mechanics already live in verified
|
| 5 |
+
subunits. The responsibility here is to connect those subunits without
|
| 6 |
+
corrupting their bridge contracts.
|
| 7 |
+
|
| 8 |
+
In particular, this path must preserve three architectural distinctions:
|
| 9 |
+
|
| 10 |
+
- selected head indices are not routing probabilities
|
| 11 |
+
- packed position semantics are chosen before BEA, not inside it
|
| 12 |
+
- weighted reduction must consume the router's unbiased renormalized
|
| 13 |
+
probabilities after token-choice order has been restored
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn
|
| 18 |
+
|
| 19 |
+
from .__cache__mosrah_cache import MoSRAHCache
|
| 20 |
+
from .configuration import ShramConfig
|
| 21 |
+
from .__attention__bottlenecked_ensemble_attention import BottleneckedEnsembleAttention
|
| 22 |
+
from .__attention__expert_packing import (
|
| 23 |
+
pack_experts,
|
| 24 |
+
setup_packing,
|
| 25 |
+
unpack_experts,
|
| 26 |
+
)
|
| 27 |
+
from .__attention__router import MoSRAHRouter
|
| 28 |
+
from .__attention__positions_converter import SparseMoSRAHPositions
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class MoSRAHLayer(nn.Module):
|
| 32 |
+
"""Full routed sparse attention path for SHRAM.
|
| 33 |
+
|
| 34 |
+
The MoSRAH path consumes model-space hidden states together with
|
| 35 |
+
authoritative per-token positions and returns the model-space sparse-path
|
| 36 |
+
contribution, the router's load-balance loss, and the router's MaxVio
|
| 37 |
+
routing-imbalance scalar.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, config: ShramConfig) -> None:
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.num_experts = config.num_mosrah_heads
|
| 43 |
+
|
| 44 |
+
self.router = MoSRAHRouter(config)
|
| 45 |
+
self.positions = SparseMoSRAHPositions(config)
|
| 46 |
+
self.bea = BottleneckedEnsembleAttention(config)
|
| 47 |
+
|
| 48 |
+
def forward(
|
| 49 |
+
self,
|
| 50 |
+
hidden_states: torch.Tensor,
|
| 51 |
+
position_ids: torch.Tensor,
|
| 52 |
+
active_mask: torch.Tensor,
|
| 53 |
+
cache: MoSRAHCache | None,
|
| 54 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 55 |
+
"""Run the full MoSRAH sparse path.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
hidden_states: Model-space hidden states x of shape (B, N, d).
|
| 59 |
+
position_ids: Authoritative per-token positions of shape (B, N).
|
| 60 |
+
active_mask: Current-chunk active mask of shape (B, N), where True
|
| 61 |
+
means the token is semantically live. Forwarded to the router
|
| 62 |
+
so dead tokens are excluded from routing statistics, and to
|
| 63 |
+
pack_experts so dead outer tokens do not become semantically
|
| 64 |
+
active packed entries.
|
| 65 |
+
cache: Optional layer-local MoSRAH cache. Pass None for uncached
|
| 66 |
+
execution and the layer-local cache instance for cached execution.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
sparse_output: Model-space sparse-path output of shape (B, N, d).
|
| 70 |
+
load_balance_loss: Scalar router load-balance loss.
|
| 71 |
+
max_vio: Detached scalar routing-imbalance summary. Passed through
|
| 72 |
+
unchanged from the router; see MoSRAHRouter for semantics.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
# -------------------------------------------------------------------
|
| 76 |
+
# The first transition moves from model-space token-choice input into
|
| 77 |
+
# the packed expert-choice sparse-attention state. Routing decides both
|
| 78 |
+
# which experts each token uses and which unbiased probabilities must be
|
| 79 |
+
# reserved for the final reduction. The active mask is forwarded to the
|
| 80 |
+
# router so dead tokens are excluded from routing statistics, and to
|
| 81 |
+
# pack_experts so outer liveness is faithfully carried into the packed
|
| 82 |
+
# frame. Packing returns both the unpacking mask (slot occupancy, always
|
| 83 |
+
# B*N*K True entries) and the packed active mask (live slots only);
|
| 84 |
+
# active_mask is rebound to the packed form after this point.
|
| 85 |
+
# -------------------------------------------------------------------
|
| 86 |
+
selected_heads, routing_probs, load_balance_loss, max_vio = self.router(
|
| 87 |
+
hidden_states, active_mask
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
flattened_selected_heads, permutation, inverse_permutation = setup_packing(
|
| 91 |
+
selected_heads
|
| 92 |
+
)
|
| 93 |
+
packed_hidden_states, packed_positions, unpacking_mask, active_mask = pack_experts(
|
| 94 |
+
hidden_states=hidden_states,
|
| 95 |
+
position_ids=position_ids,
|
| 96 |
+
selected_heads=selected_heads,
|
| 97 |
+
num_experts=self.num_experts,
|
| 98 |
+
flattened_selected_heads=flattened_selected_heads,
|
| 99 |
+
permutation=permutation,
|
| 100 |
+
outer_active_mask=active_mask,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# -------------------------------------------------------------------
|
| 104 |
+
# Sparse attention runs entirely in the packed expert-choice frame, so
|
| 105 |
+
# the RoPE position semantics must also be chosen in that frame. The
|
| 106 |
+
# position layer therefore decides whether BEA should see packed
|
| 107 |
+
# original-token positions or packed local-slot positions. BEA then
|
| 108 |
+
# consumes that packed position tensor together with the packed hidden
|
| 109 |
+
# states and the layer-local sparse cache, which it owns directly.
|
| 110 |
+
# -------------------------------------------------------------------
|
| 111 |
+
bea_positions = self.positions(
|
| 112 |
+
packed_positions=packed_positions,
|
| 113 |
+
cache=cache,
|
| 114 |
+
)
|
| 115 |
+
packed_outputs = self.bea(
|
| 116 |
+
packed_embeddings=packed_hidden_states,
|
| 117 |
+
position_ids=bea_positions,
|
| 118 |
+
active_mask=active_mask,
|
| 119 |
+
cache=cache,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# -------------------------------------------------------------------
|
| 123 |
+
# The final transition restores token-choice meaning and only then
|
| 124 |
+
# collapses the K routed copies back into model space. This ordering is
|
| 125 |
+
# required because routing_probs live in token-choice space, whereas BEA
|
| 126 |
+
# returns expert-choice packed outputs. The reduction must therefore
|
| 127 |
+
# happen after unpacking, and it must use the router's unbiased
|
| 128 |
+
# renormalized probabilities rather than any biased selection scores.
|
| 129 |
+
# -------------------------------------------------------------------
|
| 130 |
+
token_choice_outputs = unpack_experts(
|
| 131 |
+
expert_outputs=packed_outputs,
|
| 132 |
+
selected_heads=selected_heads,
|
| 133 |
+
unpacking_mask=unpacking_mask,
|
| 134 |
+
inverse_permutation=inverse_permutation,
|
| 135 |
+
)
|
| 136 |
+
final_output = (
|
| 137 |
+
token_choice_outputs * routing_probs.unsqueeze(-1)
|
| 138 |
+
).sum(dim=2)
|
| 139 |
+
|
| 140 |
+
return final_output, load_balance_loss, max_vio
|
__attention__positions_converter.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Position computation for the MoSRAH sparse path.
|
| 2 |
+
|
| 3 |
+
This layer computes the packed position tensor P consumed by BEA.
|
| 4 |
+
|
| 5 |
+
- In main-sequence mode, P is the packed original-token position tensor from the
|
| 6 |
+
packing path.
|
| 7 |
+
- In semantic-sequence mode, P is a per-expert local sequence over the packed
|
| 8 |
+
expert-choice layout, optionally offset by the current sparse-cache occupancies
|
| 9 |
+
during cached inference.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from torch import nn
|
| 14 |
+
|
| 15 |
+
from .configuration import ShramConfig
|
| 16 |
+
from .__cache__mosrah_cache import MoSRAHCache
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SparseMoSRAHPositions(nn.Module):
|
| 20 |
+
"""Compute the packed RoPE position tensor for the MoSRAH sparse path.
|
| 21 |
+
|
| 22 |
+
This layer operates in the packed expert-choice frame used by BEA. The input
|
| 23 |
+
packed_positions tensor is always the packed original-token position tensor
|
| 24 |
+
produced by the packing path. The configured rope_mode determines whether that
|
| 25 |
+
tensor is forwarded directly or replaced by a semantic local-slot sequence.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, config: ShramConfig) -> None:
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.rope_mode = config.rope_mode
|
| 31 |
+
|
| 32 |
+
def forward(
|
| 33 |
+
self,
|
| 34 |
+
packed_positions: torch.Tensor,
|
| 35 |
+
cache: MoSRAHCache | None,
|
| 36 |
+
) -> torch.Tensor:
|
| 37 |
+
"""Compute the packed position tensor P consumed by BEA.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
packed_positions: Packed original-token positions J' of shape (B, L, T).
|
| 41 |
+
cache: Optional layer-local MoSRAH cache. When present in semantic-sequence
|
| 42 |
+
mode, the current per-head occupancies offset the local packed sequence.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
Packed position tensor P of shape (B, L, T).
|
| 46 |
+
"""
|
| 47 |
+
if self.rope_mode == "main_sequence":
|
| 48 |
+
return self._main_sequence_positions(packed_positions)
|
| 49 |
+
|
| 50 |
+
if self.rope_mode == "semantic_sequence":
|
| 51 |
+
return self._semantic_sequence_positions(packed_positions, cache)
|
| 52 |
+
|
| 53 |
+
raise NotImplementedError(
|
| 54 |
+
f"Unsupported MoSRAH rope_mode '{self.rope_mode}'."
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
def _main_sequence_positions(
|
| 58 |
+
self,
|
| 59 |
+
packed_positions: torch.Tensor,
|
| 60 |
+
) -> torch.Tensor:
|
| 61 |
+
"""Forward packed original-token positions unchanged."""
|
| 62 |
+
return packed_positions
|
| 63 |
+
|
| 64 |
+
def _semantic_sequence_positions(
|
| 65 |
+
self,
|
| 66 |
+
packed_positions: torch.Tensor,
|
| 67 |
+
cache: MoSRAHCache | None,
|
| 68 |
+
) -> torch.Tensor:
|
| 69 |
+
"""Compute semantic-sequence packed positions in expert-choice space.
|
| 70 |
+
|
| 71 |
+
Without a sparse cache, semantic positions are the local packed sequence
|
| 72 |
+
0, 1, 2, ... over the expert-local T dimension. With a sparse cache, that
|
| 73 |
+
same local sequence is offset by the current per-(batch, expert) occupancies
|
| 74 |
+
returned by get_heads_lengths().
|
| 75 |
+
"""
|
| 76 |
+
batch_size, num_experts, packed_length = packed_positions.shape
|
| 77 |
+
|
| 78 |
+
# -------------------------------------------------------------------
|
| 79 |
+
# Construct the local packed sequence 0, 1, 2, ... over the expert-local
|
| 80 |
+
# sequence dimension T. This is then broadcast across batch and experts.
|
| 81 |
+
# -------------------------------------------------------------------
|
| 82 |
+
local_positions = torch.arange(
|
| 83 |
+
packed_length,
|
| 84 |
+
device=packed_positions.device,
|
| 85 |
+
dtype=packed_positions.dtype,
|
| 86 |
+
).view(1, 1, packed_length).expand(
|
| 87 |
+
batch_size,
|
| 88 |
+
num_experts,
|
| 89 |
+
packed_length,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# -------------------------------------------------------------------
|
| 93 |
+
# In cached semantic-sequence mode, positions continue from the current
|
| 94 |
+
# sparse-cache occupancies rather than restarting at zero for the local
|
| 95 |
+
# chunk.
|
| 96 |
+
# -------------------------------------------------------------------
|
| 97 |
+
if cache is None:
|
| 98 |
+
return local_positions
|
| 99 |
+
|
| 100 |
+
cached_lengths = cache.get_heads_lengths().to(
|
| 101 |
+
device=packed_positions.device,
|
| 102 |
+
dtype=packed_positions.dtype,
|
| 103 |
+
).unsqueeze(-1)
|
| 104 |
+
|
| 105 |
+
return local_positions + cached_lengths
|
__attention__router.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Token-choice router for the MoSRAH sparse attention path.
|
| 2 |
+
|
| 3 |
+
This module implements the routing mechanism described in Appendix A.Routing of the
|
| 4 |
+
paper. Given an input hidden state x, the router produces two outputs used downstream:
|
| 5 |
+
|
| 6 |
+
- selected_heads (I): which K of the L available expert heads each token routes to,
|
| 7 |
+
determined by TopK over biased routing scores.
|
| 8 |
+
- routing_probs (P): the weights used for the weighted output reduction, gathered from
|
| 9 |
+
*unbiased* routing scores at the selected indices and renormalized. The learned expert
|
| 10 |
+
bias b must not influence P.
|
| 11 |
+
|
| 12 |
+
This separation is architecturally critical: expert_bias drives selection (and thus load
|
| 13 |
+
balancing) but does not corrupt the gradient path from the output through routing_probs
|
| 14 |
+
back to the routing projection weights.
|
| 15 |
+
|
| 16 |
+
The router also computes and returns the load balance loss via the LoadBalanceLoss custom
|
| 17 |
+
autograd operator (see load_balance_loss.py). This loss is a scalar that the training
|
| 18 |
+
loop can weight and add to the language modeling loss.
|
| 19 |
+
|
| 20 |
+
The router additionally computes and returns MaxVio, a detached scalar summarising
|
| 21 |
+
routing imbalance for the current forward pass:
|
| 22 |
+
|
| 23 |
+
MaxVio = L · max_l(f_l − 1/L)
|
| 24 |
+
|
| 25 |
+
where f_l is the realised routing frequency of head l and 1/L is the perfectly balanced
|
| 26 |
+
target. MaxVio is a monitoring quantity only; it never contributes gradients.
|
| 27 |
+
|
| 28 |
+
Paper ref: Appendix A.Routing, Appendix A.Load Balancing, §MaxVio.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn as nn
|
| 33 |
+
import torch.nn.functional as F
|
| 34 |
+
|
| 35 |
+
from .configuration import ShramConfig
|
| 36 |
+
from .__attention__load_balance_loss import LoadBalanceLoss
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MoSRAHRouter(nn.Module):
|
| 40 |
+
"""Token-choice router for MoSRAH sparse attention.
|
| 41 |
+
|
| 42 |
+
Each input token independently selects K of the L available expert heads. Selection
|
| 43 |
+
is driven by biased routing scores to enable load balancing, but the routing
|
| 44 |
+
probabilities used for output reduction are computed from unbiased scores so that
|
| 45 |
+
the expert bias does not interfere with the gradient path to the router weights.
|
| 46 |
+
|
| 47 |
+
The routing projection W_r has no bias term — the paper specifies xW_r with no
|
| 48 |
+
additional projection bias. The only bias-like parameter is expert_bias (b), which
|
| 49 |
+
has an entirely separate role and update mechanism.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
config: Model configuration. Must expose ``hidden_size``, ``num_mosrah_heads``
|
| 53 |
+
(L), and ``num_selected_heads`` (K).
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, config: ShramConfig) -> None:
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.num_mosrah_heads = config.num_mosrah_heads
|
| 59 |
+
self.num_selected_heads = config.num_selected_heads
|
| 60 |
+
|
| 61 |
+
# W_r: routing projection, no bias (paper specifies xW_r, no additional term).
|
| 62 |
+
self.routing_projection = nn.Linear(
|
| 63 |
+
config.hidden_size, config.num_mosrah_heads, bias=False
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# b: learned per-head bias for load balancing. Initialized to zero so that all
|
| 67 |
+
# heads start with equal selection probability. Updated by the main optimizer
|
| 68 |
+
# via the LoadBalanceLoss custom backward.
|
| 69 |
+
self.expert_bias = nn.Parameter(torch.zeros(config.num_mosrah_heads))
|
| 70 |
+
|
| 71 |
+
def forward(
|
| 72 |
+
self, x: torch.Tensor, active_mask: torch.Tensor
|
| 73 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 74 |
+
"""Route input tokens to K expert heads each and compute routing probabilities.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
x: Input hidden states of shape (batch, seq_len, hidden_size).
|
| 78 |
+
active_mask: Current-chunk active mask of shape (batch, seq_len), where
|
| 79 |
+
True means the token is semantically live. Dead tokens do not
|
| 80 |
+
contribute to routing frequencies, load_balance_loss, or max_vio.
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
selected_heads: Head indices I of shape (batch, seq_len, num_selected_heads).
|
| 84 |
+
Each token's K selected head indices, determined by TopK on biased scores.
|
| 85 |
+
routing_probs: Routing probabilities P of shape (batch, seq_len,
|
| 86 |
+
num_selected_heads). Gathered from unbiased scores at selected_heads
|
| 87 |
+
indices and renormalized to sum to 1 per token.
|
| 88 |
+
load_balance_loss: Scalar load balance imbalance loss for this forward pass.
|
| 89 |
+
Training loop scales this by a weight and adds it to the main loss.
|
| 90 |
+
max_vio: Detached scalar routing-imbalance summary for this forward pass.
|
| 91 |
+
Equal to L · max_l(f_l − 1/L). Zero means perfect balance. Not a loss;
|
| 92 |
+
never contributes gradients.
|
| 93 |
+
"""
|
| 94 |
+
B, N, _ = x.shape
|
| 95 |
+
L = self.num_mosrah_heads
|
| 96 |
+
K = self.num_selected_heads
|
| 97 |
+
|
| 98 |
+
# Unbiased routing scores R = Softmax(xW_r). These are the scores used to
|
| 99 |
+
# compute routing_probs — expert_bias must not influence them.
|
| 100 |
+
logits = self.routing_projection(x) # (B, N, L)
|
| 101 |
+
routing_scores = F.softmax(logits, dim=-1) # R, (B, N, L)
|
| 102 |
+
|
| 103 |
+
# Biased routing scores R̂ = Softmax(xW_r + b). Used only for TopK head
|
| 104 |
+
# selection. expert_bias is added to logits before softmax so that the bias
|
| 105 |
+
# shifts selection probability without rescaling the unbiased distribution.
|
| 106 |
+
biased_routing_scores = F.softmax( # R̂, (B, N, L)
|
| 107 |
+
logits + self.expert_bias, dim=-1
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# selected_heads I = TopK(R̂): K head indices per token, shape (B, N, K).
|
| 111 |
+
selected_heads = biased_routing_scores.topk(K, dim=-1).indices
|
| 112 |
+
|
| 113 |
+
# Routing probabilities P: gathered from unbiased R at selected_heads indices,
|
| 114 |
+
# then renormalized so they sum to 1 per token. Gathering from routing_scores
|
| 115 |
+
# (not biased_routing_scores) is the invariant that keeps the gradient path from
|
| 116 |
+
# the output back to the router weights free of expert_bias influence.
|
| 117 |
+
gathered = routing_scores.gather(dim=-1, index=selected_heads) # V, (B, N, K)
|
| 118 |
+
routing_probs = gathered / gathered.sum(dim=-1, keepdim=True) # P, (B, N, K)
|
| 119 |
+
|
| 120 |
+
# Routing frequency f_l: fraction of active (batch, token, head_slot) triples
|
| 121 |
+
# assigned to each head. Dead tokens are excluded by zeroing their rows in the
|
| 122 |
+
# assignment mask before reduction. Normalization uses the active assignment
|
| 123 |
+
# count so frequencies remain properly scaled regardless of how many tokens
|
| 124 |
+
# are live in this chunk.
|
| 125 |
+
assignment_mask = torch.zeros(B, N, L, device=x.device, dtype=x.dtype)
|
| 126 |
+
assignment_mask.scatter_(-1, selected_heads, 1.0)
|
| 127 |
+
active_assignments = assignment_mask * active_mask.unsqueeze(-1)
|
| 128 |
+
num_active_assignments = active_mask.sum() * K
|
| 129 |
+
routing_freqs = active_assignments.sum(dim=(0, 1)) / num_active_assignments # f, (L,)
|
| 130 |
+
|
| 131 |
+
# Load balance loss via custom autograd. expert_bias is an input so PyTorch
|
| 132 |
+
# registers it as a graph node; the custom backward writes the DeepSeek-style
|
| 133 |
+
# correction gradient to expert_bias.grad for the optimizer to consume.
|
| 134 |
+
load_balance_loss = LoadBalanceLoss.apply(self.expert_bias, routing_freqs)
|
| 135 |
+
|
| 136 |
+
# MaxVio is a detached monitoring scalar derived from routing_freqs. It must
|
| 137 |
+
# not contribute gradients under any circumstance, so it is detached at the
|
| 138 |
+
# point of computation rather than left to callers to detach.
|
| 139 |
+
max_vio = self._compute_max_vio(routing_freqs, L)
|
| 140 |
+
|
| 141 |
+
return selected_heads, routing_probs, load_balance_loss, max_vio
|
| 142 |
+
|
| 143 |
+
@staticmethod
|
| 144 |
+
def _compute_max_vio(routing_freqs: torch.Tensor, num_heads: int) -> torch.Tensor:
|
| 145 |
+
"""Compute the MaxVio routing-imbalance scalar.
|
| 146 |
+
|
| 147 |
+
MaxVio = L · max_l(f_l − 1/L), where f_l is the realised routing frequency of
|
| 148 |
+
head l and 1/L is the perfectly balanced target. A value of zero indicates
|
| 149 |
+
perfect balance; a value of 1 means the most overloaded head received exactly
|
| 150 |
+
double its fair share.
|
| 151 |
+
|
| 152 |
+
The result is detached from the autograd graph — MaxVio is a monitoring scalar
|
| 153 |
+
and must never contribute gradients to any parameter.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
routing_freqs: Per-head routing frequencies of shape (L,). Sums to 1.
|
| 157 |
+
num_heads: Total number of MoSRAH heads L.
|
| 158 |
+
|
| 159 |
+
Returns:
|
| 160 |
+
Detached scalar MaxVio tensor.
|
| 161 |
+
"""
|
| 162 |
+
return (num_heads * (routing_freqs - 1.0 / num_heads).max()).detach()
|
__attention__shram.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SHRAM hybrid attention layer.
|
| 2 |
+
|
| 3 |
+
This module implements the hybrid attention construction H(x) = h_l(x) + h_s(x)
|
| 4 |
+
used at one decoder attention slot in SHRAM.
|
| 5 |
+
|
| 6 |
+
The local sliding-window path and the MoSRAH sparse path are already verified
|
| 7 |
+
independently. The responsibility here is therefore not to introduce new
|
| 8 |
+
attention logic, but to preserve the bridge contracts between them: both paths
|
| 9 |
+
must consume the same input hidden state, each path must receive the sub-cache
|
| 10 |
+
it actually owns, the two model-space outputs must be summed directly, and the
|
| 11 |
+
sparse-path load-balance loss must remain visible to the caller.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from torch import nn
|
| 16 |
+
|
| 17 |
+
from .__cache__shram_layer_cache import ShramLayerCache
|
| 18 |
+
from .configuration import ShramConfig
|
| 19 |
+
from .__attention__sliding_window_attention import SlidingWindowAttention
|
| 20 |
+
from .__attention__mosrah import MoSRAHLayer
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class SHRAMHybridLayer(nn.Module):
|
| 24 |
+
"""Hybrid attention layer H(x) = h_l(x) + h_s(x) for one decoder slot.
|
| 25 |
+
|
| 26 |
+
The local path preserves nearby-token behavior through sliding-window causal
|
| 27 |
+
attention. The sparse path is the theorem-facing MoSRAH routed attention
|
| 28 |
+
path. Both operate over the same model-space hidden state and return
|
| 29 |
+
model-space outputs, so the hybrid composition is a direct sum in model
|
| 30 |
+
space.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, config: ShramConfig) -> None:
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.local_attention = SlidingWindowAttention(config)
|
| 36 |
+
self.sparse_attention = MoSRAHLayer(config)
|
| 37 |
+
|
| 38 |
+
def forward(
|
| 39 |
+
self,
|
| 40 |
+
hidden_states: torch.Tensor,
|
| 41 |
+
position_ids: torch.Tensor,
|
| 42 |
+
active_mask: torch.Tensor,
|
| 43 |
+
cache: ShramLayerCache | None,
|
| 44 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 45 |
+
"""Apply the SHRAM hybrid attention layer.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
hidden_states: Input hidden states of shape (B, N, d).
|
| 49 |
+
position_ids: Authoritative token positions of shape (B, N).
|
| 50 |
+
active_mask: Current-chunk active mask of shape (B, N), where True
|
| 51 |
+
means the token is semantically live. Forwarded unchanged to
|
| 52 |
+
both the local path and the sparse path.
|
| 53 |
+
cache: Optional per-layer SHRAM cache. When provided, the owned
|
| 54 |
+
sliding-window and MoSRAH sub-caches are dispatched directly to
|
| 55 |
+
their corresponding attention paths.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
hybrid_output: Model-space hybrid attention output of shape (B, N, d).
|
| 59 |
+
load_balance_loss: Scalar sparse-path load-balance loss.
|
| 60 |
+
max_vio: Detached scalar routing-imbalance summary. Passed through
|
| 61 |
+
unchanged from MoSRAHLayer; see MoSRAHRouter for semantics.
|
| 62 |
+
"""
|
| 63 |
+
# ------------------------------------------------
|
| 64 |
+
# It is not possible, due to how bea constructs its block mask,
|
| 65 |
+
# for the model to process a sequence that does not start at zero
|
| 66 |
+
# without a cache to track the per-head offsets
|
| 67 |
+
# ------------------------------------------------
|
| 68 |
+
|
| 69 |
+
if cache is None and torch.any(position_ids[:, 0] != 0):
|
| 70 |
+
raise ValueError(
|
| 71 |
+
"Uncached SHRAMHybridLayer does not support nonzero starting positions. "
|
| 72 |
+
"Either provide a matching ShramLayerCache populated by the prefix for "
|
| 73 |
+
"continued decoding, or rebase the uncached sequence to start at 0."
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# -------------------------------------------------------------------
|
| 77 |
+
# The hybrid layer's first responsibility is cache dispatch. The layer
|
| 78 |
+
# cache already owns the concrete sub-cache objects required by each
|
| 79 |
+
# path, so this unit should forward those exact references rather than
|
| 80 |
+
# reinterpret cache ownership or invent a composite update protocol here.
|
| 81 |
+
# -------------------------------------------------------------------
|
| 82 |
+
if cache is None:
|
| 83 |
+
sliding_window_cache = None
|
| 84 |
+
mosrah_cache = None
|
| 85 |
+
else:
|
| 86 |
+
sliding_window_cache = cache.sliding_window_cache
|
| 87 |
+
mosrah_cache = cache.mosrah_cache
|
| 88 |
+
|
| 89 |
+
# -------------------------------------------------------------------
|
| 90 |
+
# Both attention paths must see the same model-space hidden state for
|
| 91 |
+
# the current decoder layer. The local path preserves short-range
|
| 92 |
+
# structure, while the sparse path provides the routed long-range
|
| 93 |
+
# contribution and emits the load-balance signal used by training.
|
| 94 |
+
# -------------------------------------------------------------------
|
| 95 |
+
local_output = self.local_attention(
|
| 96 |
+
x=hidden_states,
|
| 97 |
+
position_ids=position_ids,
|
| 98 |
+
active_mask=active_mask,
|
| 99 |
+
cache=sliding_window_cache,
|
| 100 |
+
)
|
| 101 |
+
sparse_output, load_balance_loss, max_vio = self.sparse_attention(
|
| 102 |
+
hidden_states=hidden_states,
|
| 103 |
+
position_ids=position_ids,
|
| 104 |
+
active_mask=active_mask,
|
| 105 |
+
cache=mosrah_cache,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
# -------------------------------------------------------------------
|
| 109 |
+
# The composition rule is intentionally simple at this boundary. Both
|
| 110 |
+
# sublayers already return model-space tensors of matching shape, so the
|
| 111 |
+
# correct hybrid behavior is their direct sum with no additional mixing
|
| 112 |
+
# logic introduced here.
|
| 113 |
+
# -------------------------------------------------------------------
|
| 114 |
+
hybrid_output = local_output + sparse_output
|
| 115 |
+
|
| 116 |
+
return hybrid_output, load_balance_loss, max_vio
|
__attention__sliding_window_attention.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/shram/model/attention/sliding_window_attention.py
|
| 2 |
+
|
| 3 |
+
"""Local sliding-window attention path for SHRAM.
|
| 4 |
+
|
| 5 |
+
This file defines `SlidingWindowAttention`, the local short-range attention path
|
| 6 |
+
used inside the SHRAM hybrid layer.
|
| 7 |
+
|
| 8 |
+
In the masked-continuation variant, the local cache no longer returns a
|
| 9 |
+
semantically dense visible frame. Instead, `LocalSlidingWindowLayerCache`
|
| 10 |
+
returns:
|
| 11 |
+
|
| 12 |
+
- the retained local window memory concatenated with the current chunk
|
| 13 |
+
- an aligned active mask over that returned frame
|
| 14 |
+
|
| 15 |
+
This module consumes that returned frame directly and constructs effective local
|
| 16 |
+
causal/window visibility from the mask. It does not own cache retention policy;
|
| 17 |
+
it owns only local attention semantics.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import math
|
| 21 |
+
from typing import Any
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
| 26 |
+
|
| 27 |
+
from .__cache__sliding_window_cache import LocalSlidingWindowLayerCache
|
| 28 |
+
from .configuration import ShramConfig
|
| 29 |
+
from .rope import RotaryEmbedding
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class SlidingWindowAttention(nn.Module):
|
| 33 |
+
"""Causal local sliding-window attention for one SHRAM layer.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
config: SHRAM config. Must expose `hidden_size`,
|
| 37 |
+
`num_sliding_window_heads`, `head_dim`, `window_size`,
|
| 38 |
+
`attention_dropout`, and `local_rope_theta`.
|
| 39 |
+
|
| 40 |
+
Raises:
|
| 41 |
+
NotImplementedError: If `attention_dropout != 0.0`.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, config: ShramConfig) -> None:
|
| 45 |
+
super().__init__()
|
| 46 |
+
|
| 47 |
+
self.hidden_size = config.hidden_size
|
| 48 |
+
self.num_heads = config.num_sliding_window_heads
|
| 49 |
+
self.head_dim = config.head_dim
|
| 50 |
+
self.window_size = config.window_size
|
| 51 |
+
self.attention_dropout = config.attention_dropout
|
| 52 |
+
|
| 53 |
+
if self.attention_dropout != 0.0:
|
| 54 |
+
raise NotImplementedError(
|
| 55 |
+
"SlidingWindowAttention currently supports only "
|
| 56 |
+
"attention_dropout == 0.0."
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
self.inner_dim = self.num_heads * self.head_dim
|
| 60 |
+
|
| 61 |
+
# Standard MHA projections for the local path.
|
| 62 |
+
self.q_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
|
| 63 |
+
self.k_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
|
| 64 |
+
self.v_proj = nn.Linear(self.hidden_size, self.inner_dim, bias=False)
|
| 65 |
+
self.o_proj = nn.Linear(self.inner_dim, self.hidden_size, bias=False)
|
| 66 |
+
|
| 67 |
+
# The local path always uses default-mode RoPE with its own theta.
|
| 68 |
+
self.rope = RotaryEmbedding(
|
| 69 |
+
mode="default",
|
| 70 |
+
head_dim=self.head_dim,
|
| 71 |
+
theta=config.local_rope_theta,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def forward(
|
| 75 |
+
self,
|
| 76 |
+
x: torch.Tensor,
|
| 77 |
+
position_ids: torch.Tensor,
|
| 78 |
+
active_mask: torch.Tensor,
|
| 79 |
+
cache: LocalSlidingWindowLayerCache | None = None,
|
| 80 |
+
) -> torch.Tensor:
|
| 81 |
+
"""Apply local causal sliding-window attention.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
x: Input tensor of shape `(B, N, hidden_size)`.
|
| 85 |
+
position_ids: Position tensor of shape `(B, N)`.
|
| 86 |
+
active_mask: Current-chunk active mask of shape `(B, N)`, where
|
| 87 |
+
`True` means active.
|
| 88 |
+
cache: Optional `LocalSlidingWindowLayerCache`.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Output tensor of shape `(B, N, hidden_size)`.
|
| 92 |
+
"""
|
| 93 |
+
batch_size, query_len, _ = x.shape
|
| 94 |
+
|
| 95 |
+
self._validate_position_shape(x, position_ids)
|
| 96 |
+
self._validate_active_mask_shape(x, active_mask)
|
| 97 |
+
|
| 98 |
+
# (B, N, H*D) -> (B, H, N, D)
|
| 99 |
+
q = self.q_proj(x).view(
|
| 100 |
+
batch_size,
|
| 101 |
+
query_len,
|
| 102 |
+
self.num_heads,
|
| 103 |
+
self.head_dim,
|
| 104 |
+
).transpose(1, 2)
|
| 105 |
+
k = self.k_proj(x).view(
|
| 106 |
+
batch_size,
|
| 107 |
+
query_len,
|
| 108 |
+
self.num_heads,
|
| 109 |
+
self.head_dim,
|
| 110 |
+
).transpose(1, 2)
|
| 111 |
+
v = self.v_proj(x).view(
|
| 112 |
+
batch_size,
|
| 113 |
+
query_len,
|
| 114 |
+
self.num_heads,
|
| 115 |
+
self.head_dim,
|
| 116 |
+
).transpose(1, 2)
|
| 117 |
+
|
| 118 |
+
q, k, attention_scaling = self.rope(q, k, position_ids)
|
| 119 |
+
|
| 120 |
+
# The cache returns the current-step visible local frame, not merely the
|
| 121 |
+
# retained next-step cache buffer.
|
| 122 |
+
if cache is not None:
|
| 123 |
+
k_full, v_full, full_active_mask = cache.update(k, v, active_mask)
|
| 124 |
+
else:
|
| 125 |
+
k_full, v_full, full_active_mask = k, v, active_mask
|
| 126 |
+
|
| 127 |
+
block_mask = self._make_block_mask(
|
| 128 |
+
active_mask=full_active_mask,
|
| 129 |
+
batch_size=batch_size,
|
| 130 |
+
num_heads=self.num_heads,
|
| 131 |
+
query_len=query_len,
|
| 132 |
+
kv_len=k_full.shape[-2],
|
| 133 |
+
window_size=self.window_size,
|
| 134 |
+
device=x.device,
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
attn_output = flex_attention(
|
| 138 |
+
q,
|
| 139 |
+
k_full,
|
| 140 |
+
v_full,
|
| 141 |
+
block_mask=block_mask,
|
| 142 |
+
scale=attention_scaling / math.sqrt(self.head_dim),
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# (B, H, N, D) -> (B, N, H*D) -> (B, N, hidden_size)
|
| 146 |
+
attn_output = (
|
| 147 |
+
attn_output.transpose(1, 2)
|
| 148 |
+
.contiguous()
|
| 149 |
+
.view(batch_size, query_len, self.inner_dim)
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
return self.o_proj(attn_output)
|
| 153 |
+
|
| 154 |
+
def _validate_position_shape(
|
| 155 |
+
self,
|
| 156 |
+
x: torch.Tensor,
|
| 157 |
+
position_ids: torch.Tensor,
|
| 158 |
+
) -> None:
|
| 159 |
+
"""Validate the position tensor shape expected by local RoPE."""
|
| 160 |
+
if position_ids.shape != x.shape[:2]:
|
| 161 |
+
raise ValueError(
|
| 162 |
+
f"position_ids must have shape {tuple(x.shape[:2])}, "
|
| 163 |
+
f"got {tuple(position_ids.shape)}."
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
def _validate_active_mask_shape(
|
| 167 |
+
self,
|
| 168 |
+
x: torch.Tensor,
|
| 169 |
+
active_mask: torch.Tensor,
|
| 170 |
+
) -> None:
|
| 171 |
+
"""Validate the current-chunk active-mask contract."""
|
| 172 |
+
if active_mask.shape != x.shape[:2]:
|
| 173 |
+
raise ValueError(
|
| 174 |
+
f"active_mask must have shape {tuple(x.shape[:2])}, "
|
| 175 |
+
f"got {tuple(active_mask.shape)}."
|
| 176 |
+
)
|
| 177 |
+
if active_mask.dtype != torch.bool:
|
| 178 |
+
raise ValueError(
|
| 179 |
+
f"active_mask must have dtype torch.bool, got {active_mask.dtype}."
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def _make_block_mask(
|
| 183 |
+
self,
|
| 184 |
+
active_mask: torch.Tensor,
|
| 185 |
+
batch_size: int,
|
| 186 |
+
num_heads: int,
|
| 187 |
+
query_len: int,
|
| 188 |
+
kv_len: int,
|
| 189 |
+
window_size: int,
|
| 190 |
+
device: torch.device,
|
| 191 |
+
) -> Any:
|
| 192 |
+
"""Create the FlexAttention block mask for masked local continuation.
|
| 193 |
+
|
| 194 |
+
The returned local frame is chronological in raw buffer order, but dead
|
| 195 |
+
positions may remain inside it. Effective local order is therefore
|
| 196 |
+
recovered from the active mask itself by taking a cumulative count over
|
| 197 |
+
active positions.
|
| 198 |
+
|
| 199 |
+
Queries still occupy the tail of the returned frame, so raw buffer order
|
| 200 |
+
is used to locate query rows. Semantic active-token positions are then
|
| 201 |
+
used to decide causality and sliding-window distance.
|
| 202 |
+
"""
|
| 203 |
+
query_offset = kv_len - query_len
|
| 204 |
+
semantic_positions = active_mask.long().cumsum(dim=-1) - 1
|
| 205 |
+
|
| 206 |
+
def sliding_window_mask(
|
| 207 |
+
batch_idx: torch.Tensor,
|
| 208 |
+
head_idx: torch.Tensor,
|
| 209 |
+
q_idx: torch.Tensor,
|
| 210 |
+
kv_idx: torch.Tensor,
|
| 211 |
+
) -> torch.Tensor:
|
| 212 |
+
|
| 213 |
+
q_abs = query_offset + q_idx
|
| 214 |
+
|
| 215 |
+
query_is_active = active_mask[batch_idx, q_abs]
|
| 216 |
+
key_is_active = active_mask[batch_idx, kv_idx]
|
| 217 |
+
|
| 218 |
+
q_sem = semantic_positions[batch_idx, q_abs]
|
| 219 |
+
k_sem = semantic_positions[batch_idx, kv_idx]
|
| 220 |
+
|
| 221 |
+
is_causal = k_sem <= q_sem
|
| 222 |
+
in_window = (q_sem - k_sem) < window_size
|
| 223 |
+
|
| 224 |
+
return query_is_active & key_is_active & is_causal & in_window
|
| 225 |
+
|
| 226 |
+
return create_block_mask(
|
| 227 |
+
sliding_window_mask,
|
| 228 |
+
B=batch_size,
|
| 229 |
+
H=num_heads,
|
| 230 |
+
Q_LEN=query_len,
|
| 231 |
+
KV_LEN=kv_len,
|
| 232 |
+
device=device,
|
| 233 |
+
)
|
__cache__mosrah_cache.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MoSRAH sparse KV cache — single-layer implementation.
|
| 2 |
+
|
| 3 |
+
MoSRAH routes each token to K of L available expert heads, so its KV cache is indexed
|
| 4 |
+
by head rather than by sequence position. The routing is dynamic and produces a ragged
|
| 5 |
+
distribution of token counts across (batch, head) slots — different batch items may
|
| 6 |
+
route different numbers of tokens to the same head, and different heads accumulate at
|
| 7 |
+
different rates. DynamicCache cannot represent this correctly: it concatenates along
|
| 8 |
+
the sequence dimension and assumes uniform token counts across the batch. MoSRAHCache
|
| 9 |
+
therefore uses a custom buffer design.
|
| 10 |
+
|
| 11 |
+
Keys and values are stored in the CacheLayerMixin-standard self.keys and self.values
|
| 12 |
+
attributes as (B, L, T, u) tensors, where B is batch size, L is the number of expert
|
| 13 |
+
heads (num_mosrah_heads), T is the current buffer capacity, and u is the bottlenecked
|
| 14 |
+
head embedding width (head_dim). A (B, L) integer count tensor _counts tracks the
|
| 15 |
+
valid occupancy of each (batch, head) slot. Buffer capacity is exposed as the
|
| 16 |
+
buffer_capacity property and is derived directly from self.keys rather than tracked
|
| 17 |
+
as a separate variable.
|
| 18 |
+
|
| 19 |
+
The primary interface is update(key_states, value_states, active_mask), which accepts
|
| 20 |
+
expert-choice layout, stores only active entries in causal order, and returns the full
|
| 21 |
+
accumulated (keys, values, active_mask) for immediate use by BEA. The returned
|
| 22 |
+
active_mask identifies valid cached positions; everything beyond each slot's count is
|
| 23 |
+
junk data that downstream attention must exclude.
|
| 24 |
+
|
| 25 |
+
BEA applies RoPE and calls update() with post-RoPE keys (K̃). The occupancy counts
|
| 26 |
+
exposed by get_heads_lengths() must be read before update() if the caller needs the
|
| 27 |
+
pre-update occupancy for position computation (Unit 10.A). update() increments counts
|
| 28 |
+
in-place and the pre-update values are not recoverable afterward.
|
| 29 |
+
|
| 30 |
+
All buffers are allocated at construction time. MoSRAHCache is constructed by
|
| 31 |
+
ShramLayerCache, which has access to batch size, device, and all model config parameters
|
| 32 |
+
needed to fully specify the storage layout upfront.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
import torch
|
| 36 |
+
from transformers.cache_utils import CacheLayerMixin
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MoSRAHCache(CacheLayerMixin):
|
| 40 |
+
"""KV cache for the MoSRAH sparse attention path — single decoder layer.
|
| 41 |
+
|
| 42 |
+
Subclasses CacheLayerMixin to satisfy the HuggingFace per-layer cache role.
|
| 43 |
+
Stores keys and values in the mixin-standard self.keys and self.values attributes
|
| 44 |
+
using a custom (B, L, T, u) layout rather than delegating to DynamicCache,
|
| 45 |
+
which cannot represent MoSRAH's ragged per-(batch, head) token counts correctly.
|
| 46 |
+
|
| 47 |
+
All storage is allocated at construction time and is_initialized is True
|
| 48 |
+
immediately. The caller (ShramLayerCache) provides batch size, device, and model
|
| 49 |
+
config parameters so no lazy allocation is needed.
|
| 50 |
+
|
| 51 |
+
Input is expected in expert-choice layout: (B, L, T, u) key/value tensors with a
|
| 52 |
+
(B, L, T) boolean active_mask. Only positions where active_mask is True are written.
|
| 53 |
+
This matches the packed representation produced by expert packing in the MoSRAH
|
| 54 |
+
forward pass, where BEA has already applied RoPE before calling update().
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the
|
| 58 |
+
second dimension of all storage tensors.
|
| 59 |
+
head_dim: Bottlenecked head embedding width (u). Determines the fourth
|
| 60 |
+
dimension of all storage tensors.
|
| 61 |
+
batch_size: Number of sequences in the batch. Determines the first dimension
|
| 62 |
+
of all storage tensors.
|
| 63 |
+
device: Device on which to allocate all tensors. Should match the model device.
|
| 64 |
+
initial_buffer_size: Initial sequence capacity per (batch, head) slot. Doubled
|
| 65 |
+
when any slot overflows. Defaults to 64 to avoid repeated reallocation
|
| 66 |
+
during prompt processing.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
is_compileable = False
|
| 70 |
+
is_sliding = False
|
| 71 |
+
|
| 72 |
+
def __init__(
|
| 73 |
+
self,
|
| 74 |
+
num_mosrah_heads: int,
|
| 75 |
+
head_dim: int,
|
| 76 |
+
batch_size: int,
|
| 77 |
+
device: torch.device,
|
| 78 |
+
initial_buffer_size: int = 64,
|
| 79 |
+
) -> None:
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.num_mosrah_heads = num_mosrah_heads
|
| 82 |
+
self.head_dim = head_dim
|
| 83 |
+
self.batch_size = batch_size
|
| 84 |
+
self.device = device
|
| 85 |
+
|
| 86 |
+
# Allocate primary storage into the mixin-standard self.keys / self.values so
|
| 87 |
+
# that inherited methods (offload, prefetch) operate on real tensors. _counts
|
| 88 |
+
# tracks valid occupancy per (batch, head) slot.
|
| 89 |
+
self.keys: torch.Tensor = torch.zeros(
|
| 90 |
+
batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
|
| 91 |
+
)
|
| 92 |
+
self.values: torch.Tensor = torch.zeros(
|
| 93 |
+
batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
|
| 94 |
+
)
|
| 95 |
+
self._counts: torch.Tensor = torch.zeros(
|
| 96 |
+
batch_size, num_mosrah_heads, dtype=torch.long, device=device
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# Storage is fully allocated at construction — the cache is initialized.
|
| 100 |
+
self.is_initialized = True
|
| 101 |
+
|
| 102 |
+
# ---------------------------------------------------------------------------
|
| 103 |
+
# Properties
|
| 104 |
+
# ---------------------------------------------------------------------------
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def buffer_capacity(self) -> int:
|
| 108 |
+
"""Current number of slots allocated per (batch, head) pair.
|
| 109 |
+
|
| 110 |
+
Derived directly from self.keys rather than tracked separately, so it is
|
| 111 |
+
always consistent with the actual buffer after expansion.
|
| 112 |
+
"""
|
| 113 |
+
return self.keys.shape[2]
|
| 114 |
+
|
| 115 |
+
# ---------------------------------------------------------------------------
|
| 116 |
+
# Primary API
|
| 117 |
+
# ---------------------------------------------------------------------------
|
| 118 |
+
|
| 119 |
+
def update( # type: ignore[override]
|
| 120 |
+
self,
|
| 121 |
+
key_states: torch.Tensor,
|
| 122 |
+
value_states: torch.Tensor,
|
| 123 |
+
active_mask: torch.Tensor,
|
| 124 |
+
cache_kwargs: dict | None = None,
|
| 125 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 126 |
+
"""Scatter active key/value states into the buffer and return the full cache state.
|
| 127 |
+
|
| 128 |
+
Accepts expert-choice layout: key_states and value_states are (B, L, T, u);
|
| 129 |
+
active_mask is (B, L, T) bool with True marking real tokens. Only active
|
| 130 |
+
positions are written; inactive positions are ignored.
|
| 131 |
+
|
| 132 |
+
Uses a cumsum construction to derive the absolute buffer position for each
|
| 133 |
+
active token without any Python loops. For a given (batch, head) slot,
|
| 134 |
+
positions are assigned in the order tokens appear along the T dimension,
|
| 135 |
+
preserving causal ordering.
|
| 136 |
+
|
| 137 |
+
Returns the full accumulated (keys, values, active_mask) across the cached
|
| 138 |
+
sparse sequence. The returned active_mask is True exactly for slots t <
|
| 139 |
+
counts[b, l]; everything beyond is junk data that BEA must exclude.
|
| 140 |
+
|
| 141 |
+
Note: get_heads_lengths() must be called before update() if the caller needs
|
| 142 |
+
the pre-update occupancy for position computation (Unit 10.A). update()
|
| 143 |
+
increments counts in-place and the pre-update values are not recoverable.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout.
|
| 147 |
+
value_states: Shape (B, L, T, u) — value vectors in expert-choice layout.
|
| 148 |
+
active_mask: Shape (B, L, T) bool — True for real tokens, False for padding.
|
| 149 |
+
cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature.
|
| 150 |
+
|
| 151 |
+
Returns:
|
| 152 |
+
Tuple of (keys, values, active_mask):
|
| 153 |
+
keys: (B, L, T, u) float — full key buffer including junk slots.
|
| 154 |
+
values: (B, L, T, u) float — full value buffer including junk slots.
|
| 155 |
+
active_mask: (B, L, T) bool — True iff slot (b, l, t) has been written.
|
| 156 |
+
"""
|
| 157 |
+
incoming_delta = active_mask.long().sum(dim=2) # (B, L)
|
| 158 |
+
|
| 159 |
+
if (self._counts + incoming_delta).max().item() > self.buffer_capacity:
|
| 160 |
+
self._expand()
|
| 161 |
+
|
| 162 |
+
# Cumulative count of active positions along T for each (b, l) slot. Entry
|
| 163 |
+
# [b, l, t] is the 1-based rank of position t among all active positions in
|
| 164 |
+
# that slot. Subtract 1 for a zero-based within-slot index. Inactive positions
|
| 165 |
+
# produce a negative value, which is excluded by the mask gate below.
|
| 166 |
+
within_slot = active_mask.long().cumsum(dim=2) - 1 # (B, L, T)
|
| 167 |
+
|
| 168 |
+
# Add the pre-update count to get the absolute buffer position for each
|
| 169 |
+
# active token.
|
| 170 |
+
abs_pos = within_slot + self._counts.unsqueeze(-1) # (B, L, T)
|
| 171 |
+
|
| 172 |
+
# Scatter key and value vectors at all active positions.
|
| 173 |
+
b_idx, l_idx, t_idx = torch.where(active_mask)
|
| 174 |
+
self.keys[b_idx, l_idx, abs_pos[b_idx, l_idx, t_idx]] = (
|
| 175 |
+
key_states[b_idx, l_idx, t_idx]
|
| 176 |
+
)
|
| 177 |
+
self.values[b_idx, l_idx, abs_pos[b_idx, l_idx, t_idx]] = (
|
| 178 |
+
value_states[b_idx, l_idx, t_idx]
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
self._counts += incoming_delta
|
| 182 |
+
|
| 183 |
+
return self.keys, self.values, self._make_active_mask()
|
| 184 |
+
|
| 185 |
+
def get_heads_lengths(self) -> torch.Tensor:
|
| 186 |
+
"""Return the per-(batch, head) token count for this layer.
|
| 187 |
+
|
| 188 |
+
This is the authoritative occupancy tensor consumed by BEA for attention
|
| 189 |
+
masking and by position computation (Unit 10.A) for semantic-sequence
|
| 190 |
+
position computation.
|
| 191 |
+
|
| 192 |
+
Note: in the MoSRAH forward pass, this must be called before update() if the
|
| 193 |
+
caller needs the pre-update occupancy. update() increments these counts in-place.
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
Integer tensor of shape (B, L) where entry [b, h] is the number of valid
|
| 197 |
+
tokens stored in the (b, h) slot. Zero for slots with no writes yet.
|
| 198 |
+
"""
|
| 199 |
+
return self._counts
|
| 200 |
+
|
| 201 |
+
# ---------------------------------------------------------------------------
|
| 202 |
+
# CacheLayerMixin — overridden coordination methods
|
| 203 |
+
# ---------------------------------------------------------------------------
|
| 204 |
+
|
| 205 |
+
def reset(self) -> None:
|
| 206 |
+
"""Clear all cached key and value tensors.
|
| 207 |
+
|
| 208 |
+
Zeroes self.keys, self.values, and _counts in place. Storage remains allocated
|
| 209 |
+
and is_initialized remains True — only the contents are cleared.
|
| 210 |
+
"""
|
| 211 |
+
self.keys.zero_()
|
| 212 |
+
self.values.zero_()
|
| 213 |
+
self._counts.zero_()
|
| 214 |
+
|
| 215 |
+
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
|
| 216 |
+
"""Reorder the batch dimension of all cached tensors for beam search.
|
| 217 |
+
|
| 218 |
+
Applied atomically across self.keys, self.values, and _counts. Beam search
|
| 219 |
+
must reorder all three together or the occupancy counts and buffer contents
|
| 220 |
+
will correspond to different beam hypotheses.
|
| 221 |
+
|
| 222 |
+
Overrides the parent because the parent's implementation calls get_seq_length(),
|
| 223 |
+
which is not supported for this cache.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
beam_idx: Permutation indices of shape (batch,) produced by the beam
|
| 227 |
+
search algorithm.
|
| 228 |
+
"""
|
| 229 |
+
self.keys = self.keys[beam_idx]
|
| 230 |
+
self.values = self.values[beam_idx]
|
| 231 |
+
self._counts = self._counts[beam_idx]
|
| 232 |
+
|
| 233 |
+
def batch_repeat_interleave(self, repeats: int) -> None:
|
| 234 |
+
"""Expand the batch dimension by repeating each entry repeats times.
|
| 235 |
+
|
| 236 |
+
Used at beam search initialisation to expand the cache from batch size B to
|
| 237 |
+
B * repeats, matching the expanded beam candidate batch. Applied atomically
|
| 238 |
+
across keys, values, and _counts; batch_size is updated to reflect the new size.
|
| 239 |
+
|
| 240 |
+
Args:
|
| 241 |
+
repeats: Number of times to repeat each batch entry.
|
| 242 |
+
"""
|
| 243 |
+
self.keys = self.keys.repeat_interleave(repeats, dim=0)
|
| 244 |
+
self.values = self.values.repeat_interleave(repeats, dim=0)
|
| 245 |
+
self._counts = self._counts.repeat_interleave(repeats, dim=0)
|
| 246 |
+
self.batch_size = self.batch_size * repeats
|
| 247 |
+
|
| 248 |
+
def batch_select_indices(self, indices: torch.Tensor) -> None:
|
| 249 |
+
"""Select a subset of batch entries by index.
|
| 250 |
+
|
| 251 |
+
Used in contrastive search to retain only the selected candidate entries.
|
| 252 |
+
Applied atomically across keys, values, and _counts; batch_size is updated
|
| 253 |
+
to reflect the number of retained entries.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
indices: 1-D integer tensor of batch indices to retain.
|
| 257 |
+
"""
|
| 258 |
+
self.keys = self.keys[indices]
|
| 259 |
+
self.values = self.values[indices]
|
| 260 |
+
self._counts = self._counts[indices]
|
| 261 |
+
self.batch_size = indices.shape[0]
|
| 262 |
+
|
| 263 |
+
def offload(self) -> None:
|
| 264 |
+
"""Offload all cached tensors to CPU.
|
| 265 |
+
|
| 266 |
+
Extends the parent to also offload _counts, which the parent does not know
|
| 267 |
+
about. All three tensors are moved atomically so device state remains consistent.
|
| 268 |
+
"""
|
| 269 |
+
super().offload()
|
| 270 |
+
self._counts = self._counts.to("cpu", non_blocking=True)
|
| 271 |
+
|
| 272 |
+
def prefetch(self) -> None:
|
| 273 |
+
"""Move all cached tensors back to the model device ahead of time.
|
| 274 |
+
|
| 275 |
+
Extends the parent to also prefetch _counts, which the parent does not know
|
| 276 |
+
about. _counts is synced to self.keys.device after the parent moves keys and
|
| 277 |
+
values, so all three remain consistent.
|
| 278 |
+
"""
|
| 279 |
+
super().prefetch()
|
| 280 |
+
if self._counts.device != self.keys.device:
|
| 281 |
+
self._counts = self._counts.to(self.keys.device, non_blocking=True)
|
| 282 |
+
|
| 283 |
+
def lazy_initialization( # type: ignore[override]
|
| 284 |
+
self, key_states: torch.Tensor, value_states: torch.Tensor
|
| 285 |
+
) -> None:
|
| 286 |
+
"""No-op — storage is fully allocated at construction time."""
|
| 287 |
+
pass
|
| 288 |
+
|
| 289 |
+
# ---------------------------------------------------------------------------
|
| 290 |
+
# CacheLayerMixin — unsupported abstract methods
|
| 291 |
+
# ---------------------------------------------------------------------------
|
| 292 |
+
|
| 293 |
+
def get_seq_length(self) -> int: # type: ignore[override]
|
| 294 |
+
"""Not supported — no single sequence length represents this cache's state.
|
| 295 |
+
|
| 296 |
+
MoSRAH heads accumulate independently; (batch, head) slots have different
|
| 297 |
+
lengths depending on routing history. There is no meaningful scalar summary.
|
| 298 |
+
Use get_heads_lengths() for per-head occupancy.
|
| 299 |
+
"""
|
| 300 |
+
raise NotImplementedError(
|
| 301 |
+
"MoSRAHCache has no single sequence length. "
|
| 302 |
+
"Use get_heads_lengths() for per-head occupancy."
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
def get_max_cache_shape(self) -> int: # type: ignore[override]
|
| 306 |
+
"""Not supported — MoSRAHCache is dynamic and unbounded."""
|
| 307 |
+
raise NotImplementedError(
|
| 308 |
+
"MoSRAHCache is unbounded; get_max_cache_shape() is not supported."
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
def get_mask_sizes( # type: ignore[override]
|
| 312 |
+
self,
|
| 313 |
+
cache_position: torch.Tensor,
|
| 314 |
+
) -> tuple[int, int]:
|
| 315 |
+
"""Not supported — MoSRAHCache does not participate in HF mask construction."""
|
| 316 |
+
raise NotImplementedError(
|
| 317 |
+
"MoSRAHCache does not support get_mask_sizes()."
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
# ---------------------------------------------------------------------------
|
| 321 |
+
# Internal helpers
|
| 322 |
+
# ---------------------------------------------------------------------------
|
| 323 |
+
|
| 324 |
+
def _make_active_mask(self) -> torch.Tensor:
|
| 325 |
+
"""Construct the (B, L, T) active mask from current counts.
|
| 326 |
+
|
| 327 |
+
Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot
|
| 328 |
+
has been written. Positions at or beyond the count are junk and must be
|
| 329 |
+
excluded by downstream attention.
|
| 330 |
+
"""
|
| 331 |
+
cap = self.buffer_capacity
|
| 332 |
+
return (
|
| 333 |
+
torch.arange(cap, device=self.keys.device)
|
| 334 |
+
.expand(self.batch_size, self.num_mosrah_heads, cap)
|
| 335 |
+
< self._counts.unsqueeze(-1)
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
def _expand(self) -> None:
|
| 339 |
+
"""Double the buffer capacity, preserving existing data.
|
| 340 |
+
|
| 341 |
+
Called by update() when an incoming batch of tokens would cause any
|
| 342 |
+
(batch, head) slot to exceed the current buffer capacity. All existing
|
| 343 |
+
key and value data is copied into the low half of the new buffer; the
|
| 344 |
+
high half is zero-initialised and will be filled by subsequent writes.
|
| 345 |
+
After reassignment, buffer_capacity reflects the new size automatically.
|
| 346 |
+
"""
|
| 347 |
+
old_cap = self.buffer_capacity
|
| 348 |
+
new_cap = old_cap * 2
|
| 349 |
+
dev = self.keys.device
|
| 350 |
+
new_keys = torch.zeros(
|
| 351 |
+
self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
|
| 352 |
+
)
|
| 353 |
+
new_values = torch.zeros(
|
| 354 |
+
self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
|
| 355 |
+
)
|
| 356 |
+
new_keys[:, :, :old_cap, :] = self.keys
|
| 357 |
+
new_values[:, :, :old_cap, :] = self.values
|
| 358 |
+
self.keys = new_keys
|
| 359 |
+
self.values = new_values
|
__cache__shram_cache.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SHRAM top-level cache — model-wide owner for the full SHRAM decoder stack.
|
| 2 |
+
|
| 3 |
+
The HuggingFace Cache protocol expects a single top-level Cache object that owns one
|
| 4 |
+
CacheLayerMixin per decoder layer. The actual SHRAM caching responsibilities live one level
|
| 5 |
+
lower in ShramLayerCache — each of which owns a LocalSlidingWindowLayerCache and a MoSRAHCache.
|
| 6 |
+
ShramCache bridges those two levels: it constructs one ShramLayerCache per decoder layer,
|
| 7 |
+
presents them through the Cache interface, and transparently forwards model-wide operations
|
| 8 |
+
across all of them.
|
| 9 |
+
|
| 10 |
+
ShramCache does not define a composite update() interface. The two attention paths inside each
|
| 11 |
+
SHRAM layer have different update semantics, and neither the layer-level boundary (Unit 6.B)
|
| 12 |
+
nor the model-level boundary here can meaningfully unify them. Callers must reach down to the
|
| 13 |
+
relevant sub-cache directly. ShramCache's role is ownership, construction, and model-wide
|
| 14 |
+
coordination of the layer caches — not routing attention inputs.
|
| 15 |
+
|
| 16 |
+
Sequence length is reported by delegating to the local sliding-window sub-cache of the
|
| 17 |
+
specified layer, which tracks the cumulative count of token positions processed. This is
|
| 18 |
+
what HuggingFace generation reads through get_seq_length().
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from transformers.cache_utils import Cache
|
| 23 |
+
|
| 24 |
+
from .__cache__shram_layer_cache import ShramLayerCache
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ShramCache(Cache):
|
| 28 |
+
"""Top-level cache for the full SHRAM model.
|
| 29 |
+
|
| 30 |
+
Owns one ShramLayerCache per decoder layer. Satisfies the HuggingFace top-level Cache
|
| 31 |
+
role and transparently forwards reset, reorder, and sequence-length queries across all
|
| 32 |
+
owned layer caches.
|
| 33 |
+
|
| 34 |
+
No composite update() interface is provided. The two attention paths inside each SHRAM
|
| 35 |
+
layer have materially different update semantics; callers must update sub-caches directly
|
| 36 |
+
via cache.layers[layer_idx].sliding_window_cache or cache.layers[layer_idx].mosrah_cache.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
num_hidden_layers: Number of SHRAM decoder layers. Determines how many
|
| 40 |
+
ShramLayerCache objects are constructed.
|
| 41 |
+
sliding_window: Token window size passed to each layer's LocalSlidingWindowLayerCache.
|
| 42 |
+
num_local_heads: Number of local attention heads per layer.
|
| 43 |
+
local_head_dim: Per-head embedding width for the local path.
|
| 44 |
+
num_mosrah_heads: Total number of MoSRAH expert heads (L) per layer.
|
| 45 |
+
mosrah_head_dim: Bottlenecked head embedding width (u) for the MoSRAH path.
|
| 46 |
+
batch_size: Number of sequences in the batch.
|
| 47 |
+
device: Device on which to allocate cache tensors.
|
| 48 |
+
initial_buffer_size: Initial per-(batch, head) capacity for each MoSRAHCache.
|
| 49 |
+
Doubled when any slot overflows. Defaults to 64 to avoid repeated reallocation
|
| 50 |
+
during prompt processing.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
num_hidden_layers: int,
|
| 56 |
+
sliding_window: int,
|
| 57 |
+
num_local_heads: int,
|
| 58 |
+
local_head_dim: int,
|
| 59 |
+
num_mosrah_heads: int,
|
| 60 |
+
mosrah_head_dim: int,
|
| 61 |
+
batch_size: int,
|
| 62 |
+
device: torch.device,
|
| 63 |
+
initial_buffer_size: int = 64,
|
| 64 |
+
) -> None:
|
| 65 |
+
layers = [
|
| 66 |
+
ShramLayerCache(
|
| 67 |
+
sliding_window=sliding_window,
|
| 68 |
+
num_local_heads=num_local_heads,
|
| 69 |
+
local_head_dim=local_head_dim,
|
| 70 |
+
num_mosrah_heads=num_mosrah_heads,
|
| 71 |
+
mosrah_head_dim=mosrah_head_dim,
|
| 72 |
+
batch_size=batch_size,
|
| 73 |
+
device=device,
|
| 74 |
+
initial_buffer_size=initial_buffer_size,
|
| 75 |
+
)
|
| 76 |
+
for _ in range(num_hidden_layers)
|
| 77 |
+
]
|
| 78 |
+
super().__init__(layers=layers)
|
| 79 |
+
|
| 80 |
+
# ---------------------------------------------------------------------------
|
| 81 |
+
# Cache — composite-meaningful methods
|
| 82 |
+
# ---------------------------------------------------------------------------
|
| 83 |
+
#
|
| 84 |
+
# reset(): Inherited. Iterates all layer caches and calls reset() on each.
|
| 85 |
+
#
|
| 86 |
+
# reorder_cache(beam_idx): Inherited. Iterates all layer caches and reorders each.
|
| 87 |
+
#
|
| 88 |
+
# is_initialized: Inherited property. True iff all layer caches are initialized.
|
| 89 |
+
# Since ShramLayerCache.is_initialized is True from construction, this is True
|
| 90 |
+
# immediately after ShramCache.__init__ returns.
|
| 91 |
+
|
| 92 |
+
def get_seq_length(self, layer_idx: int = 0) -> int: # type: ignore[override]
|
| 93 |
+
"""Return the cumulative sequence length for the specified layer.
|
| 94 |
+
|
| 95 |
+
Delegates to the layer cache at layer_idx, which in turn delegates to the
|
| 96 |
+
local sliding-window sub-cache. That sub-cache is authoritative for sequence
|
| 97 |
+
progress: it sees every token presented to the layer and accumulates a truthful
|
| 98 |
+
total count. Defaults to layer 0, which is sufficient for HuggingFace generation.
|
| 99 |
+
"""
|
| 100 |
+
return self.layers[layer_idx].get_seq_length()
|
| 101 |
+
|
| 102 |
+
# ---------------------------------------------------------------------------
|
| 103 |
+
# Cache — unsupported methods
|
| 104 |
+
# ---------------------------------------------------------------------------
|
| 105 |
+
|
| 106 |
+
def update( # type: ignore[override]
|
| 107 |
+
self,
|
| 108 |
+
key_states: torch.Tensor,
|
| 109 |
+
value_states: torch.Tensor,
|
| 110 |
+
layer_idx: int,
|
| 111 |
+
cache_kwargs: dict | None = None,
|
| 112 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 113 |
+
"""Not supported — ShramCache has no composite update interface.
|
| 114 |
+
|
| 115 |
+
The two attention paths inside each SHRAM layer have different update semantics.
|
| 116 |
+
Callers must update sub-caches directly:
|
| 117 |
+
cache.layers[layer_idx].sliding_window_cache.update(key_states, value_states)
|
| 118 |
+
cache.layers[layer_idx].mosrah_cache.update(key_states, value_states, active_mask)
|
| 119 |
+
"""
|
| 120 |
+
raise NotImplementedError(
|
| 121 |
+
"ShramCache has no composite update interface. "
|
| 122 |
+
"Update sliding_window_cache or mosrah_cache on the relevant layer directly."
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
def crop(self, max_length: int) -> None:
|
| 126 |
+
"""Not supported — ShramCache layers do not implement crop()."""
|
| 127 |
+
raise NotImplementedError("ShramCache does not support crop().")
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def max_batch_size(self) -> int:
|
| 131 |
+
"""Not supported — ShramCache does not track a uniform batch size across layers."""
|
| 132 |
+
raise NotImplementedError("ShramCache does not expose max_batch_size.")
|
| 133 |
+
|
| 134 |
+
@property
|
| 135 |
+
def max_cache_len(self) -> int:
|
| 136 |
+
"""Not supported — ShramCache has no single maximum cache length.
|
| 137 |
+
|
| 138 |
+
The sliding-window side is bounded by sliding_window; the MoSRAH side is unbounded.
|
| 139 |
+
No truthful scalar maximum represents the composite.
|
| 140 |
+
"""
|
| 141 |
+
raise NotImplementedError("ShramCache does not expose max_cache_len.")
|
__cache__shram_layer_cache.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SHRAM per-layer cache — composite owner for one SHRAM decoder layer.
|
| 2 |
+
|
| 3 |
+
A SHRAM decoder layer contains two distinct attention pathways at one attention slot: the
|
| 4 |
+
local sliding-window path and the MoSRAH sparse path. Each path has its own cache with
|
| 5 |
+
different semantics and a different downstream consumer. ShramLayerCache owns both, satisfies
|
| 6 |
+
the HuggingFace per-layer cache role, and exposes each sub-cache directly so its attention
|
| 7 |
+
path can interact with it without indirection.
|
| 8 |
+
|
| 9 |
+
ShramLayerCache does not define a composite update() interface. The two paths have materially
|
| 10 |
+
different update semantics — the local side uses chunk-local key/value/mask concatenation
|
| 11 |
+
while the MoSRAH side uses expert-choice scatter with an active mask — and merging these
|
| 12 |
+
behind a single update() would hide those differences behind a misleading abstraction. Instead,
|
| 13 |
+
each attention path calls update() on the sub-cache it owns. ShramLayerCache acts as the
|
| 14 |
+
ownership, coordination, and reset/reorder boundary for one decoder layer.
|
| 15 |
+
|
| 16 |
+
Sequence length at this boundary is reported by delegating to the local sliding-window
|
| 17 |
+
sub-cache, which tracks the cumulative count of token positions processed. This is the
|
| 18 |
+
quantity HuggingFace generation reads through get_seq_length().
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from transformers.cache_utils import CacheLayerMixin
|
| 23 |
+
|
| 24 |
+
from .__cache__mosrah_cache import MoSRAHCache
|
| 25 |
+
from .__cache__sliding_window_cache import LocalSlidingWindowLayerCache
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ShramLayerCache(CacheLayerMixin):
|
| 29 |
+
"""Cache subsystem for one SHRAM decoder layer.
|
| 30 |
+
|
| 31 |
+
Owns and coordinates two sub-caches:
|
| 32 |
+
- sliding_window_cache: LocalSlidingWindowLayerCache for the local sliding-window path.
|
| 33 |
+
- mosrah_cache: MoSRAHCache for the MoSRAH sparse attention path.
|
| 34 |
+
|
| 35 |
+
Satisfies the HuggingFace per-layer cache role (CacheLayerMixin). The two sub-caches are
|
| 36 |
+
exposed directly for their downstream attention paths — no composite update() interface is
|
| 37 |
+
provided, because the two paths have materially different update semantics.
|
| 38 |
+
|
| 39 |
+
Sequence length is reported by delegating to the local sliding-window sub-cache, which
|
| 40 |
+
tracks the cumulative count of token positions processed across all update() calls.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
sliding_window: Number of tokens retained by the local sliding-window cache.
|
| 44 |
+
num_local_heads: Number of local attention heads.
|
| 45 |
+
local_head_dim: Per-head embedding width for the local path.
|
| 46 |
+
num_mosrah_heads: Total number of MoSRAH expert heads (L).
|
| 47 |
+
mosrah_head_dim: Bottlenecked head embedding width (u) for the MoSRAH path.
|
| 48 |
+
batch_size: Number of sequences in the batch.
|
| 49 |
+
device: Device on which to allocate cache tensors.
|
| 50 |
+
initial_buffer_size: Initial per-(batch, head) capacity for MoSRAHCache. Doubled
|
| 51 |
+
when any slot overflows. Defaults to 64 to avoid repeated reallocation during
|
| 52 |
+
prompt processing.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
is_compileable = False
|
| 56 |
+
is_sliding = False
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
sliding_window: int,
|
| 61 |
+
num_local_heads: int,
|
| 62 |
+
local_head_dim: int,
|
| 63 |
+
num_mosrah_heads: int,
|
| 64 |
+
mosrah_head_dim: int,
|
| 65 |
+
batch_size: int,
|
| 66 |
+
device: torch.device,
|
| 67 |
+
initial_buffer_size: int = 64,
|
| 68 |
+
) -> None:
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.sliding_window_cache = LocalSlidingWindowLayerCache(
|
| 71 |
+
sliding_window=sliding_window,
|
| 72 |
+
num_heads=num_local_heads,
|
| 73 |
+
head_dim=local_head_dim,
|
| 74 |
+
batch_size=batch_size,
|
| 75 |
+
device=device,
|
| 76 |
+
)
|
| 77 |
+
self.mosrah_cache = MoSRAHCache(
|
| 78 |
+
num_mosrah_heads=num_mosrah_heads,
|
| 79 |
+
head_dim=mosrah_head_dim,
|
| 80 |
+
batch_size=batch_size,
|
| 81 |
+
device=device,
|
| 82 |
+
initial_buffer_size=initial_buffer_size,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# ---------------------------------------------------------------------------
|
| 86 |
+
# Properties
|
| 87 |
+
# ---------------------------------------------------------------------------
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def is_initialized(self) -> bool:
|
| 91 |
+
"""True iff both sub-caches have allocated their storage.
|
| 92 |
+
|
| 93 |
+
Both LocalSlidingWindowLayerCache and MoSRAHCache pre-allocate at construction,
|
| 94 |
+
so this is True immediately after ShramLayerCache.__init__ returns.
|
| 95 |
+
"""
|
| 96 |
+
return self.sliding_window_cache.is_initialized and self.mosrah_cache.is_initialized
|
| 97 |
+
|
| 98 |
+
@is_initialized.setter
|
| 99 |
+
def is_initialized(self, value: bool) -> None:
|
| 100 |
+
# CacheLayerMixin.__init__ assigns self.is_initialized = False as an instance
|
| 101 |
+
# attribute. Since property is a data descriptor it takes precedence, but Python
|
| 102 |
+
# still routes the assignment through __set__. Absorb it silently — state is
|
| 103 |
+
# derived from sub-caches, not stored here.
|
| 104 |
+
pass
|
| 105 |
+
|
| 106 |
+
# ---------------------------------------------------------------------------
|
| 107 |
+
# CacheLayerMixin — composite-meaningful methods
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
|
| 110 |
+
def get_seq_length(self) -> int: # type: ignore[override]
|
| 111 |
+
"""Return the cumulative sequence length from the local sliding-window path.
|
| 112 |
+
|
| 113 |
+
The local path is authoritative for sequence progress: it sees every token
|
| 114 |
+
presented to this layer and accumulates a truthful total. Delegates to
|
| 115 |
+
sliding_window_cache.get_seq_length().
|
| 116 |
+
"""
|
| 117 |
+
return self.sliding_window_cache.get_seq_length()
|
| 118 |
+
|
| 119 |
+
def reset(self) -> None:
|
| 120 |
+
"""Clear both sub-caches.
|
| 121 |
+
|
| 122 |
+
Delegates reset to each sub-cache. Both are cleared atomically so the sliding-window
|
| 123 |
+
state and MoSRAH sparse state remain consistent.
|
| 124 |
+
"""
|
| 125 |
+
self.sliding_window_cache.reset()
|
| 126 |
+
self.mosrah_cache.reset()
|
| 127 |
+
|
| 128 |
+
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
|
| 129 |
+
"""Reorder the batch dimension of both sub-caches for beam search.
|
| 130 |
+
|
| 131 |
+
Delegates to each sub-cache. Both are reordered atomically so the sliding-window
|
| 132 |
+
and MoSRAH state correspond to the same beam hypotheses after reordering.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
beam_idx: Permutation indices of shape (batch,) produced by beam search.
|
| 136 |
+
"""
|
| 137 |
+
self.sliding_window_cache.reorder_cache(beam_idx)
|
| 138 |
+
self.mosrah_cache.reorder_cache(beam_idx)
|
| 139 |
+
|
| 140 |
+
def batch_repeat_interleave(self, repeats: int) -> None:
|
| 141 |
+
"""Expand the batch dimension of both sub-caches for beam search initialisation.
|
| 142 |
+
|
| 143 |
+
Delegates atomically to each sub-cache. Both must be expanded together so the
|
| 144 |
+
sliding-window and MoSRAH state correspond to the same beam candidates.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
repeats: Number of times to repeat each batch entry.
|
| 148 |
+
"""
|
| 149 |
+
self.sliding_window_cache.batch_repeat_interleave(repeats)
|
| 150 |
+
self.mosrah_cache.batch_repeat_interleave(repeats)
|
| 151 |
+
|
| 152 |
+
def batch_select_indices(self, indices: torch.Tensor) -> None:
|
| 153 |
+
"""Select a subset of batch entries in both sub-caches for contrastive search.
|
| 154 |
+
|
| 155 |
+
Delegates atomically to each sub-cache. Both must be trimmed together so the
|
| 156 |
+
sliding-window and MoSRAH state remain consistent.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
indices: 1-D integer tensor of batch indices to retain.
|
| 160 |
+
"""
|
| 161 |
+
self.sliding_window_cache.batch_select_indices(indices)
|
| 162 |
+
self.mosrah_cache.batch_select_indices(indices)
|
| 163 |
+
|
| 164 |
+
def offload(self) -> None:
|
| 165 |
+
"""Offload both sub-caches to CPU.
|
| 166 |
+
|
| 167 |
+
Delegates to each sub-cache's offload method. Does not call super() — ShramLayerCache
|
| 168 |
+
does not own self.keys/self.values directly; all cached data lives in the sub-caches.
|
| 169 |
+
"""
|
| 170 |
+
self.sliding_window_cache.offload()
|
| 171 |
+
self.mosrah_cache.offload()
|
| 172 |
+
|
| 173 |
+
def prefetch(self) -> None:
|
| 174 |
+
"""Move both sub-caches back to their model device ahead of time.
|
| 175 |
+
|
| 176 |
+
Delegates to each sub-cache's prefetch method. Does not call super() — ShramLayerCache
|
| 177 |
+
does not own self.keys/self.values directly; all cached data lives in the sub-caches.
|
| 178 |
+
"""
|
| 179 |
+
self.sliding_window_cache.prefetch()
|
| 180 |
+
self.mosrah_cache.prefetch()
|
| 181 |
+
|
| 182 |
+
def lazy_initialization( # type: ignore[override]
|
| 183 |
+
self, key_states: torch.Tensor, value_states: torch.Tensor
|
| 184 |
+
) -> None:
|
| 185 |
+
"""No-op — both sub-caches handle their own initialization."""
|
| 186 |
+
pass
|
| 187 |
+
|
| 188 |
+
# ---------------------------------------------------------------------------
|
| 189 |
+
# CacheLayerMixin — unsupported abstract methods
|
| 190 |
+
# ---------------------------------------------------------------------------
|
| 191 |
+
|
| 192 |
+
def update( # type: ignore[override]
|
| 193 |
+
self,
|
| 194 |
+
key_states: torch.Tensor,
|
| 195 |
+
value_states: torch.Tensor,
|
| 196 |
+
cache_kwargs: dict | None = None,
|
| 197 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 198 |
+
"""Not supported — ShramLayerCache has no composite update interface.
|
| 199 |
+
|
| 200 |
+
The two sub-caches have materially different update semantics: the sliding-window
|
| 201 |
+
side uses standard key/value concatenation while the MoSRAH side uses expert-choice
|
| 202 |
+
scatter with an active mask. Callers must update each sub-cache directly via
|
| 203 |
+
sliding_window_cache.update() or mosrah_cache.update().
|
| 204 |
+
"""
|
| 205 |
+
raise NotImplementedError(
|
| 206 |
+
"ShramLayerCache has no composite update interface. "
|
| 207 |
+
"Update sliding_window_cache or mosrah_cache directly."
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
def get_max_cache_shape(self) -> int: # type: ignore[override]
|
| 211 |
+
"""Not supported — the composite cache has no single maximum shape.
|
| 212 |
+
|
| 213 |
+
The sliding-window cache is bounded by sliding_window; the MoSRAH cache is
|
| 214 |
+
unbounded. No truthful scalar maximum represents the composite.
|
| 215 |
+
"""
|
| 216 |
+
raise NotImplementedError(
|
| 217 |
+
"ShramLayerCache has no single maximum cache shape. "
|
| 218 |
+
"Query sliding_window_cache or mosrah_cache directly."
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
def get_mask_sizes( # type: ignore[override]
|
| 222 |
+
self,
|
| 223 |
+
cache_position: torch.Tensor,
|
| 224 |
+
) -> tuple[int, int]:
|
| 225 |
+
"""Not supported — ShramLayerCache does not participate in HF mask construction.
|
| 226 |
+
|
| 227 |
+
The two sub-caches have different mask semantics and their respective attention
|
| 228 |
+
paths handle masking directly.
|
| 229 |
+
"""
|
| 230 |
+
raise NotImplementedError(
|
| 231 |
+
"ShramLayerCache does not support get_mask_sizes(). "
|
| 232 |
+
"Query sliding_window_cache or mosrah_cache directly."
|
| 233 |
+
)
|
__cache__sliding_window_cache.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# src/shram/model/cache/sliding_window_cache.py
|
| 2 |
+
|
| 3 |
+
"""Local sliding-window cache for the SHRAM local attention path.
|
| 4 |
+
|
| 5 |
+
This file defines `LocalSlidingWindowLayerCache`, the local sub-cache owned by
|
| 6 |
+
`ShramLayerCache` and consumed by `SlidingWindowAttention`.
|
| 7 |
+
|
| 8 |
+
Its job is narrow:
|
| 9 |
+
|
| 10 |
+
- accept the current chunk's local key/value tensors and active mask
|
| 11 |
+
- return the current-step local frame consumed by local attention
|
| 12 |
+
- separately retain the next-step sliding-window cache state
|
| 13 |
+
|
| 14 |
+
It does not decide local causal visibility. That is owned by
|
| 15 |
+
`SlidingWindowAttention`, which consumes the returned key/value/mask frame and
|
| 16 |
+
constructs the effective local attention mask from it.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from transformers.cache_utils import CacheLayerMixin
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class LocalSlidingWindowLayerCache(CacheLayerMixin):
|
| 24 |
+
"""Fixed-width local cache for one SHRAM decoder layer.
|
| 25 |
+
|
| 26 |
+
The cache keeps a retained local sliding-window buffer and an aligned active
|
| 27 |
+
mask. On update, it returns the current-step local frame formed by
|
| 28 |
+
concatenating retained cache state with the new chunk, then remembers only
|
| 29 |
+
the last `sliding_window` positions for the next step.
|
| 30 |
+
|
| 31 |
+
Dead positions are allowed to remain in both the returned frame and the
|
| 32 |
+
retained cache. Correctness is carried by the aligned active mask.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
sliding_window: Width of the retained local sliding-window buffer.
|
| 36 |
+
num_heads: Number of local attention heads.
|
| 37 |
+
head_dim: Per-head embedding width for the local path.
|
| 38 |
+
batch_size: Number of sequences in the batch.
|
| 39 |
+
device: Device on which to allocate cache storage.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
is_compileable = False
|
| 43 |
+
is_sliding = True
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
sliding_window: int,
|
| 48 |
+
num_heads: int,
|
| 49 |
+
head_dim: int,
|
| 50 |
+
batch_size: int,
|
| 51 |
+
device: torch.device,
|
| 52 |
+
) -> None:
|
| 53 |
+
super().__init__()
|
| 54 |
+
|
| 55 |
+
if sliding_window < 1:
|
| 56 |
+
raise ValueError(
|
| 57 |
+
f"sliding_window must be >= 1, got {sliding_window}."
|
| 58 |
+
)
|
| 59 |
+
if num_heads < 1:
|
| 60 |
+
raise ValueError(f"num_heads must be >= 1, got {num_heads}.")
|
| 61 |
+
if head_dim < 1:
|
| 62 |
+
raise ValueError(f"head_dim must be >= 1, got {head_dim}.")
|
| 63 |
+
if batch_size < 1:
|
| 64 |
+
raise ValueError(f"batch_size must be >= 1, got {batch_size}.")
|
| 65 |
+
|
| 66 |
+
self.sliding_window = sliding_window
|
| 67 |
+
self.num_heads = num_heads
|
| 68 |
+
self.head_dim = head_dim
|
| 69 |
+
self.batch_size = batch_size
|
| 70 |
+
self.device = device
|
| 71 |
+
|
| 72 |
+
# Retained next-step local cache state. Storage is fixed-width from the
|
| 73 |
+
# start; semantic validity is carried by `active_mask`.
|
| 74 |
+
self.keys = torch.zeros(
|
| 75 |
+
batch_size,
|
| 76 |
+
num_heads,
|
| 77 |
+
sliding_window,
|
| 78 |
+
head_dim,
|
| 79 |
+
device=device,
|
| 80 |
+
)
|
| 81 |
+
self.values = torch.zeros(
|
| 82 |
+
batch_size,
|
| 83 |
+
num_heads,
|
| 84 |
+
sliding_window,
|
| 85 |
+
head_dim,
|
| 86 |
+
device=device,
|
| 87 |
+
)
|
| 88 |
+
self.active_mask = torch.zeros(
|
| 89 |
+
batch_size,
|
| 90 |
+
sliding_window,
|
| 91 |
+
dtype=torch.bool,
|
| 92 |
+
device=device,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
self.is_initialized = True
|
| 96 |
+
|
| 97 |
+
# Cumulative count of all token positions presented through update() for
|
| 98 |
+
# this cache instance. This is the quantity HuggingFace generation reads
|
| 99 |
+
# through get_seq_length() to track how far along the sequence we are.
|
| 100 |
+
self._total_processed: int = 0
|
| 101 |
+
|
| 102 |
+
def update( # type: ignore[override]
|
| 103 |
+
self,
|
| 104 |
+
key_states: torch.Tensor,
|
| 105 |
+
value_states: torch.Tensor,
|
| 106 |
+
active_mask: torch.Tensor,
|
| 107 |
+
cache_kwargs: dict | None = None,
|
| 108 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 109 |
+
"""Return the current-step local frame and retain the next-step window.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
key_states: Shape `(B, H, T_new, D)` local key vectors for the
|
| 113 |
+
current chunk.
|
| 114 |
+
value_states: Shape `(B, H, T_new, D)` local value vectors for the
|
| 115 |
+
current chunk.
|
| 116 |
+
active_mask: Shape `(B, T_new)` bool. `True` means the
|
| 117 |
+
corresponding token position in the current chunk is active.
|
| 118 |
+
cache_kwargs: Present only to satisfy the `CacheLayerMixin`
|
| 119 |
+
interface. Unused by this cache.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Tuple of:
|
| 123 |
+
- visible_keys: `(B, H, sliding_window + T_new, D)`
|
| 124 |
+
- visible_values: `(B, H, sliding_window + T_new, D)`
|
| 125 |
+
- visible_active_mask: `(B, sliding_window + T_new)`
|
| 126 |
+
|
| 127 |
+
These are the tensors the local attention path should consume
|
| 128 |
+
directly for the current step.
|
| 129 |
+
"""
|
| 130 |
+
self._ensure_state_compatibility(
|
| 131 |
+
key_states=key_states,
|
| 132 |
+
value_states=value_states,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# The current-step local frame is just retained cache state followed by
|
| 136 |
+
# the current chunk in chronological order.
|
| 137 |
+
composite_keys, composite_values, composite_mask = self._make_composite_frame(
|
| 138 |
+
key_states=key_states,
|
| 139 |
+
value_states=value_states,
|
| 140 |
+
active_mask=active_mask,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# The cache remembers only the last raw sliding-window positions of that
|
| 144 |
+
# composite frame for the next step. Dead positions are allowed to
|
| 145 |
+
# survive; downstream local attention will ignore them using the mask.
|
| 146 |
+
self._retain_next_window(
|
| 147 |
+
composite_keys=composite_keys,
|
| 148 |
+
composite_values=composite_values,
|
| 149 |
+
composite_mask=composite_mask,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
self._total_processed += key_states.shape[2]
|
| 153 |
+
|
| 154 |
+
return composite_keys, composite_values, composite_mask
|
| 155 |
+
|
| 156 |
+
def _ensure_state_compatibility(
|
| 157 |
+
self,
|
| 158 |
+
key_states: torch.Tensor,
|
| 159 |
+
value_states: torch.Tensor,
|
| 160 |
+
) -> None:
|
| 161 |
+
"""Keep retained cache buffers compatible with the incoming update tensors.
|
| 162 |
+
|
| 163 |
+
The cache is allocated eagerly for simplicity. If later updates arrive on
|
| 164 |
+
a different device or in a different floating dtype, move the retained
|
| 165 |
+
state to match while preserving its contents.
|
| 166 |
+
"""
|
| 167 |
+
if self.keys.dtype != key_states.dtype or self.keys.device != key_states.device:
|
| 168 |
+
self.keys = self.keys.to(
|
| 169 |
+
device=key_states.device,
|
| 170 |
+
dtype=key_states.dtype,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
if (
|
| 174 |
+
self.values.dtype != value_states.dtype
|
| 175 |
+
or self.values.device != value_states.device
|
| 176 |
+
):
|
| 177 |
+
self.values = self.values.to(
|
| 178 |
+
device=value_states.device,
|
| 179 |
+
dtype=value_states.dtype,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
if self.active_mask.device != key_states.device:
|
| 183 |
+
self.active_mask = self.active_mask.to(
|
| 184 |
+
key_states.device,
|
| 185 |
+
non_blocking=True,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
def _make_composite_frame(
|
| 189 |
+
self,
|
| 190 |
+
key_states: torch.Tensor,
|
| 191 |
+
value_states: torch.Tensor,
|
| 192 |
+
active_mask: torch.Tensor,
|
| 193 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 194 |
+
"""Build the current-step local frame in chronological order."""
|
| 195 |
+
return (
|
| 196 |
+
torch.cat([self.keys, key_states], dim=-2),
|
| 197 |
+
torch.cat([self.values, value_states], dim=-2),
|
| 198 |
+
torch.cat([self.active_mask, active_mask], dim=-1),
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def _retain_next_window(
|
| 202 |
+
self,
|
| 203 |
+
composite_keys: torch.Tensor,
|
| 204 |
+
composite_values: torch.Tensor,
|
| 205 |
+
composite_mask: torch.Tensor,
|
| 206 |
+
) -> None:
|
| 207 |
+
"""Remember the next-step retained local state.
|
| 208 |
+
|
| 209 |
+
This is a raw positional trim to the last `sliding_window` positions, not
|
| 210 |
+
a semantic live-token trim.
|
| 211 |
+
"""
|
| 212 |
+
self.keys = composite_keys[:, :, -self.sliding_window :, :]
|
| 213 |
+
self.values = composite_values[:, :, -self.sliding_window :, :]
|
| 214 |
+
self.active_mask = composite_mask[:, -self.sliding_window :]
|
| 215 |
+
|
| 216 |
+
def get_seq_length(self) -> int:
|
| 217 |
+
"""Return the cumulative number of token positions processed by this cache.
|
| 218 |
+
|
| 219 |
+
This is the total count of token positions presented across all update()
|
| 220 |
+
calls since construction or the last reset(). It is the quantity HuggingFace
|
| 221 |
+
generation reads to track sequence progress and is not the same as active-token
|
| 222 |
+
count or current window occupancy.
|
| 223 |
+
"""
|
| 224 |
+
return self._total_processed
|
| 225 |
+
|
| 226 |
+
def get_max_cache_shape(self) -> int:
|
| 227 |
+
return self.sliding_window
|
| 228 |
+
|
| 229 |
+
def get_mask_sizes( # type: ignore[override]
|
| 230 |
+
self,
|
| 231 |
+
cache_position: torch.Tensor,
|
| 232 |
+
) -> tuple[int, int]:
|
| 233 |
+
raise NotImplementedError(
|
| 234 |
+
"LocalSlidingWindowLayerCache does not support get_mask_sizes()."
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
def reset(self) -> None:
|
| 238 |
+
"""Restore fresh-cache behavior."""
|
| 239 |
+
self.keys.zero_()
|
| 240 |
+
self.values.zero_()
|
| 241 |
+
self.active_mask.zero_()
|
| 242 |
+
self._total_processed = 0
|
| 243 |
+
|
| 244 |
+
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
|
| 245 |
+
"""Reorder the batch dimension for beam search."""
|
| 246 |
+
self.keys = self.keys[beam_idx]
|
| 247 |
+
self.values = self.values[beam_idx]
|
| 248 |
+
self.active_mask = self.active_mask[beam_idx]
|
| 249 |
+
|
| 250 |
+
def batch_repeat_interleave(self, repeats: int) -> None:
|
| 251 |
+
"""Expand the batch dimension for beam-search initialisation."""
|
| 252 |
+
self.keys = self.keys.repeat_interleave(repeats, dim=0)
|
| 253 |
+
self.values = self.values.repeat_interleave(repeats, dim=0)
|
| 254 |
+
self.active_mask = self.active_mask.repeat_interleave(repeats, dim=0)
|
| 255 |
+
self.batch_size = self.batch_size * repeats
|
| 256 |
+
|
| 257 |
+
def batch_select_indices(self, indices: torch.Tensor) -> None:
|
| 258 |
+
"""Select a subset of batch entries for contrastive search."""
|
| 259 |
+
self.keys = self.keys[indices]
|
| 260 |
+
self.values = self.values[indices]
|
| 261 |
+
self.active_mask = self.active_mask[indices]
|
| 262 |
+
self.batch_size = int(indices.shape[0])
|
| 263 |
+
|
| 264 |
+
def offload(self) -> None:
|
| 265 |
+
"""Offload cache tensors to CPU."""
|
| 266 |
+
super().offload()
|
| 267 |
+
self.active_mask = self.active_mask.to("cpu", non_blocking=True)
|
| 268 |
+
|
| 269 |
+
def prefetch(self) -> None:
|
| 270 |
+
"""Move cache tensors back to the model device ahead of time."""
|
| 271 |
+
super().prefetch()
|
| 272 |
+
if self.active_mask.device != self.keys.device:
|
| 273 |
+
self.active_mask = self.active_mask.to(
|
| 274 |
+
self.keys.device,
|
| 275 |
+
non_blocking=True,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
def crop(self, max_length: int) -> None:
|
| 279 |
+
raise NotImplementedError(
|
| 280 |
+
"LocalSlidingWindowLayerCache does not support crop()."
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
def lazy_initialization(
|
| 284 |
+
self,
|
| 285 |
+
key_states: torch.Tensor,
|
| 286 |
+
value_states: torch.Tensor,
|
| 287 |
+
) -> None:
|
| 288 |
+
"""No-op — this cache allocates its fixed buffers at construction time."""
|
| 289 |
+
return
|
__cache__slow_mosrah_cache.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unvectorized reference implementation of the MoSRAH sparse KV cache.
|
| 2 |
+
|
| 3 |
+
This module exists solely as a correctness oracle. SlowMoSRAHCache implements the same
|
| 4 |
+
interface and storage layout as MoSRAHCache but uses an explicit Python loop over
|
| 5 |
+
(b, l, t) triples in update(). The loop is obviously correct by inspection: each active
|
| 6 |
+
position's key and value are written to the next available slot for that (batch, head)
|
| 7 |
+
pair, in the order positions appear along the T dimension, which directly enforces
|
| 8 |
+
causal ordering without any index arithmetic to verify.
|
| 9 |
+
|
| 10 |
+
SlowMoSRAHCache is never instantiated in the model path. Its role is to provide a
|
| 11 |
+
trusted ground truth against which the vectorized MoSRAHCache.update() is validated in
|
| 12 |
+
Unit 6.A tests, and as a reference for the Unit 10.A position decoder. Because the
|
| 13 |
+
vectorized implementation is validated by asserting exact agreement with this one on all
|
| 14 |
+
test inputs, the correctness of SlowMoSRAHCache is load-bearing: its own test suite
|
| 15 |
+
(test_slow_mosrah_cache.py) must establish it is trustworthy before it can be used as
|
| 16 |
+
an oracle.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from transformers.cache_utils import CacheLayerMixin
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class SlowMoSRAHCache(CacheLayerMixin):
|
| 24 |
+
"""Unvectorized reference implementation of the MoSRAH KV cache.
|
| 25 |
+
|
| 26 |
+
Identical storage layout to MoSRAHCache: (B, L, T, u) tensors in the
|
| 27 |
+
mixin-standard self.keys and self.values attributes, plus a (B, L) _counts tensor,
|
| 28 |
+
with the same constructor signature and the same CacheLayerMixin protocol methods.
|
| 29 |
+
The sole difference is update(), which uses an explicit Python loop over (b, l, t)
|
| 30 |
+
triples rather than vectorized index arithmetic.
|
| 31 |
+
|
| 32 |
+
This class is not used in the model path. It exists so that MoSRAHCache.update()
|
| 33 |
+
can be validated by asserting exact agreement with this implementation on all test
|
| 34 |
+
inputs. See module docstring for the trust chain this enables.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the
|
| 38 |
+
second dimension of all storage tensors.
|
| 39 |
+
head_dim: Bottlenecked head embedding width (u). Determines the fourth
|
| 40 |
+
dimension of all storage tensors.
|
| 41 |
+
batch_size: Number of sequences in the batch. Determines the first dimension
|
| 42 |
+
of all storage tensors.
|
| 43 |
+
device: Device on which to allocate all tensors. Should match the model device.
|
| 44 |
+
initial_buffer_size: Initial sequence capacity per (batch, head) slot. Doubled
|
| 45 |
+
when any slot overflows. Defaults to 64 to avoid repeated reallocation
|
| 46 |
+
during prompt processing.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
is_compileable = False
|
| 50 |
+
is_sliding = False
|
| 51 |
+
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
num_mosrah_heads: int,
|
| 55 |
+
head_dim: int,
|
| 56 |
+
batch_size: int,
|
| 57 |
+
device: torch.device,
|
| 58 |
+
initial_buffer_size: int = 64,
|
| 59 |
+
) -> None:
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.num_mosrah_heads = num_mosrah_heads
|
| 62 |
+
self.head_dim = head_dim
|
| 63 |
+
self.batch_size = batch_size
|
| 64 |
+
self.device = device
|
| 65 |
+
|
| 66 |
+
# Allocate primary storage into the mixin-standard self.keys / self.values so
|
| 67 |
+
# that inherited methods (offload, prefetch) operate on real tensors. _counts
|
| 68 |
+
# tracks valid occupancy per (batch, head) slot.
|
| 69 |
+
self.keys: torch.Tensor = torch.zeros(
|
| 70 |
+
batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
|
| 71 |
+
)
|
| 72 |
+
self.values: torch.Tensor = torch.zeros(
|
| 73 |
+
batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device
|
| 74 |
+
)
|
| 75 |
+
self._counts: torch.Tensor = torch.zeros(
|
| 76 |
+
batch_size, num_mosrah_heads, dtype=torch.long, device=device
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
# Storage is fully allocated at construction — the cache is initialized.
|
| 80 |
+
self.is_initialized = True
|
| 81 |
+
|
| 82 |
+
# ---------------------------------------------------------------------------
|
| 83 |
+
# Properties
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def buffer_capacity(self) -> int:
|
| 88 |
+
"""Current number of slots allocated per (batch, head) pair.
|
| 89 |
+
|
| 90 |
+
Derived directly from self.keys rather than tracked separately, so it is
|
| 91 |
+
always consistent with the actual buffer after expansion.
|
| 92 |
+
"""
|
| 93 |
+
return self.keys.shape[2]
|
| 94 |
+
|
| 95 |
+
# ---------------------------------------------------------------------------
|
| 96 |
+
# Primary API
|
| 97 |
+
# ---------------------------------------------------------------------------
|
| 98 |
+
|
| 99 |
+
def update( # type: ignore[override]
|
| 100 |
+
self,
|
| 101 |
+
key_states: torch.Tensor,
|
| 102 |
+
value_states: torch.Tensor,
|
| 103 |
+
active_mask: torch.Tensor,
|
| 104 |
+
cache_kwargs: dict | None = None,
|
| 105 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 106 |
+
"""Scatter active key/value states using an explicit loop; return full cache state.
|
| 107 |
+
|
| 108 |
+
Iterates over every (b, l, t) triple. For each position where active_mask is
|
| 109 |
+
True, the key and value are written to the next available slot for that
|
| 110 |
+
(batch, head) pair and the count is incremented. Causal ordering is guaranteed
|
| 111 |
+
because the t dimension is traversed from 0 to T-1 and counts are updated
|
| 112 |
+
immediately after each write.
|
| 113 |
+
|
| 114 |
+
Buffer expansion (doubling buffer_capacity) is triggered before any writes if
|
| 115 |
+
the incoming tokens would cause any slot to overflow the current capacity.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout.
|
| 119 |
+
value_states: Shape (B, L, T, u) — value vectors in expert-choice layout.
|
| 120 |
+
active_mask: Shape (B, L, T) bool — True for real tokens, False for padding.
|
| 121 |
+
cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
Tuple of (keys, values, active_mask):
|
| 125 |
+
keys: (B, L, T, u) float — full key buffer including junk slots.
|
| 126 |
+
values: (B, L, T, u) float — full value buffer including junk slots.
|
| 127 |
+
active_mask: (B, L, T) bool — True iff slot (b, l, t) has been written.
|
| 128 |
+
"""
|
| 129 |
+
B, L, T = active_mask.shape
|
| 130 |
+
|
| 131 |
+
# Expansion check uses the total active tokens per slot, same as the
|
| 132 |
+
# vectorized implementation, so both expand under identical conditions.
|
| 133 |
+
incoming_delta = active_mask.long().sum(dim=2) # (B, L)
|
| 134 |
+
if (self._counts + incoming_delta).max().item() > self.buffer_capacity:
|
| 135 |
+
self._expand()
|
| 136 |
+
|
| 137 |
+
# Write each active position into the next available slot for its (batch, head)
|
| 138 |
+
# pair. Iterating t from 0 to T-1 preserves causal ordering within each slot.
|
| 139 |
+
for b in range(B):
|
| 140 |
+
for l in range(L):
|
| 141 |
+
for t in range(T):
|
| 142 |
+
if active_mask[b, l, t]:
|
| 143 |
+
pos = self._counts[b, l].item()
|
| 144 |
+
self.keys[b, l, pos, :] = key_states[b, l, t, :]
|
| 145 |
+
self.values[b, l, pos, :] = value_states[b, l, t, :]
|
| 146 |
+
self._counts[b, l] += 1
|
| 147 |
+
|
| 148 |
+
return self.keys, self.values, self._make_active_mask()
|
| 149 |
+
|
| 150 |
+
def get_heads_lengths(self) -> torch.Tensor:
|
| 151 |
+
"""Return the per-(batch, head) token count for this layer.
|
| 152 |
+
|
| 153 |
+
This is the authoritative occupancy tensor consumed by BEA for attention
|
| 154 |
+
masking and by position computation (Unit 10.A) for semantic-sequence
|
| 155 |
+
position computation.
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
Integer tensor of shape (B, L) where entry [b, h] is the number of valid
|
| 159 |
+
tokens stored in the (b, h) slot. Zero for slots with no writes yet.
|
| 160 |
+
"""
|
| 161 |
+
return self._counts
|
| 162 |
+
|
| 163 |
+
# ---------------------------------------------------------------------------
|
| 164 |
+
# CacheLayerMixin — overridden coordination methods
|
| 165 |
+
# ---------------------------------------------------------------------------
|
| 166 |
+
|
| 167 |
+
def reset(self) -> None:
|
| 168 |
+
"""Clear all cached key and value tensors.
|
| 169 |
+
|
| 170 |
+
Zeroes self.keys, self.values, and _counts in place. Storage remains allocated
|
| 171 |
+
and is_initialized remains True — only the contents are cleared.
|
| 172 |
+
"""
|
| 173 |
+
self.keys.zero_()
|
| 174 |
+
self.values.zero_()
|
| 175 |
+
self._counts.zero_()
|
| 176 |
+
|
| 177 |
+
def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
|
| 178 |
+
"""Reorder the batch dimension of all cached tensors for beam search.
|
| 179 |
+
|
| 180 |
+
Applied atomically across self.keys, self.values, and _counts. Beam search
|
| 181 |
+
must reorder all three together or the occupancy counts and buffer contents
|
| 182 |
+
will correspond to different beam hypotheses.
|
| 183 |
+
|
| 184 |
+
Overrides the parent because the parent's implementation calls get_seq_length(),
|
| 185 |
+
which is not supported for this cache.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
beam_idx: Permutation indices of shape (batch,) produced by the beam
|
| 189 |
+
search algorithm.
|
| 190 |
+
"""
|
| 191 |
+
self.keys = self.keys[beam_idx]
|
| 192 |
+
self.values = self.values[beam_idx]
|
| 193 |
+
self._counts = self._counts[beam_idx]
|
| 194 |
+
|
| 195 |
+
def batch_repeat_interleave(self, repeats: int) -> None:
|
| 196 |
+
"""Expand the batch dimension by repeating each entry repeats times.
|
| 197 |
+
|
| 198 |
+
Used at beam search initialisation to expand the cache from batch size B to
|
| 199 |
+
B * repeats, matching the expanded beam candidate batch. Applied atomically
|
| 200 |
+
across keys, values, and _counts; batch_size is updated to reflect the new size.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
repeats: Number of times to repeat each batch entry.
|
| 204 |
+
"""
|
| 205 |
+
self.keys = self.keys.repeat_interleave(repeats, dim=0)
|
| 206 |
+
self.values = self.values.repeat_interleave(repeats, dim=0)
|
| 207 |
+
self._counts = self._counts.repeat_interleave(repeats, dim=0)
|
| 208 |
+
self.batch_size = self.batch_size * repeats
|
| 209 |
+
|
| 210 |
+
def batch_select_indices(self, indices: torch.Tensor) -> None:
|
| 211 |
+
"""Select a subset of batch entries by index.
|
| 212 |
+
|
| 213 |
+
Used in contrastive search to retain only the selected candidate entries.
|
| 214 |
+
Applied atomically across keys, values, and _counts; batch_size is updated
|
| 215 |
+
to reflect the number of retained entries.
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
indices: 1-D integer tensor of batch indices to retain.
|
| 219 |
+
"""
|
| 220 |
+
self.keys = self.keys[indices]
|
| 221 |
+
self.values = self.values[indices]
|
| 222 |
+
self._counts = self._counts[indices]
|
| 223 |
+
self.batch_size = indices.shape[0]
|
| 224 |
+
|
| 225 |
+
def offload(self) -> None:
|
| 226 |
+
"""Offload all cached tensors to CPU.
|
| 227 |
+
|
| 228 |
+
Extends the parent to also offload _counts, which the parent does not know
|
| 229 |
+
about. All three tensors are moved atomically so device state remains consistent.
|
| 230 |
+
"""
|
| 231 |
+
super().offload()
|
| 232 |
+
self._counts = self._counts.to("cpu", non_blocking=True)
|
| 233 |
+
|
| 234 |
+
def prefetch(self) -> None:
|
| 235 |
+
"""Move all cached tensors back to the model device ahead of time.
|
| 236 |
+
|
| 237 |
+
Extends the parent to also prefetch _counts, which the parent does not know
|
| 238 |
+
about. _counts is synced to self.keys.device after the parent moves keys and
|
| 239 |
+
values, so all three remain consistent.
|
| 240 |
+
"""
|
| 241 |
+
super().prefetch()
|
| 242 |
+
if self._counts.device != self.keys.device:
|
| 243 |
+
self._counts = self._counts.to(self.keys.device, non_blocking=True)
|
| 244 |
+
|
| 245 |
+
def lazy_initialization( # type: ignore[override]
|
| 246 |
+
self, key_states: torch.Tensor, value_states: torch.Tensor
|
| 247 |
+
) -> None:
|
| 248 |
+
"""No-op — storage is fully allocated at construction time."""
|
| 249 |
+
pass
|
| 250 |
+
|
| 251 |
+
# ---------------------------------------------------------------------------
|
| 252 |
+
# CacheLayerMixin — unsupported abstract methods
|
| 253 |
+
# ---------------------------------------------------------------------------
|
| 254 |
+
|
| 255 |
+
def get_seq_length(self) -> int: # type: ignore[override]
|
| 256 |
+
"""Not supported — no single sequence length represents this cache's state.
|
| 257 |
+
|
| 258 |
+
MoSRAH heads accumulate independently; (batch, head) slots have different
|
| 259 |
+
lengths depending on routing history. There is no meaningful scalar summary.
|
| 260 |
+
Use get_heads_lengths() for per-head occupancy.
|
| 261 |
+
"""
|
| 262 |
+
raise NotImplementedError(
|
| 263 |
+
"SlowMoSRAHCache has no single sequence length. "
|
| 264 |
+
"Use get_heads_lengths() for per-head occupancy."
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
def get_max_cache_shape(self) -> int: # type: ignore[override]
|
| 268 |
+
"""Not supported — SlowMoSRAHCache is dynamic and unbounded."""
|
| 269 |
+
raise NotImplementedError(
|
| 270 |
+
"SlowMoSRAHCache is unbounded; get_max_cache_shape() is not supported."
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
def get_mask_sizes( # type: ignore[override]
|
| 274 |
+
self,
|
| 275 |
+
cache_position: torch.Tensor,
|
| 276 |
+
) -> tuple[int, int]:
|
| 277 |
+
"""Not supported — SlowMoSRAHCache does not participate in HF mask construction."""
|
| 278 |
+
raise NotImplementedError(
|
| 279 |
+
"SlowMoSRAHCache does not support get_mask_sizes()."
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# ---------------------------------------------------------------------------
|
| 283 |
+
# Internal helpers
|
| 284 |
+
# ---------------------------------------------------------------------------
|
| 285 |
+
|
| 286 |
+
def _make_active_mask(self) -> torch.Tensor:
|
| 287 |
+
"""Construct the (B, L, T) active mask from current counts.
|
| 288 |
+
|
| 289 |
+
Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot
|
| 290 |
+
has been written. Positions at or beyond the count are junk and must be
|
| 291 |
+
excluded by downstream attention.
|
| 292 |
+
"""
|
| 293 |
+
cap = self.buffer_capacity
|
| 294 |
+
return (
|
| 295 |
+
torch.arange(cap, device=self.keys.device)
|
| 296 |
+
.expand(self.batch_size, self.num_mosrah_heads, cap)
|
| 297 |
+
< self._counts.unsqueeze(-1)
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
def _expand(self) -> None:
|
| 301 |
+
"""Double the buffer capacity, preserving existing data.
|
| 302 |
+
|
| 303 |
+
Called by update() when an incoming batch of tokens would cause any
|
| 304 |
+
(batch, head) slot to exceed the current buffer capacity. All existing
|
| 305 |
+
key and value data is copied into the low half of the new buffer; the
|
| 306 |
+
high half is zero-initialised and will be filled by subsequent writes.
|
| 307 |
+
After reassignment, buffer_capacity reflects the new size automatically.
|
| 308 |
+
"""
|
| 309 |
+
old_cap = self.buffer_capacity
|
| 310 |
+
new_cap = old_cap * 2
|
| 311 |
+
dev = self.keys.device
|
| 312 |
+
new_keys = torch.zeros(
|
| 313 |
+
self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
|
| 314 |
+
)
|
| 315 |
+
new_values = torch.zeros(
|
| 316 |
+
self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev
|
| 317 |
+
)
|
| 318 |
+
new_keys[:, :, :old_cap, :] = self.keys
|
| 319 |
+
new_values[:, :, :old_cap, :] = self.values
|
| 320 |
+
self.keys = new_keys
|
| 321 |
+
self.values = new_values
|
__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .configuration import ShramConfig
|
| 2 |
+
from .decoder_layer import DecoderLayer
|
| 3 |
+
from .huggingface import ShramForCausalLM
|
| 4 |
+
from .__attention__load_balance_loss import LoadBalanceLoss
|
| 5 |
+
from .mlp import SwiGLUMLP
|
| 6 |
+
from .model import ShramModel
|
| 7 |
+
from .rope import RotaryEmbedding
|
| 8 |
+
from .__attention__router import MoSRAHRouter
|
| 9 |
+
from .__cache__mosrah_cache import MoSRAHCache
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"DecoderLayer",
|
| 13 |
+
"LoadBalanceLoss",
|
| 14 |
+
"MoSRAHCache",
|
| 15 |
+
"MoSRAHRouter",
|
| 16 |
+
"ShramConfig",
|
| 17 |
+
"ShramForCausalLM",
|
| 18 |
+
"ShramModel",
|
| 19 |
+
"RotaryEmbedding",
|
| 20 |
+
"SwiGLUMLP",
|
| 21 |
+
]
|
config.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha": 1.0,
|
| 3 |
+
"attention_dropout": 0.0,
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoConfig": "configuration.ShramConfig",
|
| 6 |
+
"AutoModelForCausalLM": "huggingface.ShramForCausalLM"
|
| 7 |
+
},
|
| 8 |
+
"beta": 32.0,
|
| 9 |
+
"head_dim": 16,
|
| 10 |
+
"hidden_size": 512,
|
| 11 |
+
"inference_sequence_length": 1024,
|
| 12 |
+
"intermediate_size": 1366,
|
| 13 |
+
"local_rope_theta": 10000.0,
|
| 14 |
+
"model_type": "shram",
|
| 15 |
+
"mosrah_rope_theta": 10000.0,
|
| 16 |
+
"num_hidden_layers": 12,
|
| 17 |
+
"num_mosrah_heads": 16,
|
| 18 |
+
"num_selected_heads": 16,
|
| 19 |
+
"num_sliding_window_heads": 16,
|
| 20 |
+
"rms_norm_eps": 1e-05,
|
| 21 |
+
"rope_mode": "main_sequence",
|
| 22 |
+
"tie_word_embeddings": false,
|
| 23 |
+
"training_sequence_length": 1024,
|
| 24 |
+
"transformers_version": "5.3.0",
|
| 25 |
+
"use_cache": true,
|
| 26 |
+
"vocab_size": 50277,
|
| 27 |
+
"window_size": 128
|
| 28 |
+
}
|
configuration.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration for the SHRAM transformer.
|
| 2 |
+
|
| 3 |
+
All architectural parameters that vary across model scales or are meaningful research
|
| 4 |
+
variables are expressed here. Architectural constants (no bias in linear layers,
|
| 5 |
+
SwiGLU activation with SiLU gate) are implemented in the relevant modules and
|
| 6 |
+
documented at the point of use — they are not config parameters because they do not
|
| 7 |
+
vary and changing them produces a different architecture, not a different scale.
|
| 8 |
+
|
| 9 |
+
RoPE configuration is owned entirely by this config. Each attention path reads its
|
| 10 |
+
parameters directly and constructs its own RotaryEmbedding instance explicitly — no
|
| 11 |
+
HuggingFace rope infrastructure is used. See Unit 5.A design decisions in plan.md.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from transformers import PretrainedConfig
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ShramConfig(PretrainedConfig):
|
| 18 |
+
"""Configuration class for the SHRAM decoder-only transformer.
|
| 19 |
+
|
| 20 |
+
SHRAM (Sparse Hybrid Token Routed Attention Mixture) replaces every standard
|
| 21 |
+
attention layer with a hybrid layer H(x) = h_l(x) + h_s(x), where h_l is a
|
| 22 |
+
local sliding-window causal attention path and h_s is the MoSRAH sparse routed
|
| 23 |
+
path. All other components follow the Llama 3 baseline.
|
| 24 |
+
|
| 25 |
+
This config is the single source of truth for every architectural dimension of the
|
| 26 |
+
model. Nothing in the architecture may use a literal number that belongs here.
|
| 27 |
+
|
| 28 |
+
Two independent RoPE configurations exist — one per attention path:
|
| 29 |
+
|
| 30 |
+
- h_l always uses standard RoPE with ``local_rope_theta``.
|
| 31 |
+
- BEA always uses YaRN with ``mosrah_rope_theta``, ``training_sequence_length``,
|
| 32 |
+
``inference_sequence_length``, ``alpha``, and ``beta``. When
|
| 33 |
+
``inference_sequence_length == training_sequence_length`` the YaRN scale factor
|
| 34 |
+
``s = 1`` and YaRN reduces exactly to standard RoPE — this is the default state
|
| 35 |
+
and the correct setting for experiments that do not require context extension.
|
| 36 |
+
|
| 37 |
+
Registered with HuggingFace AutoClass via ``auto_map``. Instantiate from the Hub::
|
| 38 |
+
|
| 39 |
+
config = AutoConfig.from_pretrained(
|
| 40 |
+
"your-namespace/advanced-transformers-lib",
|
| 41 |
+
trust_remote_code=True,
|
| 42 |
+
num_hidden_layers=12,
|
| 43 |
+
)
|
| 44 |
+
model = AutoModelForCausalLM.from_config(config)
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
vocab_size: Vocabulary size. Controls the embedding table and output logits
|
| 48 |
+
dimension. Must match the tokenizer.
|
| 49 |
+
embedding_width: Model width ``d``. The dimension of the residual stream.
|
| 50 |
+
mlp_width: FFN hidden dimension.
|
| 51 |
+
num_decoder_layers: Number of transformer blocks stacked in sequence.
|
| 52 |
+
num_sliding_window_heads: Number of heads in the local sliding-window path h_l.
|
| 53 |
+
num_mosrah_heads: Total MoSRAH expert heads available ``L``.
|
| 54 |
+
num_selected_heads: MoSRAH heads each token selects ``K``.
|
| 55 |
+
head_dim: Per-head dimension, shared by both attention paths. Must be even
|
| 56 |
+
(RoPE rotates dimensions in pairs). Paper uses 16.
|
| 57 |
+
window_size: Sliding window size for h_l. Paper uses 128.
|
| 58 |
+
rope_mode: RoPE position encoding mode for BEA. ``"main_sequence"`` supplies
|
| 59 |
+
original sequence positions; ``"semantic_sequence"`` supplies local slot
|
| 60 |
+
indices. Both are required; experimentally correct mode is undetermined
|
| 61 |
+
(paper §4). Default ``"main_sequence"``.
|
| 62 |
+
rms_norm_eps: Epsilon for RMSNorm layers.
|
| 63 |
+
local_rope_theta: RoPE base frequency ``b`` for the local attention path h_l.
|
| 64 |
+
Paper uses b=10000.
|
| 65 |
+
mosrah_rope_theta: RoPE base frequency ``b`` for the BEA path. Paper uses
|
| 66 |
+
b=10000.
|
| 67 |
+
training_sequence_length: Context length ``C_train`` the model was or will be
|
| 68 |
+
trained at. Used to compute the YaRN scale factor for BEA.
|
| 69 |
+
inference_sequence_length: Context length ``C_target`` the model must support
|
| 70 |
+
at inference. When equal to ``training_sequence_length``, scale ``s=1``
|
| 71 |
+
and YaRN reduces to standard RoPE.
|
| 72 |
+
alpha: YaRN ramp lower boundary α (paper §A.2). Frequency dimensions with
|
| 73 |
+
``r(d) < alpha`` are fully interpolated by scale s. Paper value: 1.0.
|
| 74 |
+
beta: YaRN ramp upper boundary β (paper §A.2). Frequency dimensions with
|
| 75 |
+
``r(d) > beta`` are left unscaled. Paper value: 32.0.
|
| 76 |
+
attention_dropout: Dropout probability on attention weights. Default 0.0.
|
| 77 |
+
use_cache: Whether to return past_key_values for KV caching.
|
| 78 |
+
output_hidden_states: Whether to return hidden states after each layer.
|
| 79 |
+
tie_word_embeddings: Whether input embedding and LM head share weights.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
model_type = "shram"
|
| 83 |
+
|
| 84 |
+
auto_map = {
|
| 85 |
+
"AutoConfig": "configuration.ShramConfig",
|
| 86 |
+
"AutoModelForCausalLM": "huggingface.ShramForCausalLM",
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
vocab_size: int = 50277,
|
| 92 |
+
embedding_width: int = 512,
|
| 93 |
+
mlp_width: int = 1366,
|
| 94 |
+
num_decoder_layers: int = 12,
|
| 95 |
+
num_sliding_window_heads: int = 16,
|
| 96 |
+
num_mosrah_heads: int = 16,
|
| 97 |
+
num_selected_heads: int = 16,
|
| 98 |
+
head_dim: int = 16,
|
| 99 |
+
window_size: int = 128,
|
| 100 |
+
rope_mode: str = "main_sequence",
|
| 101 |
+
rms_norm_eps: float = 1e-5,
|
| 102 |
+
local_rope_theta: float = 10000.0,
|
| 103 |
+
mosrah_rope_theta: float = 10000.0,
|
| 104 |
+
training_sequence_length: int = 1024,
|
| 105 |
+
alpha: float = 1.0,
|
| 106 |
+
beta: float = 32.0,
|
| 107 |
+
attention_dropout: float = 0.0,
|
| 108 |
+
use_cache: bool = True,
|
| 109 |
+
output_hidden_states: bool = False,
|
| 110 |
+
tie_word_embeddings: bool = False,
|
| 111 |
+
**kwargs,
|
| 112 |
+
):
|
| 113 |
+
if head_dim % 2 != 0:
|
| 114 |
+
raise ValueError(
|
| 115 |
+
f"head_dim must be even (RoPE rotates dimensions in pairs). "
|
| 116 |
+
f"Got head_dim={head_dim}."
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
if rope_mode not in {"main_sequence", "semantic_sequence"}:
|
| 120 |
+
raise ValueError(
|
| 121 |
+
f"rope_mode must be 'main_sequence' or 'semantic_sequence', "
|
| 122 |
+
f"got '{rope_mode}'."
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if training_sequence_length <= 0:
|
| 126 |
+
raise ValueError(
|
| 127 |
+
f"training_sequence_length must be positive, "
|
| 128 |
+
f"got {training_sequence_length}."
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# inference_sequence_length is not a constructor parameter. It defaults to
|
| 132 |
+
# training_sequence_length (scale=1.0, standard RoPE). If a saved config
|
| 133 |
+
# carries the field through kwargs (after set_inference_context() was called
|
| 134 |
+
# before saving), restore it here with validation.
|
| 135 |
+
saved_inference_length = kwargs.pop("inference_sequence_length", training_sequence_length)
|
| 136 |
+
if saved_inference_length <= 0:
|
| 137 |
+
raise ValueError(
|
| 138 |
+
f"inference_sequence_length must be positive, "
|
| 139 |
+
f"got {saved_inference_length}."
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
self.vocab_size = vocab_size
|
| 143 |
+
self.hidden_size = embedding_width
|
| 144 |
+
self.intermediate_size = mlp_width
|
| 145 |
+
self.num_hidden_layers = num_decoder_layers
|
| 146 |
+
self.num_sliding_window_heads = num_sliding_window_heads
|
| 147 |
+
self.num_mosrah_heads = num_mosrah_heads
|
| 148 |
+
self.num_selected_heads = num_selected_heads
|
| 149 |
+
self.head_dim = head_dim
|
| 150 |
+
self.window_size = window_size
|
| 151 |
+
self.rope_mode = rope_mode
|
| 152 |
+
self.rms_norm_eps = rms_norm_eps
|
| 153 |
+
self.local_rope_theta = local_rope_theta
|
| 154 |
+
self.mosrah_rope_theta = mosrah_rope_theta
|
| 155 |
+
self.training_sequence_length = training_sequence_length
|
| 156 |
+
self.inference_sequence_length = saved_inference_length
|
| 157 |
+
self.alpha = alpha
|
| 158 |
+
self.beta = beta
|
| 159 |
+
self.attention_dropout = attention_dropout
|
| 160 |
+
self.use_cache = use_cache
|
| 161 |
+
|
| 162 |
+
super().__init__(
|
| 163 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 164 |
+
output_hidden_states=output_hidden_states,
|
| 165 |
+
**kwargs,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Promote auto_map to an instance attribute so PretrainedConfig.to_dict()
|
| 169 |
+
# serialises it into config.json.
|
| 170 |
+
self.auto_map = type(self).auto_map
|
| 171 |
+
|
| 172 |
+
@property
|
| 173 |
+
def scale(self) -> float:
|
| 174 |
+
"""YaRN context extension scale factor s = inference_sequence_length / training_sequence_length.
|
| 175 |
+
|
| 176 |
+
When scale == 1.0, YaRN reduces exactly to standard RoPE — all frequency
|
| 177 |
+
adjustments cancel and A_rope = 1. This is the default state.
|
| 178 |
+
"""
|
| 179 |
+
return self.inference_sequence_length / self.training_sequence_length
|
| 180 |
+
|
| 181 |
+
def set_inference_context(self, inference_sequence_length: int) -> None:
|
| 182 |
+
"""Set the inference context length for YaRN context extension.
|
| 183 |
+
|
| 184 |
+
This is the only supported way to set inference_sequence_length. At construction
|
| 185 |
+
the inference context defaults to training_sequence_length (scale=1.0, standard
|
| 186 |
+
RoPE). Call this method to configure a longer inference context, which causes
|
| 187 |
+
YaRN to interpolate frequencies and extend the effective context window.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
inference_sequence_length: Target inference context length. Must be positive.
|
| 191 |
+
Values equal to training_sequence_length produce scale=1.0 (standard RoPE).
|
| 192 |
+
Values greater than training_sequence_length enable YaRN extrapolation.
|
| 193 |
+
"""
|
| 194 |
+
if inference_sequence_length <= 0:
|
| 195 |
+
raise ValueError(
|
| 196 |
+
f"inference_sequence_length must be positive, "
|
| 197 |
+
f"got {inference_sequence_length}."
|
| 198 |
+
)
|
| 199 |
+
self.inference_sequence_length = inference_sequence_length
|
decoder_layer.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Decoder layer — a single transformer block.
|
| 2 |
+
|
| 3 |
+
Each block applies pre-norm hybrid attention followed by pre-norm MLP, with
|
| 4 |
+
residual connections around both sublayers:
|
| 5 |
+
|
| 6 |
+
normed_attn = RMSNorm(x)
|
| 7 |
+
attn_out, load_balance_loss, max_vio = SHRAMHybridLayer(normed_attn, ...)
|
| 8 |
+
h = x + attn_out
|
| 9 |
+
|
| 10 |
+
normed_mlp = RMSNorm(h)
|
| 11 |
+
mlp_out = SwiGLUMLP(normed_mlp)
|
| 12 |
+
out = h + mlp_out
|
| 13 |
+
|
| 14 |
+
Pre-norm keeps the residual stream unnormalised. Gradients flow more cleanly
|
| 15 |
+
through unnormalised residuals at depth, and each sublayer receives a stable,
|
| 16 |
+
normalised view of the signal.
|
| 17 |
+
|
| 18 |
+
Two independent RMSNorm instances are used — one before attention, one before
|
| 19 |
+
MLP. They learn different scalings because they precede layers with different
|
| 20 |
+
dynamic ranges. Sharing them would be wrong.
|
| 21 |
+
|
| 22 |
+
torch.nn.RMSNorm is used directly (available from PyTorch 2.4+). It omits mean
|
| 23 |
+
subtraction, is faster than LayerNorm, and proved more stable at scale.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
|
| 29 |
+
from .__attention__shram import SHRAMHybridLayer
|
| 30 |
+
from .__cache__shram_layer_cache import ShramLayerCache
|
| 31 |
+
from .configuration import ShramConfig
|
| 32 |
+
from .mlp import SwiGLUMLP
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class DecoderLayer(nn.Module):
|
| 36 |
+
"""A single pre-norm SHRAM decoder block.
|
| 37 |
+
|
| 38 |
+
Composes SHRAMHybridLayer and SwiGLUMLP with residual connections and
|
| 39 |
+
independent RMSNorm instances on each sublayer input.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
config: SHRAM config. Must expose ``hidden_size`` and ``rms_norm_eps``
|
| 43 |
+
in addition to the fields required by SHRAMHybridLayer and
|
| 44 |
+
SwiGLUMLP.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(self, config: ShramConfig) -> None:
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.attn_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 50 |
+
self.mlp_norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 51 |
+
self.attention = SHRAMHybridLayer(config)
|
| 52 |
+
self.mlp = SwiGLUMLP(config)
|
| 53 |
+
|
| 54 |
+
def forward(
|
| 55 |
+
self,
|
| 56 |
+
x: torch.Tensor,
|
| 57 |
+
position_ids: torch.Tensor,
|
| 58 |
+
active_mask: torch.Tensor,
|
| 59 |
+
cache: ShramLayerCache | None = None,
|
| 60 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 61 |
+
"""Apply one decoder block to the input.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
x: Input of shape (batch, seq_len, hidden_size).
|
| 65 |
+
position_ids: Authoritative positions of shape (batch, seq_len).
|
| 66 |
+
active_mask: Current-chunk active mask of shape (batch, seq_len),
|
| 67 |
+
where True means the token is semantically live. Forwarded
|
| 68 |
+
unchanged to the hybrid attention layer.
|
| 69 |
+
cache: Optional per-layer SHRAM cache passed through to the hybrid
|
| 70 |
+
attention layer unchanged.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
output: Tensor of shape (batch, seq_len, hidden_size).
|
| 74 |
+
load_balance_loss: Scalar sparse-path load-balance loss propagated
|
| 75 |
+
from SHRAMHybridLayer.
|
| 76 |
+
max_vio: Detached scalar routing-imbalance summary. Passed through
|
| 77 |
+
unchanged from SHRAMHybridLayer; see MoSRAHRouter for semantics.
|
| 78 |
+
"""
|
| 79 |
+
attn_out, load_balance_loss, max_vio = self.attention(
|
| 80 |
+
hidden_states=self.attn_norm(x),
|
| 81 |
+
position_ids=position_ids,
|
| 82 |
+
active_mask=active_mask,
|
| 83 |
+
cache=cache,
|
| 84 |
+
)
|
| 85 |
+
hidden_states = x + attn_out
|
| 86 |
+
output = hidden_states + self.mlp(self.mlp_norm(hidden_states))
|
| 87 |
+
return output, load_balance_loss, max_vio
|
huggingface.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HuggingFace causal-LM wrapper for SHRAM.
|
| 2 |
+
|
| 3 |
+
ShramForCausalLM is the HuggingFace-facing language-model boundary for SHRAM.
|
| 4 |
+
It owns token embedding lookup, LM-head projection, wrapper-level next-token
|
| 5 |
+
cross-entropy loss, config-controlled tied embeddings, and generation/cache
|
| 6 |
+
orchestration at the wrapper boundary.
|
| 7 |
+
|
| 8 |
+
The backbone remains a pure transformer stack. ShramModel accepts pre-embedded
|
| 9 |
+
hidden states together with current position IDs, a current active mask, and an
|
| 10 |
+
optional ShramCache. It has no knowledge of token IDs, vocabulary projection,
|
| 11 |
+
or causal-LM loss.
|
| 12 |
+
|
| 13 |
+
HuggingFace generation reaches this wrapper with two different tensor
|
| 14 |
+
conventions:
|
| 15 |
+
|
| 16 |
+
- ``position_ids`` is a current-step tensor. GenerationMixin updates the total
|
| 17 |
+
sequence state between steps, then slices position-bearing tensors back down
|
| 18 |
+
before calling ``forward()``.
|
| 19 |
+
- ``attention_mask`` is a full 2D mask over the total sequence so far. This
|
| 20 |
+
wrapper slices its recent chunk to produce the current semantic liveness mask
|
| 21 |
+
expected by the backbone.
|
| 22 |
+
|
| 23 |
+
Generation-created caches are handled in ``_prepare_cache_for_generation``.
|
| 24 |
+
That hook ensures HuggingFace generation uses ShramCache rather than a generic
|
| 25 |
+
dynamic cache. The direct ``forward()`` path does not silently create caches;
|
| 26 |
+
when ``use_cache=True`` it expects a truthful ShramCache to have been supplied.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from dataclasses import dataclass
|
| 30 |
+
from typing import Any
|
| 31 |
+
|
| 32 |
+
import torch
|
| 33 |
+
import torch.nn as nn
|
| 34 |
+
from transformers import GenerationMixin, PreTrainedModel
|
| 35 |
+
from transformers.cache_utils import Cache
|
| 36 |
+
from transformers.generation.configuration_utils import GenerationMode
|
| 37 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 38 |
+
|
| 39 |
+
from .__cache__shram_cache import ShramCache
|
| 40 |
+
from .configuration import ShramConfig
|
| 41 |
+
from .model import ShramModel
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
class ShramCausalLMOutput(CausalLMOutputWithPast):
|
| 46 |
+
"""SHRAM causal-LM wrapper output.
|
| 47 |
+
|
| 48 |
+
This subclasses HuggingFace's standard ``CausalLMOutputWithPast``.
|
| 49 |
+
Dataclass inheritance is sufficient here: all standard causal-LM fields and
|
| 50 |
+
ModelOutput behavior are inherited from the parent, and this subclass adds
|
| 51 |
+
only the SHRAM-specific wrapper outputs.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
ce_loss: torch.FloatTensor | None = None
|
| 55 |
+
load_balance_loss: torch.FloatTensor | None = None
|
| 56 |
+
max_vio: torch.FloatTensor | None = None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class ShramForCausalLM(PreTrainedModel, GenerationMixin):
|
| 60 |
+
"""HuggingFace-facing causal language model wrapper for SHRAM.
|
| 61 |
+
|
| 62 |
+
Owns token embeddings, LM-head projection, wrapper-level shifted CE loss,
|
| 63 |
+
tied embedding configuration, and generation/cache boundary behavior.
|
| 64 |
+
Delegates all transformer computation to ``ShramModel``.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
config: SHRAM model configuration.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
config_class = ShramConfig
|
| 71 |
+
base_model_prefix = "model"
|
| 72 |
+
_no_split_modules = ["DecoderLayer"]
|
| 73 |
+
supports_gradient_checkpointing = True
|
| 74 |
+
|
| 75 |
+
def __init__(self, config: ShramConfig) -> None:
|
| 76 |
+
super().__init__(config)
|
| 77 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 78 |
+
self.model = ShramModel(config)
|
| 79 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 80 |
+
self._configure_tied_embeddings()
|
| 81 |
+
self.post_init()
|
| 82 |
+
|
| 83 |
+
def _configure_tied_embeddings(self) -> None:
|
| 84 |
+
"""Apply config-controlled tied embedding behavior on this instance."""
|
| 85 |
+
if self.config.tie_word_embeddings:
|
| 86 |
+
self.lm_head.weight = self.embed_tokens.weight
|
| 87 |
+
self._tied_weights_keys = {
|
| 88 |
+
"lm_head.weight": "embed_tokens.weight",
|
| 89 |
+
}
|
| 90 |
+
else:
|
| 91 |
+
self._tied_weights_keys = {}
|
| 92 |
+
|
| 93 |
+
def get_input_embeddings(self) -> nn.Embedding:
|
| 94 |
+
"""Return the token embedding matrix."""
|
| 95 |
+
return self.embed_tokens
|
| 96 |
+
|
| 97 |
+
def set_input_embeddings(self, value: nn.Embedding) -> None:
|
| 98 |
+
"""Replace the token embedding matrix."""
|
| 99 |
+
self.embed_tokens = value
|
| 100 |
+
self._configure_tied_embeddings()
|
| 101 |
+
|
| 102 |
+
def get_output_embeddings(self) -> nn.Linear:
|
| 103 |
+
"""Return the LM head."""
|
| 104 |
+
return self.lm_head
|
| 105 |
+
|
| 106 |
+
def set_output_embeddings(self, value: nn.Linear) -> None:
|
| 107 |
+
"""Replace the LM head."""
|
| 108 |
+
self.lm_head = value
|
| 109 |
+
self._configure_tied_embeddings()
|
| 110 |
+
|
| 111 |
+
def _build_shram_cache(
|
| 112 |
+
self,
|
| 113 |
+
batch_size: int,
|
| 114 |
+
device: torch.device,
|
| 115 |
+
) -> ShramCache:
|
| 116 |
+
"""Construct a fresh top-level SHRAM cache."""
|
| 117 |
+
return ShramCache(
|
| 118 |
+
num_hidden_layers=self.config.num_hidden_layers,
|
| 119 |
+
sliding_window=self.config.window_size,
|
| 120 |
+
num_local_heads=self.config.num_sliding_window_heads,
|
| 121 |
+
local_head_dim=self.config.head_dim,
|
| 122 |
+
num_mosrah_heads=self.config.num_mosrah_heads,
|
| 123 |
+
mosrah_head_dim=self.config.hidden_size // self.config.num_selected_heads,
|
| 124 |
+
batch_size=batch_size,
|
| 125 |
+
device=device,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
def _validate_generation_cache_request(
|
| 129 |
+
self,
|
| 130 |
+
generation_config: Any,
|
| 131 |
+
model_kwargs: dict[str, Any],
|
| 132 |
+
generation_mode: GenerationMode,
|
| 133 |
+
) -> None:
|
| 134 |
+
"""Validate SHRAM's generation-side cache policy."""
|
| 135 |
+
if generation_mode in {
|
| 136 |
+
GenerationMode.ASSISTED_GENERATION,
|
| 137 |
+
GenerationMode.CONTRASTIVE_SEARCH,
|
| 138 |
+
}:
|
| 139 |
+
raise NotImplementedError(
|
| 140 |
+
"ShramForCausalLM does not currently support assisted generation "
|
| 141 |
+
"or contrastive search because ShramCache does not support crop()."
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
user_defined_cache = model_kwargs.get("past_key_values")
|
| 145 |
+
if user_defined_cache is not None:
|
| 146 |
+
if generation_config.cache_implementation is not None:
|
| 147 |
+
raise ValueError(
|
| 148 |
+
"Passing both `cache_implementation` and `past_key_values` "
|
| 149 |
+
"is unsupported. Please use only one."
|
| 150 |
+
)
|
| 151 |
+
if isinstance(user_defined_cache, tuple):
|
| 152 |
+
raise ValueError(
|
| 153 |
+
"Passing a tuple of `past_key_values` is not supported. "
|
| 154 |
+
"Please use a `ShramCache` instance."
|
| 155 |
+
)
|
| 156 |
+
if not isinstance(user_defined_cache, ShramCache):
|
| 157 |
+
raise TypeError(
|
| 158 |
+
"ShramForCausalLM requires `past_key_values` to be a "
|
| 159 |
+
"`ShramCache` instance."
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
if (
|
| 163 |
+
user_defined_cache is None
|
| 164 |
+
and generation_config.use_cache
|
| 165 |
+
and generation_config.cache_implementation is not None
|
| 166 |
+
):
|
| 167 |
+
raise ValueError(
|
| 168 |
+
"ShramForCausalLM does not support `cache_implementation`. "
|
| 169 |
+
"Generation-created caches must be `ShramCache` objects."
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
def _prepare_cache_for_generation(
|
| 173 |
+
self,
|
| 174 |
+
generation_config: Any,
|
| 175 |
+
model_kwargs: dict[str, Any],
|
| 176 |
+
generation_mode: GenerationMode,
|
| 177 |
+
batch_size: int,
|
| 178 |
+
max_cache_length: int,
|
| 179 |
+
) -> None:
|
| 180 |
+
"""Ensure HuggingFace generation uses ShramCache.
|
| 181 |
+
|
| 182 |
+
This is the SHRAM-specific generation hook. The rest of the default
|
| 183 |
+
generation plumbing is kept intact as much as possible.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
generation_config: Active generation configuration.
|
| 187 |
+
model_kwargs: Generation kwargs, updated in place.
|
| 188 |
+
generation_mode: HuggingFace generation mode.
|
| 189 |
+
batch_size: Effective generation batch size.
|
| 190 |
+
max_cache_length: Requested cache length. Accepted but unused here.
|
| 191 |
+
"""
|
| 192 |
+
self._validate_generation_cache_request(
|
| 193 |
+
generation_config=generation_config,
|
| 194 |
+
model_kwargs=model_kwargs,
|
| 195 |
+
generation_mode=generation_mode,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
if model_kwargs.get("past_key_values") is not None:
|
| 199 |
+
return
|
| 200 |
+
|
| 201 |
+
if not generation_config.use_cache:
|
| 202 |
+
return
|
| 203 |
+
|
| 204 |
+
num_repeats = max(
|
| 205 |
+
generation_config.num_beams or 1,
|
| 206 |
+
generation_config.num_return_sequences or 1,
|
| 207 |
+
)
|
| 208 |
+
model_kwargs["past_key_values"] = self._build_shram_cache(
|
| 209 |
+
batch_size=batch_size*num_repeats,
|
| 210 |
+
device=self.embed_tokens.weight.device,
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
def _reorder_cache(
|
| 214 |
+
self,
|
| 215 |
+
past_key_values: Cache,
|
| 216 |
+
beam_idx: torch.Tensor,
|
| 217 |
+
) -> Cache:
|
| 218 |
+
"""Reorder the cache in place for beam search."""
|
| 219 |
+
past_key_values.reorder_cache(beam_idx)
|
| 220 |
+
return past_key_values
|
| 221 |
+
|
| 222 |
+
def _validate_input_ids(self, input_ids: torch.Tensor) -> None:
|
| 223 |
+
"""Validate token IDs at the wrapper boundary."""
|
| 224 |
+
if input_ids.ndim != 2:
|
| 225 |
+
raise ValueError("input_ids must have shape (batch, seq_len).")
|
| 226 |
+
if input_ids.shape[1] == 0:
|
| 227 |
+
raise ValueError("input_ids sequence length must be nonzero.")
|
| 228 |
+
if input_ids.dtype != torch.long:
|
| 229 |
+
raise TypeError("input_ids must be an long int tensor.")
|
| 230 |
+
|
| 231 |
+
def _validate_attention_mask(
|
| 232 |
+
self,
|
| 233 |
+
input_ids: torch.Tensor,
|
| 234 |
+
attention_mask: torch.Tensor | None,
|
| 235 |
+
) -> None:
|
| 236 |
+
"""Validate the full-sequence attention mask."""
|
| 237 |
+
if attention_mask is None:
|
| 238 |
+
return
|
| 239 |
+
if attention_mask.ndim != 2:
|
| 240 |
+
raise ValueError("attention_mask must have shape (batch, total_seq_len).")
|
| 241 |
+
if attention_mask.shape[0] != input_ids.shape[0]:
|
| 242 |
+
raise ValueError("attention_mask batch dimension must match input_ids.")
|
| 243 |
+
if attention_mask.shape[1] < input_ids.shape[1]:
|
| 244 |
+
raise ValueError(
|
| 245 |
+
"attention_mask must be at least as long as the current input_ids chunk."
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def _validate_position_ids(
|
| 249 |
+
self,
|
| 250 |
+
input_ids: torch.Tensor,
|
| 251 |
+
position_ids: torch.Tensor | None,
|
| 252 |
+
) -> None:
|
| 253 |
+
"""Validate current-step position IDs."""
|
| 254 |
+
if position_ids is None:
|
| 255 |
+
return
|
| 256 |
+
if position_ids.ndim != 2:
|
| 257 |
+
raise ValueError("position_ids must have shape (batch, seq_len).")
|
| 258 |
+
if position_ids.shape != input_ids.shape:
|
| 259 |
+
raise ValueError(
|
| 260 |
+
"position_ids must match the current input_ids shape exactly."
|
| 261 |
+
)
|
| 262 |
+
if input_ids.dtype != torch.long:
|
| 263 |
+
raise TypeError("position_ids must be an long tensor.")
|
| 264 |
+
|
| 265 |
+
def _validate_labels(
|
| 266 |
+
self,
|
| 267 |
+
input_ids: torch.Tensor,
|
| 268 |
+
labels: torch.Tensor | None,
|
| 269 |
+
) -> None:
|
| 270 |
+
"""Validate label shape at the wrapper boundary."""
|
| 271 |
+
if labels is None:
|
| 272 |
+
return
|
| 273 |
+
if labels.ndim != 2:
|
| 274 |
+
raise ValueError("labels must have shape (batch, seq_len).")
|
| 275 |
+
if labels.shape != input_ids.shape:
|
| 276 |
+
raise ValueError("labels must have the same shape as input_ids.")
|
| 277 |
+
if input_ids.dtype != torch.long:
|
| 278 |
+
raise TypeError("labels must be a long tensor.")
|
| 279 |
+
|
| 280 |
+
def _validate_cache_inputs(
|
| 281 |
+
self,
|
| 282 |
+
use_cache: bool,
|
| 283 |
+
past_key_values: Cache | None,
|
| 284 |
+
) -> None:
|
| 285 |
+
"""Validate cache policy for direct wrapper calls."""
|
| 286 |
+
if use_cache:
|
| 287 |
+
if past_key_values is None:
|
| 288 |
+
raise ValueError(
|
| 289 |
+
"use_cache=True requires an explicit ShramCache. During "
|
| 290 |
+
"generate(), HuggingFace should supply this through "
|
| 291 |
+
"_prepare_cache_for_generation()."
|
| 292 |
+
)
|
| 293 |
+
if not isinstance(past_key_values, ShramCache):
|
| 294 |
+
raise TypeError(
|
| 295 |
+
"past_key_values must be a ShramCache when use_cache=True."
|
| 296 |
+
)
|
| 297 |
+
return
|
| 298 |
+
|
| 299 |
+
if past_key_values is not None:
|
| 300 |
+
raise ValueError("past_key_values was provided while use_cache=False.")
|
| 301 |
+
|
| 302 |
+
def _validate_position_sources(
|
| 303 |
+
self,
|
| 304 |
+
use_cache: bool,
|
| 305 |
+
attention_mask: torch.Tensor | None,
|
| 306 |
+
position_ids: torch.Tensor | None,
|
| 307 |
+
) -> None:
|
| 308 |
+
"""Validate that cached forward has a truthful source of positions."""
|
| 309 |
+
if use_cache and attention_mask is None and position_ids is None:
|
| 310 |
+
raise ValueError(
|
| 311 |
+
"Cached forward requires either position_ids or attention_mask."
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
def _validate_hf_boundary(
|
| 315 |
+
self,
|
| 316 |
+
output_attentions: bool | None,
|
| 317 |
+
return_dict: bool | None,
|
| 318 |
+
inputs_embeds: torch.Tensor | None,
|
| 319 |
+
cache_position: torch.Tensor | None,
|
| 320 |
+
extra_kwargs: dict[str, Any],
|
| 321 |
+
) -> None:
|
| 322 |
+
"""Validate unsupported HuggingFace-facing wrapper inputs."""
|
| 323 |
+
if output_attentions:
|
| 324 |
+
raise NotImplementedError(
|
| 325 |
+
"ShramForCausalLM does not expose output_attentions."
|
| 326 |
+
)
|
| 327 |
+
if return_dict is False:
|
| 328 |
+
raise ValueError(
|
| 329 |
+
"return_dict=False is not supported. "
|
| 330 |
+
"ShramForCausalLM always returns ShramCausalLMOutput."
|
| 331 |
+
)
|
| 332 |
+
if inputs_embeds is not None:
|
| 333 |
+
raise ValueError(
|
| 334 |
+
"inputs_embeds is not supported at the SHRAM wrapper boundary. "
|
| 335 |
+
"Pass input_ids instead."
|
| 336 |
+
)
|
| 337 |
+
if extra_kwargs:
|
| 338 |
+
unsupported = ", ".join(sorted(extra_kwargs))
|
| 339 |
+
raise TypeError(
|
| 340 |
+
f"Unsupported forward kwargs for ShramForCausalLM: {unsupported}"
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
def _standardize_full_attention_mask(
|
| 344 |
+
self,
|
| 345 |
+
input_ids: torch.Tensor,
|
| 346 |
+
attention_mask: torch.Tensor | None,
|
| 347 |
+
) -> torch.BoolTensor:
|
| 348 |
+
"""Return a concrete full-sequence boolean attention mask."""
|
| 349 |
+
if attention_mask is None:
|
| 350 |
+
return torch.ones_like(input_ids, dtype=torch.bool)
|
| 351 |
+
return attention_mask.to(dtype=torch.bool)
|
| 352 |
+
|
| 353 |
+
def _resolve_current_position_ids(
|
| 354 |
+
self,
|
| 355 |
+
input_ids: torch.Tensor,
|
| 356 |
+
position_ids: torch.Tensor | None,
|
| 357 |
+
full_attention_mask: torch.BoolTensor,
|
| 358 |
+
) -> torch.LongTensor:
|
| 359 |
+
"""Resolve concrete current-step position IDs for the backbone."""
|
| 360 |
+
if position_ids is not None:
|
| 361 |
+
return position_ids.to(dtype=torch.long)
|
| 362 |
+
|
| 363 |
+
full_position_ids = full_attention_mask.to(dtype=torch.long).cumsum(dim=-1) - 1
|
| 364 |
+
full_position_ids = full_position_ids.masked_fill(~full_attention_mask, 0)
|
| 365 |
+
current_length = input_ids.shape[1]
|
| 366 |
+
return full_position_ids[:, -current_length:]
|
| 367 |
+
|
| 368 |
+
def forward(
|
| 369 |
+
self,
|
| 370 |
+
input_ids: torch.Tensor,
|
| 371 |
+
attention_mask: torch.Tensor | None = None,
|
| 372 |
+
position_ids: torch.Tensor | None = None,
|
| 373 |
+
past_key_values: Cache | None = None,
|
| 374 |
+
use_cache: bool | None = None,
|
| 375 |
+
output_hidden_states: bool | None = None,
|
| 376 |
+
labels: torch.Tensor | None = None,
|
| 377 |
+
return_dict: bool | None = None,
|
| 378 |
+
ce_weight: float = 1.0,
|
| 379 |
+
load_balance_weight: float = 0.01,
|
| 380 |
+
**kwargs: Any,
|
| 381 |
+
) -> ShramCausalLMOutput:
|
| 382 |
+
"""Run the SHRAM causal language model wrapper.
|
| 383 |
+
|
| 384 |
+
Args:
|
| 385 |
+
input_ids: Current token IDs of shape ``(batch, seq_len)``.
|
| 386 |
+
attention_mask: Optional full 2D mask of shape
|
| 387 |
+
``(batch, total_seq_len)``. The wrapper slices its recent chunk
|
| 388 |
+
to produce the current semantic liveness mask expected by the
|
| 389 |
+
backbone.
|
| 390 |
+
position_ids: Optional current-step position IDs of shape
|
| 391 |
+
``(batch, seq_len)``. In ordinary HuggingFace generation this is
|
| 392 |
+
already the current-step tensor when it reaches ``forward()``.
|
| 393 |
+
past_key_values: Optional SHRAM cache. Required when
|
| 394 |
+
``use_cache=True``.
|
| 395 |
+
use_cache: Whether to use and return a cache. Defaults to
|
| 396 |
+
``config.use_cache``.
|
| 397 |
+
output_hidden_states: Whether to return backbone hidden states.
|
| 398 |
+
Defaults to ``config.output_hidden_states``.
|
| 399 |
+
labels: Optional target token IDs of shape ``(batch, seq_len)``.
|
| 400 |
+
return_dict: Must be ``True`` or ``None``.
|
| 401 |
+
ce_weight: Weight applied to the cross-entropy loss when combining with
|
| 402 |
+
the load-balance loss. Default 1.0.
|
| 403 |
+
load_balance_weight: Weight applied to the load-balance auxiliary loss.
|
| 404 |
+
Default 0.01, matching the paper's recommendation.
|
| 405 |
+
**kwargs: Unsupported HuggingFace kwargs fail explicitly.
|
| 406 |
+
|
| 407 |
+
Returns:
|
| 408 |
+
``ShramCausalLMOutput`` with:
|
| 409 |
+
- ``logits`` of shape ``(batch, seq_len, vocab_size)``,
|
| 410 |
+
- ``loss`` = ``ce_weight * ce_loss + load_balance_weight * load_balance_loss``
|
| 411 |
+
when labels are provided (``None`` otherwise),
|
| 412 |
+
- ``ce_loss`` — raw unweighted cross-entropy loss for logging,
|
| 413 |
+
- ``past_key_values`` as the active ``ShramCache`` or ``None``,
|
| 414 |
+
- ``hidden_states`` when requested,
|
| 415 |
+
- ``load_balance_loss`` — raw unweighted load-balance loss from the backbone,
|
| 416 |
+
- detached ``max_vio`` from the backbone.
|
| 417 |
+
"""
|
| 418 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 419 |
+
output_hidden_states = (
|
| 420 |
+
output_hidden_states
|
| 421 |
+
if output_hidden_states is not None
|
| 422 |
+
else self.config.output_hidden_states
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
inputs_embeds = kwargs.pop("inputs_embeds", None)
|
| 426 |
+
output_attentions = kwargs.pop("output_attentions", None)
|
| 427 |
+
cache_position = kwargs.pop("cache_position", None)
|
| 428 |
+
|
| 429 |
+
# ------------------------------------------------------------------
|
| 430 |
+
# Validation zone.
|
| 431 |
+
#
|
| 432 |
+
# The wrapper boundary is where HuggingFace-facing inputs are judged
|
| 433 |
+
# for truthfulness before any internal work begins. These checks are
|
| 434 |
+
# intentionally front-loaded so the core logic below can assume one
|
| 435 |
+
# coherent interpretation of the call rather than defensively checking
|
| 436 |
+
# shapes, cache policy, or unsupported HF knobs at the point of use.
|
| 437 |
+
# This keeps the main sequence readable while ensuring invalid states
|
| 438 |
+
# fail before they can silently contaminate backbone execution.
|
| 439 |
+
# ------------------------------------------------------------------
|
| 440 |
+
self._validate_input_ids(input_ids)
|
| 441 |
+
self._validate_attention_mask(input_ids, attention_mask)
|
| 442 |
+
self._validate_position_ids(input_ids, position_ids)
|
| 443 |
+
self._validate_labels(input_ids, labels)
|
| 444 |
+
self._validate_cache_inputs(use_cache, past_key_values)
|
| 445 |
+
self._validate_position_sources(use_cache, attention_mask, position_ids)
|
| 446 |
+
self._validate_hf_boundary(
|
| 447 |
+
output_attentions=output_attentions,
|
| 448 |
+
return_dict=return_dict,
|
| 449 |
+
inputs_embeds=inputs_embeds,
|
| 450 |
+
cache_position=cache_position,
|
| 451 |
+
extra_kwargs=kwargs,
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
# ------------------------------------------------------------------
|
| 455 |
+
# Standardization zone.
|
| 456 |
+
#
|
| 457 |
+
# HuggingFace and SHRAM use different boundary conventions: generation
|
| 458 |
+
# carries a full-sequence 2D attention mask, while the SHRAM backbone
|
| 459 |
+
# wants a current-step active mask and concrete current position IDs.
|
| 460 |
+
# This zone collapses those wrapper-facing conventions into one valid
|
| 461 |
+
# backbone-facing state. After this point the core no longer reasons
|
| 462 |
+
# about optional or ambiguous input forms; it works only with concrete
|
| 463 |
+
# tensors whose semantics are already fixed.
|
| 464 |
+
# ------------------------------------------------------------------
|
| 465 |
+
full_attention_mask: torch.BoolTensor = self._standardize_full_attention_mask(
|
| 466 |
+
input_ids=input_ids,
|
| 467 |
+
attention_mask=attention_mask,
|
| 468 |
+
)
|
| 469 |
+
current_length: int = input_ids.shape[1]
|
| 470 |
+
current_active_mask: torch.BoolTensor = full_attention_mask[:, -current_length:]
|
| 471 |
+
current_position_ids: torch.LongTensor = self._resolve_current_position_ids(
|
| 472 |
+
input_ids=input_ids,
|
| 473 |
+
position_ids=position_ids,
|
| 474 |
+
full_attention_mask=full_attention_mask,
|
| 475 |
+
)
|
| 476 |
+
shram_cache: ShramCache | None = past_key_values if use_cache else None
|
| 477 |
+
|
| 478 |
+
# ------------------------------------------------------------------
|
| 479 |
+
# Core wrapper responsibilities.
|
| 480 |
+
#
|
| 481 |
+
# The wrapper's primary job is kept visible here: convert token IDs to
|
| 482 |
+
# embeddings, delegate transformer computation to ShramModel, project
|
| 483 |
+
# hidden states back to vocabulary logits, optionally compute the
|
| 484 |
+
# wrapper-level shifted next-token loss, and return the HuggingFace-
|
| 485 |
+
# facing output object. The backbone remains responsible only for
|
| 486 |
+
# transformer semantics; token/vocabulary/loss concerns stay here.
|
| 487 |
+
# ------------------------------------------------------------------
|
| 488 |
+
token_embeddings: torch.FloatTensor = self.embed_tokens(input_ids)
|
| 489 |
+
backbone_outputs = self.model(
|
| 490 |
+
inputs_embeds=token_embeddings,
|
| 491 |
+
position_ids=current_position_ids,
|
| 492 |
+
active_mask=current_active_mask,
|
| 493 |
+
cache=shram_cache,
|
| 494 |
+
output_hidden_states=output_hidden_states,
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
logits: torch.FloatTensor = self.lm_head(backbone_outputs["last_hidden_state"])
|
| 498 |
+
|
| 499 |
+
ce_loss: torch.FloatTensor | None = None
|
| 500 |
+
loss: torch.FloatTensor | None = None
|
| 501 |
+
if labels is not None:
|
| 502 |
+
shift_logits = logits[:, :-1, :].contiguous()
|
| 503 |
+
shift_labels = labels[:, 1:].contiguous()
|
| 504 |
+
ce_loss = nn.functional.cross_entropy(
|
| 505 |
+
shift_logits.view(-1, self.config.vocab_size),
|
| 506 |
+
shift_labels.view(-1),
|
| 507 |
+
)
|
| 508 |
+
loss = ce_weight * ce_loss + load_balance_weight * backbone_outputs["load_balance_loss"]
|
| 509 |
+
|
| 510 |
+
return ShramCausalLMOutput(
|
| 511 |
+
loss=loss,
|
| 512 |
+
ce_loss=ce_loss,
|
| 513 |
+
logits=logits,
|
| 514 |
+
past_key_values=backbone_outputs["past_key_values"],
|
| 515 |
+
hidden_states=backbone_outputs["hidden_states"],
|
| 516 |
+
load_balance_loss=backbone_outputs["load_balance_loss"],
|
| 517 |
+
max_vio=backbone_outputs["max_vio"],
|
| 518 |
+
)
|
mlp.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SwiGLU feed-forward sublayer.
|
| 2 |
+
|
| 3 |
+
SwiGLU is a gated linear unit variant that multiplies a SiLU-gated projection
|
| 4 |
+
element-wise against a separate up-projection:
|
| 5 |
+
|
| 6 |
+
output = W_down(SiLU(W_gate(x)) ⊙ W_up(x))
|
| 7 |
+
|
| 8 |
+
The gating mechanism gives the network more expressive control over which features
|
| 9 |
+
to propagate than a plain two-matrix FFN. It requires three weight matrices instead
|
| 10 |
+
of two, which is why intermediate_size in Llama 3 is set lower than the 4× multiplier
|
| 11 |
+
typical of two-matrix FFNs — the total parameter count remains comparable.
|
| 12 |
+
|
| 13 |
+
SiLU is used as the gate activation because Llama 3 committed to SwiGLU specifically
|
| 14 |
+
— a fixed architectural choice.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from transformers import PretrainedConfig
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class SwiGLUMLP(nn.Module):
|
| 24 |
+
"""SwiGLU feed-forward sublayer.
|
| 25 |
+
|
| 26 |
+
Implements the three-matrix SwiGLU FFN used in Llama 3:
|
| 27 |
+
|
| 28 |
+
output = W_down(SiLU(W_gate(x)) ⊙ W_up(x))
|
| 29 |
+
|
| 30 |
+
No bias on any projection. SiLU as the gate activation is an architectural
|
| 31 |
+
constant — it is what defines SwiGLU specifically.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
config: Model config. Must expose ``hidden_size`` and ``intermediate_size``.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, config: PretrainedConfig) -> None:
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
| 40 |
+
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
|
| 41 |
+
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
|
| 42 |
+
|
| 43 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 44 |
+
"""Apply the SwiGLU feed-forward transformation.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
x: Input tensor of shape (batch, seq_len, hidden_size).
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Output tensor of shape (batch, seq_len, hidden_size).
|
| 51 |
+
"""
|
| 52 |
+
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
model.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Transformer backbone for Shram.
|
| 2 |
+
|
| 3 |
+
ShramModel is a pure PyTorch module: a sequence of DecoderLayer blocks followed
|
| 4 |
+
by a final RMSNorm. It accepts pre-embedded hidden states and returns contextual
|
| 5 |
+
representations. It has no knowledge of tokens, vocabulary, generation, or the
|
| 6 |
+
HuggingFace causal-LM wrapper contract.
|
| 7 |
+
|
| 8 |
+
Keeping the embedding out of the backbone is the correct convention and makes
|
| 9 |
+
the backbone genuinely modality-agnostic. The token interface — embedding lookup,
|
| 10 |
+
LM head, weight tying, and generation-facing naming conventions — belongs on the
|
| 11 |
+
task wrapper (ShramForCausalLM), which is the only class that knows this
|
| 12 |
+
backbone is being used for language modelling.
|
| 13 |
+
|
| 14 |
+
The final RMSNorm is necessary because the decoder stack uses pre-norm throughout:
|
| 15 |
+
each sublayer normalises its own input, leaving the residual stream itself
|
| 16 |
+
unnormalised. After many layers of accumulated residuals, that stream arrives at
|
| 17 |
+
the top with uncontrolled magnitude. The final norm brings it to a well-scaled
|
| 18 |
+
state before any projection. Without it, the LM head would receive signals of
|
| 19 |
+
arbitrary scale.
|
| 20 |
+
|
| 21 |
+
Caching is caller-managed. If a ShramCache is provided, ShramModel threads the
|
| 22 |
+
corresponding per-layer ShramLayerCache into each DecoderLayer and returns the
|
| 23 |
+
same top-level ShramCache object in the output dict. If None is provided, no
|
| 24 |
+
caching occurs.
|
| 25 |
+
|
| 26 |
+
Returns a plain dict with keys:
|
| 27 |
+
- "last_hidden_state": normed backbone output, shape (batch, seq_len, hidden_size)
|
| 28 |
+
- "past_key_values": the ShramCache object passed in, or None
|
| 29 |
+
- "hidden_states": tuple of per-layer activations if output_hidden_states=True, else None
|
| 30 |
+
- "load_balance_loss": scalar sum of per-layer SHRAM load-balance losses
|
| 31 |
+
- "max_vio": detached scalar maximum routing-imbalance across all decoder layers
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
import torch
|
| 35 |
+
import torch.nn as nn
|
| 36 |
+
|
| 37 |
+
from .__cache__shram_cache import ShramCache
|
| 38 |
+
from .configuration import ShramConfig
|
| 39 |
+
from .decoder_layer import DecoderLayer
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ShramModel(nn.Module):
|
| 43 |
+
"""Pure transformer backbone: decoder stack and final normalisation.
|
| 44 |
+
|
| 45 |
+
Accepts pre-embedded hidden states of shape (batch, seq_len, hidden_size)
|
| 46 |
+
and returns contextual representations of the same shape. No token embedding,
|
| 47 |
+
vocabulary projection, or causal-LM lifecycle concerns.
|
| 48 |
+
|
| 49 |
+
RoPE is applied inside each attention layer. Positional information is
|
| 50 |
+
encoded in the relationship between Q and K, not added to the residual
|
| 51 |
+
stream, so the backbone is agnostic to how positions are represented.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
config: Model configuration. Must be a ``ShramConfig`` instance.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self, config: ShramConfig) -> None:
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.config = config
|
| 60 |
+
self.layers = nn.ModuleList(
|
| 61 |
+
[DecoderLayer(config) for _ in range(config.num_hidden_layers)]
|
| 62 |
+
)
|
| 63 |
+
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 64 |
+
|
| 65 |
+
def forward(
|
| 66 |
+
self,
|
| 67 |
+
inputs_embeds: torch.Tensor,
|
| 68 |
+
position_ids: torch.Tensor,
|
| 69 |
+
active_mask: torch.Tensor,
|
| 70 |
+
cache: ShramCache | None = None,
|
| 71 |
+
output_hidden_states: bool = False,
|
| 72 |
+
) -> dict:
|
| 73 |
+
"""Run the transformer stack over a batch of pre-embedded sequences.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
inputs_embeds: Pre-embedded input of shape (batch, seq_len, hidden_size).
|
| 77 |
+
position_ids: Absolute positions of shape (batch, seq_len). Required.
|
| 78 |
+
Must be provided explicitly by the caller — this module does not
|
| 79 |
+
infer positions from cache state.
|
| 80 |
+
active_mask: Current-chunk active mask of shape (batch, seq_len),
|
| 81 |
+
where True means the token is semantically live. Forwarded
|
| 82 |
+
unchanged to every decoder layer.
|
| 83 |
+
cache: Optional top-level ShramCache. When provided, each DecoderLayer
|
| 84 |
+
receives its own layer-local cache via ``cache.layers[layer_idx]``.
|
| 85 |
+
The top-level cache object is updated in place and returned unchanged.
|
| 86 |
+
output_hidden_states: When True, the output dict includes a tuple of
|
| 87 |
+
per-layer hidden states: (inputs_embeds, layer_0_out, ..., layer_N_out),
|
| 88 |
+
collected before the final norm.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Plain dict with keys:
|
| 92 |
+
- ``"last_hidden_state"``: normed backbone output,
|
| 93 |
+
shape (batch, seq_len, hidden_size).
|
| 94 |
+
- ``"past_key_values"``: the cache object passed in, or None.
|
| 95 |
+
- ``"hidden_states"``: tuple of per-layer activations (including
|
| 96 |
+
inputs_embeds as position 0) if ``output_hidden_states`` is True,
|
| 97 |
+
else None. Collected before the final norm so each entry reflects the
|
| 98 |
+
unnormalised residual stream at that depth.
|
| 99 |
+
- ``"load_balance_loss"``: scalar sum of per-layer SHRAM
|
| 100 |
+
load-balance losses.
|
| 101 |
+
- ``"max_vio"``: detached scalar maximum routing-imbalance across
|
| 102 |
+
all decoder layers. Zero means perfectly balanced routing across
|
| 103 |
+
every layer; higher values identify the worst-case head imbalance.
|
| 104 |
+
"""
|
| 105 |
+
hidden_states = inputs_embeds
|
| 106 |
+
all_hidden_states = (hidden_states,) if output_hidden_states else None
|
| 107 |
+
total_load_balance_loss = inputs_embeds.new_zeros(())
|
| 108 |
+
max_vio = inputs_embeds.new_zeros(())
|
| 109 |
+
|
| 110 |
+
for layer_idx, layer in enumerate(self.layers):
|
| 111 |
+
layer_cache = None if cache is None else cache.layers[layer_idx]
|
| 112 |
+
hidden_states, layer_load_balance_loss, layer_max_vio = layer(
|
| 113 |
+
hidden_states,
|
| 114 |
+
position_ids,
|
| 115 |
+
active_mask,
|
| 116 |
+
cache=layer_cache,
|
| 117 |
+
)
|
| 118 |
+
total_load_balance_loss = total_load_balance_loss + layer_load_balance_loss
|
| 119 |
+
max_vio = torch.maximum(max_vio, layer_max_vio)
|
| 120 |
+
|
| 121 |
+
if output_hidden_states:
|
| 122 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 123 |
+
|
| 124 |
+
hidden_states = self.norm(hidden_states)
|
| 125 |
+
|
| 126 |
+
return {
|
| 127 |
+
"last_hidden_state": hidden_states,
|
| 128 |
+
"past_key_values": cache,
|
| 129 |
+
"hidden_states": all_hidden_states,
|
| 130 |
+
"load_balance_loss": total_load_balance_loss,
|
| 131 |
+
"max_vio": max_vio,
|
| 132 |
+
}
|
rope.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Rotary Position Embeddings (RoPE).
|
| 2 |
+
|
| 3 |
+
RoPE encodes position in the *relationship* between query and key vectors. When the
|
| 4 |
+
attention dot product Q·Kᵀ is computed, the per-position rotations cancel to produce
|
| 5 |
+
a score that depends only on the relative distance — not on absolute positions.
|
| 6 |
+
|
| 7 |
+
Two modes are supported:
|
| 8 |
+
|
| 9 |
+
default Standard RoPE with base frequency b. Each dimension pair d is assigned
|
| 10 |
+
frequency θ_d = b^{-2d/u} where u is the head dimension. The attention
|
| 11 |
+
scaling A_rope = 1.
|
| 12 |
+
|
| 13 |
+
yarn YaRN frequency interpolation for long-context extrapolation (Peng et al.,
|
| 14 |
+
"YaRN: Efficient Context Window Extension of Large Language Models", 2023,
|
| 15 |
+
§A.2). Three frequency regimes:
|
| 16 |
+
- Low-frequency dimensions (r < α): fully interpolated by scale s.
|
| 17 |
+
These dimensions have long wavelengths relative to the training window
|
| 18 |
+
and must be compressed to avoid out-of-distribution positions.
|
| 19 |
+
- High-frequency dimensions (r > β): left unchanged. Short-wavelength
|
| 20 |
+
dimensions already encode relative position accurately at any scale.
|
| 21 |
+
- Intermediate dimensions (α ≤ r ≤ β): linearly blended via ramp γ(r).
|
| 22 |
+
Returns A_rope = (0.1·ln(s)+1)². When s = 1, YaRN reduces exactly to
|
| 23 |
+
standard RoPE.
|
| 24 |
+
|
| 25 |
+
Each attention path (h_l and BEA) constructs its own RotaryEmbedding with explicit
|
| 26 |
+
parameters — no shared instance, no config reading. See Unit 5.A design decisions.
|
| 27 |
+
|
| 28 |
+
Cache sharing: all instances with identical parameters share one cos/sin table via a
|
| 29 |
+
class-level registry. The first instance that needs a particular (parameters, seq_len,
|
| 30 |
+
device, dtype) combination builds the table; all subsequent instances reference it
|
| 31 |
+
directly. This avoids redundant builds across the num_hidden_layers instances that
|
| 32 |
+
share the same parametrisation.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
import math
|
| 36 |
+
|
| 37 |
+
import torch
|
| 38 |
+
import torch.nn as nn
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
# Rotation helper
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
|
| 45 |
+
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
"""Apply the 90° rotation used in the RoPE update formula.
|
| 47 |
+
|
| 48 |
+
Splits the last dimension into two halves [x1, x2] and returns [-x2, x1].
|
| 49 |
+
Combined with ``x * cos + rotate_half(x) * sin``, this implements a 2D rotation
|
| 50 |
+
on each consecutive pair of dimensions, matching the block-diagonal operator
|
| 51 |
+
R^u_{Θ,p} in the paper.
|
| 52 |
+
"""
|
| 53 |
+
d = x.shape[-1] // 2
|
| 54 |
+
x1, x2 = x[..., :d], x[..., d:]
|
| 55 |
+
return torch.cat([-x2, x1], dim=-1)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ---------------------------------------------------------------------------
|
| 59 |
+
# RotaryEmbedding
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
|
| 62 |
+
class RotaryEmbedding(nn.Module):
|
| 63 |
+
"""Rotary Position Embeddings with explicit mode and parameter control.
|
| 64 |
+
|
| 65 |
+
Each caller constructs its own instance with the exact parameters it needs.
|
| 66 |
+
h_l always uses ``mode="default"``; BEA always uses ``mode="yarn"``. No
|
| 67 |
+
config object is read inside this module.
|
| 68 |
+
|
| 69 |
+
The cos/sin cache is built lazily on the first forward call and extended
|
| 70 |
+
automatically when a longer sequence is encountered. Instances with identical
|
| 71 |
+
parameters share one cache via the class-level ``_cache`` registry,
|
| 72 |
+
avoiding redundant computation across decoder layers.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
mode: ``"default"`` for standard RoPE; ``"yarn"`` for YaRN extrapolation.
|
| 76 |
+
head_dim: Per-head embedding dimension ``u``. Must be even.
|
| 77 |
+
theta: Base frequency ``b`` in θ_d = b^{-2d/u}.
|
| 78 |
+
initial_seq_length: ``C_train`` — context length the model was trained at.
|
| 79 |
+
Required for ``mode="yarn"``.
|
| 80 |
+
dilation: Scale factor ``s = C_target / C_train`` — how much the context
|
| 81 |
+
window is extended beyond training length. Required for ``mode="yarn"``.
|
| 82 |
+
When ``dilation=1.0``, YaRN reduces to standard RoPE.
|
| 83 |
+
alpha: YaRN ramp lower boundary α. Dimensions with r(d) < α are fully
|
| 84 |
+
interpolated. Required for ``mode="yarn"``.
|
| 85 |
+
beta: YaRN ramp upper boundary β. Dimensions with r(d) > β are left
|
| 86 |
+
unchanged. Required for ``mode="yarn"``.
|
| 87 |
+
device: Optional device for initial buffer placement.
|
| 88 |
+
|
| 89 |
+
Raises:
|
| 90 |
+
NotImplementedError: If ``mode`` is not ``"default"`` or ``"yarn"``.
|
| 91 |
+
ValueError: If ``mode="yarn"`` and any of ``initial_seq_length``,
|
| 92 |
+
``dilation``, ``alpha``, ``beta`` are absent.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
# Maps (freq_key, seq_len, device_str, dtype_str) → (cos_table, sin_table).
|
| 96 |
+
# Shared across all RotaryEmbedding instances in the process. Keys include device
|
| 97 |
+
# and dtype so that tables built on different devices or in different precisions
|
| 98 |
+
# are stored independently.
|
| 99 |
+
_cache: dict = {}
|
| 100 |
+
|
| 101 |
+
def __init__(
|
| 102 |
+
self,
|
| 103 |
+
mode: str,
|
| 104 |
+
head_dim: int,
|
| 105 |
+
theta: float,
|
| 106 |
+
initial_seq_length: int | None = None,
|
| 107 |
+
dilation: float | None = None,
|
| 108 |
+
alpha: float | None = None,
|
| 109 |
+
beta: float | None = None,
|
| 110 |
+
device: torch.device | None = None,
|
| 111 |
+
) -> None:
|
| 112 |
+
super().__init__()
|
| 113 |
+
|
| 114 |
+
self._validate_mode(mode)
|
| 115 |
+
self._validate_yarn_params(mode, initial_seq_length, dilation, alpha, beta)
|
| 116 |
+
self.mode = mode
|
| 117 |
+
|
| 118 |
+
# Compute per-dimension rotation frequencies θ_d (default) or θ_d' (yarn).
|
| 119 |
+
# d_index ranges over 0, 2, 4, ..., head_dim-2 — one index per dimension pair,
|
| 120 |
+
# so rotation_freqs has head_dim/2 entries.
|
| 121 |
+
d_index = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
| 122 |
+
base_freqs = 1.0 / (theta ** (d_index / head_dim)) # θ_d = b^{-2d/u}
|
| 123 |
+
|
| 124 |
+
if mode == "default":
|
| 125 |
+
rotation_freqs = base_freqs
|
| 126 |
+
self.attention_scaling: float = 1.0
|
| 127 |
+
|
| 128 |
+
else: # yarn
|
| 129 |
+
s = dilation
|
| 130 |
+
|
| 131 |
+
# r(d) = C_train · θ_d / (2π) — normalized frequency used by the ramp
|
| 132 |
+
# function to classify each dimension into one of three regimes.
|
| 133 |
+
normalized_freqs = initial_seq_length * base_freqs / (2.0 * math.pi)
|
| 134 |
+
|
| 135 |
+
# γ(r) ramp: 0 for r < α (fully interpolate), 1 for r > β (unchanged),
|
| 136 |
+
# linear blend between α and β.
|
| 137 |
+
blend_weights = ((normalized_freqs - alpha) / (beta - alpha)).clamp(0.0, 1.0)
|
| 138 |
+
|
| 139 |
+
# θ_d' = (1 − γ) · θ_d / s + γ · θ_d
|
| 140 |
+
rotation_freqs = (1.0 - blend_weights) * (base_freqs / s) + blend_weights * base_freqs
|
| 141 |
+
|
| 142 |
+
# A_rope = (0.1 · ln(s) + 1)² — attention logit scaling returned to caller.
|
| 143 |
+
self.attention_scaling = (0.1 * math.log(s) + 1.0) ** 2
|
| 144 |
+
|
| 145 |
+
# freq_key uniquely identifies the parameter set that produced rotation_freqs.
|
| 146 |
+
# Used as the primary component of the cache registry key.
|
| 147 |
+
if mode == "default":
|
| 148 |
+
self._freq_key: tuple = ("default", head_dim, float(theta))
|
| 149 |
+
else:
|
| 150 |
+
self._freq_key = (
|
| 151 |
+
"yarn", head_dim, float(theta),
|
| 152 |
+
int(initial_seq_length), float(dilation),
|
| 153 |
+
float(alpha), float(beta),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# rotation_freqs is a non-persistent buffer so it moves with the model across
|
| 157 |
+
# devices via .to() / .cuda() without appearing in saved checkpoints.
|
| 158 |
+
# It is stored per-instance rather than in the shared cache because it is
|
| 159 |
+
# small (head_dim/2 floats) — negligible cost compared to the cos/sin tables
|
| 160 |
+
# it is used to build. The meaningful sharing win is on those tables.
|
| 161 |
+
self.register_buffer("rotation_freqs", rotation_freqs, persistent=False)
|
| 162 |
+
|
| 163 |
+
# Cache tensors are plain instance attributes (not registered buffers) so that
|
| 164 |
+
# sharing across identically-parametrised instances survives .to() calls.
|
| 165 |
+
# Registered buffers are copied on device move; plain attributes are aliased,
|
| 166 |
+
# preserving the shared-tensor identity that the cache design depends on.
|
| 167 |
+
self._cos_cached: torch.Tensor | None = None
|
| 168 |
+
self._sin_cached: torch.Tensor | None = None
|
| 169 |
+
|
| 170 |
+
# ---------------------------------------------------------------------------
|
| 171 |
+
# Validation helpers
|
| 172 |
+
# ---------------------------------------------------------------------------
|
| 173 |
+
|
| 174 |
+
@staticmethod
|
| 175 |
+
def _validate_mode(mode: str) -> None:
|
| 176 |
+
"""Raise NotImplementedError if mode is not a supported value."""
|
| 177 |
+
if mode not in {"default", "yarn"}:
|
| 178 |
+
raise NotImplementedError(
|
| 179 |
+
f"RoPE mode '{mode}' is not supported. Supported modes: 'default', 'yarn'."
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
@staticmethod
|
| 183 |
+
def _validate_yarn_params(
|
| 184 |
+
mode: str,
|
| 185 |
+
initial_seq_length: int | None,
|
| 186 |
+
dilation: float | None,
|
| 187 |
+
alpha: float | None,
|
| 188 |
+
beta: float | None,
|
| 189 |
+
) -> None:
|
| 190 |
+
"""Raise ValueError if mode='yarn' and any required parameter is absent."""
|
| 191 |
+
if mode != "yarn":
|
| 192 |
+
return
|
| 193 |
+
missing = [
|
| 194 |
+
name for name, val in [
|
| 195 |
+
("initial_seq_length", initial_seq_length),
|
| 196 |
+
("dilation", dilation),
|
| 197 |
+
("alpha", alpha),
|
| 198 |
+
("beta", beta),
|
| 199 |
+
]
|
| 200 |
+
if val is None
|
| 201 |
+
]
|
| 202 |
+
if missing:
|
| 203 |
+
raise ValueError(f"mode='yarn' requires {missing}.")
|
| 204 |
+
|
| 205 |
+
# ---------------------------------------------------------------------------
|
| 206 |
+
# Cache management
|
| 207 |
+
# ---------------------------------------------------------------------------
|
| 208 |
+
|
| 209 |
+
def _extend_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
|
| 210 |
+
"""Build the cos/sin table to cover positions [0, seq_len).
|
| 211 |
+
|
| 212 |
+
Checks the class-level registry first. If a table already exists for this
|
| 213 |
+
exact (parameters, seq_len, device, dtype) combination it is reused directly;
|
| 214 |
+
otherwise it is computed and stored. The instance attributes are pointed at
|
| 215 |
+
the registry entry so that all layers sharing the same parametrisation
|
| 216 |
+
reference the same tensor.
|
| 217 |
+
"""
|
| 218 |
+
cache_key = (self._freq_key, seq_len, str(device), str(dtype))
|
| 219 |
+
|
| 220 |
+
if cache_key not in RotaryEmbedding._cache:
|
| 221 |
+
positions = torch.arange(seq_len, device=device, dtype=torch.float32)
|
| 222 |
+
# outer product → (seq_len, head_dim // 2); duplicate to (seq_len, head_dim)
|
| 223 |
+
freqs = torch.outer(
|
| 224 |
+
positions,
|
| 225 |
+
self.rotation_freqs.to(device=device, dtype=torch.float32),
|
| 226 |
+
)
|
| 227 |
+
angle_embedding = torch.cat((freqs, freqs), dim=-1)
|
| 228 |
+
RotaryEmbedding._cache[cache_key] = (
|
| 229 |
+
angle_embedding.cos().to(dtype),
|
| 230 |
+
angle_embedding.sin().to(dtype),
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
self._cos_cached, self._sin_cached = RotaryEmbedding._cache[cache_key]
|
| 234 |
+
|
| 235 |
+
def forward(
|
| 236 |
+
self,
|
| 237 |
+
q: torch.Tensor,
|
| 238 |
+
k: torch.Tensor,
|
| 239 |
+
position_ids: torch.Tensor,
|
| 240 |
+
) -> tuple[torch.Tensor, torch.Tensor, float]:
|
| 241 |
+
"""Apply rotary embeddings to query and key tensors.
|
| 242 |
+
|
| 243 |
+
The cos/sin cache is extended lazily when position_ids reference positions
|
| 244 |
+
beyond its current length, or when the device or dtype has changed.
|
| 245 |
+
|
| 246 |
+
``position_ids`` may be any integer tensor shape. Its values are valid
|
| 247 |
+
position indices into the cos/sin cache:
|
| 248 |
+
|
| 249 |
+
- h_l (standard causal): position_ids (B, N), q/k (B, H, N, head_dim).
|
| 250 |
+
- BEA (packed): position_ids (B, L, T), q/k (B, L, T, head_dim).
|
| 251 |
+
|
| 252 |
+
When q/k have head dimensions absent from position_ids, broadcast dimensions
|
| 253 |
+
are inserted automatically at dim 1.
|
| 254 |
+
|
| 255 |
+
Args:
|
| 256 |
+
q: Query tensor of shape (batch, [heads,] *pos_dims, head_dim).
|
| 257 |
+
k: Key tensor of shape (batch, [heads,] *pos_dims, head_dim).
|
| 258 |
+
position_ids: Integer positions of shape (batch, *pos_dims).
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
Tuple of (q_rotated, k_rotated, attention_scaling). attention_scaling is
|
| 262 |
+
1.0 for default mode; YaRN returns (0.1·ln(s)+1)² which the caller must
|
| 263 |
+
apply to attention logits before softmax.
|
| 264 |
+
"""
|
| 265 |
+
seq_len = int(position_ids.max().item()) + 1
|
| 266 |
+
|
| 267 |
+
# The cache is valid when it exists, covers all positions referenced by
|
| 268 |
+
# position_ids, and matches q's dtype and device. Each condition is named
|
| 269 |
+
# separately so the rebuild trigger is readable rather than a compound predicate.
|
| 270 |
+
cache_missing = self._cos_cached is None
|
| 271 |
+
cache_too_short = not cache_missing and seq_len > self._cos_cached.shape[0]
|
| 272 |
+
wrong_dtype = not cache_missing and self._cos_cached.dtype != q.dtype
|
| 273 |
+
wrong_device = not cache_missing and self._cos_cached.device != q.device
|
| 274 |
+
|
| 275 |
+
if cache_missing or cache_too_short or wrong_dtype or wrong_device:
|
| 276 |
+
self._extend_cache(seq_len, device=q.device, dtype=q.dtype)
|
| 277 |
+
|
| 278 |
+
cos = self._cos_cached[position_ids]
|
| 279 |
+
sin = self._sin_cached[position_ids]
|
| 280 |
+
|
| 281 |
+
# Insert broadcast dimensions for any head axes present in q/k but absent
|
| 282 |
+
# from position_ids. Standard: pos (B,N) → cos (B,N,D), q (B,H,N,D) → unsqueeze once.
|
| 283 |
+
# BEA: pos (B,L,T) → cos (B,L,T,D), q (B,L,T,D) → no unsqueeze needed.
|
| 284 |
+
while cos.ndim < q.ndim:
|
| 285 |
+
cos = cos.unsqueeze(1)
|
| 286 |
+
sin = sin.unsqueeze(1)
|
| 287 |
+
|
| 288 |
+
q_rotated = q * cos + _rotate_half(q) * sin
|
| 289 |
+
k_rotated = k * cos + _rotate_half(k) * sin
|
| 290 |
+
|
| 291 |
+
return q_rotated, k_rotated, self.attention_scaling
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"backend": "tokenizers",
|
| 4 |
+
"bos_token": "<|endoftext|>",
|
| 5 |
+
"eos_token": "<|endoftext|>",
|
| 6 |
+
"errors": "replace",
|
| 7 |
+
"is_local": false,
|
| 8 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 9 |
+
"pad_token": "<|padding|>",
|
| 10 |
+
"tokenizer_class": "GPTNeoXTokenizerFast",
|
| 11 |
+
"trim_offsets": true,
|
| 12 |
+
"unk_token": "<|endoftext|>"
|
| 13 |
+
}
|