Text Generation
Transformers
PyTorch
English
shram
research
sparse-attention
mixture-of-experts
custom_code
Instructions to use smithblack-0/SHRAM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use smithblack-0/SHRAM with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="smithblack-0/SHRAM", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("smithblack-0/SHRAM", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use smithblack-0/SHRAM with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "smithblack-0/SHRAM" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/smithblack-0/SHRAM
- SGLang
How to use smithblack-0/SHRAM with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "smithblack-0/SHRAM" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "smithblack-0/SHRAM", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use smithblack-0/SHRAM with Docker Model Runner:
docker model run hf.co/smithblack-0/SHRAM
| """Unvectorized reference implementation of the MoSRAH sparse KV cache. | |
| This module exists solely as a correctness oracle. SlowMoSRAHCache implements the same | |
| interface and storage layout as MoSRAHCache but uses an explicit Python loop over | |
| (b, l, t) triples in update(). The loop is obviously correct by inspection: each active | |
| position's key and value are written to the next available slot for that (batch, head) | |
| pair, in the order positions appear along the T dimension, which directly enforces | |
| causal ordering without any index arithmetic to verify. | |
| SlowMoSRAHCache is never instantiated in the model path. Its role is to provide a | |
| trusted ground truth against which the vectorized MoSRAHCache.update() is validated in | |
| Unit 6.A tests, and as a reference for the Unit 10.A position decoder. Because the | |
| vectorized implementation is validated by asserting exact agreement with this one on all | |
| test inputs, the correctness of SlowMoSRAHCache is load-bearing: its own test suite | |
| (test_slow_mosrah_cache.py) must establish it is trustworthy before it can be used as | |
| an oracle. | |
| """ | |
| import torch | |
| from transformers.cache_utils import CacheLayerMixin | |
| class SlowMoSRAHCache(CacheLayerMixin): | |
| """Unvectorized reference implementation of the MoSRAH KV cache. | |
| Identical storage layout to MoSRAHCache: (B, L, T, u) tensors in the | |
| mixin-standard self.keys and self.values attributes, plus a (B, L) _counts tensor, | |
| with the same constructor signature and the same CacheLayerMixin protocol methods. | |
| The sole difference is update(), which uses an explicit Python loop over (b, l, t) | |
| triples rather than vectorized index arithmetic. | |
| This class is not used in the model path. It exists so that MoSRAHCache.update() | |
| can be validated by asserting exact agreement with this implementation on all test | |
| inputs. See module docstring for the trust chain this enables. | |
| Args: | |
| num_mosrah_heads: Total number of MoSRAH expert heads (L). Determines the | |
| second dimension of all storage tensors. | |
| head_dim: Bottlenecked head embedding width (u). Determines the fourth | |
| dimension of all storage tensors. | |
| batch_size: Number of sequences in the batch. Determines the first dimension | |
| of all storage tensors. | |
| device: Device on which to allocate all tensors. Should match the model device. | |
| initial_buffer_size: Initial sequence capacity per (batch, head) slot. Doubled | |
| when any slot overflows. Defaults to 64 to avoid repeated reallocation | |
| during prompt processing. | |
| """ | |
| is_compileable = False | |
| is_sliding = False | |
| def __init__( | |
| self, | |
| num_mosrah_heads: int, | |
| head_dim: int, | |
| batch_size: int, | |
| device: torch.device, | |
| initial_buffer_size: int = 64, | |
| ) -> None: | |
| super().__init__() | |
| self.num_mosrah_heads = num_mosrah_heads | |
| self.head_dim = head_dim | |
| self.batch_size = batch_size | |
| self.device = device | |
| # Allocate primary storage into the mixin-standard self.keys / self.values so | |
| # that inherited methods (offload, prefetch) operate on real tensors. _counts | |
| # tracks valid occupancy per (batch, head) slot. | |
| self.keys: torch.Tensor = torch.zeros( | |
| batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device | |
| ) | |
| self.values: torch.Tensor = torch.zeros( | |
| batch_size, num_mosrah_heads, initial_buffer_size, head_dim, device=device | |
| ) | |
| self._counts: torch.Tensor = torch.zeros( | |
| batch_size, num_mosrah_heads, dtype=torch.long, device=device | |
| ) | |
| # Storage is fully allocated at construction — the cache is initialized. | |
| self.is_initialized = True | |
| # --------------------------------------------------------------------------- | |
| # Properties | |
| # --------------------------------------------------------------------------- | |
| def buffer_capacity(self) -> int: | |
| """Current number of slots allocated per (batch, head) pair. | |
| Derived directly from self.keys rather than tracked separately, so it is | |
| always consistent with the actual buffer after expansion. | |
| """ | |
| return self.keys.shape[2] | |
| # --------------------------------------------------------------------------- | |
| # Primary API | |
| # --------------------------------------------------------------------------- | |
| def update( # type: ignore[override] | |
| self, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| active_mask: torch.Tensor, | |
| cache_kwargs: dict | None = None, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """Scatter active key/value states using an explicit loop; return full cache state. | |
| Iterates over every (b, l, t) triple. For each position where active_mask is | |
| True, the key and value are written to the next available slot for that | |
| (batch, head) pair and the count is incremented. Causal ordering is guaranteed | |
| because the t dimension is traversed from 0 to T-1 and counts are updated | |
| immediately after each write. | |
| Buffer expansion (doubling buffer_capacity) is triggered before any writes if | |
| the incoming tokens would cause any slot to overflow the current capacity. | |
| Args: | |
| key_states: Shape (B, L, T, u) — post-RoPE key vectors in expert-choice layout. | |
| value_states: Shape (B, L, T, u) — value vectors in expert-choice layout. | |
| active_mask: Shape (B, L, T) bool — True for real tokens, False for padding. | |
| cache_kwargs: Unused; present to satisfy the CacheLayerMixin signature. | |
| Returns: | |
| Tuple of (keys, values, active_mask): | |
| keys: (B, L, T, u) float — full key buffer including junk slots. | |
| values: (B, L, T, u) float — full value buffer including junk slots. | |
| active_mask: (B, L, T) bool — True iff slot (b, l, t) has been written. | |
| """ | |
| B, L, T = active_mask.shape | |
| # Expansion check uses the total active tokens per slot, same as the | |
| # vectorized implementation, so both expand under identical conditions. | |
| incoming_delta = active_mask.long().sum(dim=2) # (B, L) | |
| if (self._counts + incoming_delta).max().item() > self.buffer_capacity: | |
| self._expand() | |
| # Write each active position into the next available slot for its (batch, head) | |
| # pair. Iterating t from 0 to T-1 preserves causal ordering within each slot. | |
| for b in range(B): | |
| for l in range(L): | |
| for t in range(T): | |
| if active_mask[b, l, t]: | |
| pos = self._counts[b, l].item() | |
| self.keys[b, l, pos, :] = key_states[b, l, t, :] | |
| self.values[b, l, pos, :] = value_states[b, l, t, :] | |
| self._counts[b, l] += 1 | |
| return self.keys, self.values, self._make_active_mask() | |
| def get_heads_lengths(self) -> torch.Tensor: | |
| """Return the per-(batch, head) token count for this layer. | |
| This is the authoritative occupancy tensor consumed by BEA for attention | |
| masking and by position computation (Unit 10.A) for semantic-sequence | |
| position computation. | |
| Returns: | |
| Integer tensor of shape (B, L) where entry [b, h] is the number of valid | |
| tokens stored in the (b, h) slot. Zero for slots with no writes yet. | |
| """ | |
| return self._counts | |
| # --------------------------------------------------------------------------- | |
| # CacheLayerMixin — overridden coordination methods | |
| # --------------------------------------------------------------------------- | |
| def reset(self) -> None: | |
| """Clear all cached key and value tensors. | |
| Zeroes self.keys, self.values, and _counts in place. Storage remains allocated | |
| and is_initialized remains True — only the contents are cleared. | |
| """ | |
| self.keys.zero_() | |
| self.values.zero_() | |
| self._counts.zero_() | |
| def reorder_cache(self, beam_idx: torch.LongTensor) -> None: | |
| """Reorder the batch dimension of all cached tensors for beam search. | |
| Applied atomically across self.keys, self.values, and _counts. Beam search | |
| must reorder all three together or the occupancy counts and buffer contents | |
| will correspond to different beam hypotheses. | |
| Overrides the parent because the parent's implementation calls get_seq_length(), | |
| which is not supported for this cache. | |
| Args: | |
| beam_idx: Permutation indices of shape (batch,) produced by the beam | |
| search algorithm. | |
| """ | |
| self.keys = self.keys[beam_idx] | |
| self.values = self.values[beam_idx] | |
| self._counts = self._counts[beam_idx] | |
| def batch_repeat_interleave(self, repeats: int) -> None: | |
| """Expand the batch dimension by repeating each entry repeats times. | |
| Used at beam search initialisation to expand the cache from batch size B to | |
| B * repeats, matching the expanded beam candidate batch. Applied atomically | |
| across keys, values, and _counts; batch_size is updated to reflect the new size. | |
| Args: | |
| repeats: Number of times to repeat each batch entry. | |
| """ | |
| self.keys = self.keys.repeat_interleave(repeats, dim=0) | |
| self.values = self.values.repeat_interleave(repeats, dim=0) | |
| self._counts = self._counts.repeat_interleave(repeats, dim=0) | |
| self.batch_size = self.batch_size * repeats | |
| def batch_select_indices(self, indices: torch.Tensor) -> None: | |
| """Select a subset of batch entries by index. | |
| Used in contrastive search to retain only the selected candidate entries. | |
| Applied atomically across keys, values, and _counts; batch_size is updated | |
| to reflect the number of retained entries. | |
| Args: | |
| indices: 1-D integer tensor of batch indices to retain. | |
| """ | |
| self.keys = self.keys[indices] | |
| self.values = self.values[indices] | |
| self._counts = self._counts[indices] | |
| self.batch_size = indices.shape[0] | |
| def offload(self) -> None: | |
| """Offload all cached tensors to CPU. | |
| Extends the parent to also offload _counts, which the parent does not know | |
| about. All three tensors are moved atomically so device state remains consistent. | |
| """ | |
| super().offload() | |
| self._counts = self._counts.to("cpu", non_blocking=True) | |
| def prefetch(self) -> None: | |
| """Move all cached tensors back to the model device ahead of time. | |
| Extends the parent to also prefetch _counts, which the parent does not know | |
| about. _counts is synced to self.keys.device after the parent moves keys and | |
| values, so all three remain consistent. | |
| """ | |
| super().prefetch() | |
| if self._counts.device != self.keys.device: | |
| self._counts = self._counts.to(self.keys.device, non_blocking=True) | |
| def lazy_initialization( # type: ignore[override] | |
| self, key_states: torch.Tensor, value_states: torch.Tensor | |
| ) -> None: | |
| """No-op — storage is fully allocated at construction time.""" | |
| pass | |
| # --------------------------------------------------------------------------- | |
| # CacheLayerMixin — unsupported abstract methods | |
| # --------------------------------------------------------------------------- | |
| def get_seq_length(self) -> int: # type: ignore[override] | |
| """Not supported — no single sequence length represents this cache's state. | |
| MoSRAH heads accumulate independently; (batch, head) slots have different | |
| lengths depending on routing history. There is no meaningful scalar summary. | |
| Use get_heads_lengths() for per-head occupancy. | |
| """ | |
| raise NotImplementedError( | |
| "SlowMoSRAHCache has no single sequence length. " | |
| "Use get_heads_lengths() for per-head occupancy." | |
| ) | |
| def get_max_cache_shape(self) -> int: # type: ignore[override] | |
| """Not supported — SlowMoSRAHCache is dynamic and unbounded.""" | |
| raise NotImplementedError( | |
| "SlowMoSRAHCache is unbounded; get_max_cache_shape() is not supported." | |
| ) | |
| def get_mask_sizes( # type: ignore[override] | |
| self, | |
| cache_position: torch.Tensor, | |
| ) -> tuple[int, int]: | |
| """Not supported — SlowMoSRAHCache does not participate in HF mask construction.""" | |
| raise NotImplementedError( | |
| "SlowMoSRAHCache does not support get_mask_sizes()." | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Internal helpers | |
| # --------------------------------------------------------------------------- | |
| def _make_active_mask(self) -> torch.Tensor: | |
| """Construct the (B, L, T) active mask from current counts. | |
| Returns True at position [b, l, t] iff t < _counts[b, l], i.e. the slot | |
| has been written. Positions at or beyond the count are junk and must be | |
| excluded by downstream attention. | |
| """ | |
| cap = self.buffer_capacity | |
| return ( | |
| torch.arange(cap, device=self.keys.device) | |
| .expand(self.batch_size, self.num_mosrah_heads, cap) | |
| < self._counts.unsqueeze(-1) | |
| ) | |
| def _expand(self) -> None: | |
| """Double the buffer capacity, preserving existing data. | |
| Called by update() when an incoming batch of tokens would cause any | |
| (batch, head) slot to exceed the current buffer capacity. All existing | |
| key and value data is copied into the low half of the new buffer; the | |
| high half is zero-initialised and will be filled by subsequent writes. | |
| After reassignment, buffer_capacity reflects the new size automatically. | |
| """ | |
| old_cap = self.buffer_capacity | |
| new_cap = old_cap * 2 | |
| dev = self.keys.device | |
| new_keys = torch.zeros( | |
| self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev | |
| ) | |
| new_values = torch.zeros( | |
| self.batch_size, self.num_mosrah_heads, new_cap, self.head_dim, device=dev | |
| ) | |
| new_keys[:, :, :old_cap, :] = self.keys | |
| new_values[:, :, :old_cap, :] = self.values | |
| self.keys = new_keys | |
| self.values = new_values | |