| import torch |
| import torch._dynamo as dynamo |
| from torch.nn.attention.flex_attention import ( |
| create_block_mask, |
| flex_attention, |
| or_masks, |
| ) |
| from transformers.utils import is_torchdynamo_compiling |
|
|
| dynamo.config.recompile_limit = 64 |
|
|
|
|
| |
| class WrappedFlexAttention: |
| """ |
| We are doing a singleton class so that flex attention is compiled once when it's first called. |
| """ |
|
|
| _instance = None |
| _is_flex_compiled = False |
| _compiled_flex_attention = None |
|
|
| def __new__(cls, *args, **kwargs): |
| if cls._instance is None: |
| |
| cls._instance = super().__new__(cls) |
| return cls._instance |
|
|
| @torch.compiler.disable(recursive=False) |
| def __init__(self): |
| """ |
| Initialize or update the singleton instance. |
| """ |
| if not self._is_flex_compiled: |
| |
| self._compiled_flex_attention = torch.compile( |
| flex_attention, |
| |
| ) |
| self._is_flex_compiled = True |
|
|
| def __call__(self): |
| return self._compiled_flex_attention |
|
|
|
|
| def compile_friendly_flex_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| **kwargs, |
| ) -> torch.Tensor: |
| |
| |
| flex_attention_compiled = ( |
| WrappedFlexAttention()() if not is_torchdynamo_compiling() else flex_attention |
| ) |
| return flex_attention_compiled( |
| query, |
| key, |
| value, |
| **kwargs, |
| ) |
|
|
|
|
| class WrappedCreateBlockMask: |
| _instance = None |
| _is_create_block_mask_compiled = False |
| _compiled_create_block_mask = None |
|
|
| def __new__(cls, *args, **kwargs): |
| if cls._instance is None: |
| cls._instance = super().__new__(cls) |
| return cls._instance |
|
|
| @torch.compiler.disable(recursive=False) |
| def __init__(self): |
| if not self._is_create_block_mask_compiled: |
| self._compiled_create_block_mask = torch.compile(create_block_mask) |
| self._is_create_block_mask_compiled = True |
|
|
| def __call__(self): |
| return self._compiled_create_block_mask |
|
|
|
|
| def compile_friendly_create_block_mask( |
| mask_mod, |
| B, |
| H, |
| Q_LEN, |
| KV_LEN, |
| device, |
| ): |
| create_block_mask_compiled = ( |
| WrappedCreateBlockMask()() |
| if not is_torchdynamo_compiling() |
| else create_block_mask |
| ) |
| return create_block_mask_compiled( |
| mask_mod, |
| B, |
| H, |
| Q_LEN, |
| KV_LEN, |
| device, |
| ) |
|
|
|
|
| def generate_eagle3_mask( |
| seq_lengths: torch.Tensor, Q_LEN: int, KV_LEN: int, lck: int = 0 |
| ): |
|
|
| def causal_mask(b, h, q_idx, kv_idx): |
| |
| |
| causal_mask = q_idx >= kv_idx |
| padding_mask = (kv_idx < seq_lengths[b]) & (q_idx < seq_lengths[b]) |
| return causal_mask & padding_mask |
|
|
| def suffix_mask(b, h, q_idx, kv_idx): |
| suffix_mask = kv_idx >= Q_LEN |
| padding_mask = kv_idx % Q_LEN < seq_lengths[b] |
| diagnol_mask = (kv_idx - q_idx) % Q_LEN == 0 |
| return suffix_mask & padding_mask & diagnol_mask |
|
|
| mask_mod = or_masks(causal_mask, suffix_mask) |
| mask_mod.__name__ = f"eagle3_mask_Q_{Q_LEN}_KV_{KV_LEN}_lck_{lck}" |
| return mask_mod |
|
|