| """
|
| VortexLocalAttention: Local windowed attention with global token support.
|
| Uses a sliding window of 512 tokens for efficiency, with special handling
|
| for global tokens that can attend across the entire sequence.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import Optional, Tuple
|
|
|
|
|
| class VortexLocalAttention(nn.Module):
|
| """
|
| Local windowed attention with window_size=512.
|
| Science documents have strong local coherence — equations reference
|
| nearby text, not distant paragraphs.
|
| Global tokens (special [SCIENCE] tokens) attend to everything.
|
| """
|
|
|
| def __init__(
|
| self,
|
| d_model: int,
|
| num_heads: int,
|
| window_size: int = 512,
|
| use_flash_attention: bool = True,
|
| ):
|
| """
|
| Initialize local windowed attention.
|
|
|
| Args:
|
| d_model: Model dimension
|
| num_heads: Number of attention heads
|
| window_size: Size of local attention window
|
| use_flash_attention: Use Flash Attention 2 if available (CUDA only)
|
| """
|
| super().__init__()
|
| self.d_model = d_model
|
| self.num_heads = num_heads
|
| self.head_dim = d_model // num_heads
|
| self.window_size = window_size
|
| self.use_flash_attention = use_flash_attention
|
|
|
| assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
|
|
|
|
| self.qkv = nn.Linear(d_model, d_model * 3, bias=False)
|
| self.out_proj = nn.Linear(d_model, d_model, bias=False)
|
|
|
|
|
| self.global_qkv = nn.Linear(d_model, d_model * 3, bias=False)
|
|
|
|
|
| self._initialize_weights()
|
|
|
| def _initialize_weights(self):
|
| """Initialize weights."""
|
| for module in [self.qkv, self.global_qkv, self.out_proj]:
|
| if hasattr(module, 'weight'):
|
| nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
|
| def forward(
|
| self,
|
| x: torch.Tensor,
|
| global_mask: Optional[torch.Tensor] = None,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| """
|
| Forward pass with local windowed attention.
|
|
|
| Args:
|
| x: Input tensor (batch, seq_len, d_model)
|
| global_mask: Boolean mask indicating which tokens are global (attend everywhere)
|
| Shape: (batch, seq_len) or None
|
| attention_mask: Optional padding mask (batch, seq_len)
|
|
|
| Returns:
|
| Output tensor (batch, seq_len, d_model)
|
| """
|
| batch, seq_len, _ = x.shape
|
| device = x.device
|
| dtype = x.dtype
|
|
|
| if global_mask is None:
|
| global_mask = torch.zeros(batch, seq_len, dtype=torch.bool, device=device)
|
|
|
|
|
| qkv = self.qkv(x)
|
| q, k, v = qkv.chunk(3, dim=-1)
|
|
|
|
|
| q = q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| k = k.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| v = v.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
| if global_mask.any():
|
| global_qkv = self.global_qkv(x)
|
| gq, gk, gv = global_qkv.chunk(3, dim=-1)
|
| gq = gq.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| gk = gk.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| gv = gv.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
| output = torch.zeros_like(x)
|
|
|
|
|
| for t in range(seq_len):
|
|
|
| window_start = max(0, t - self.window_size // 2)
|
| window_end = min(seq_len, t + self.window_size // 2 + 1)
|
| window_len = window_end - window_start
|
|
|
|
|
| window_indices = slice(window_start, window_end)
|
|
|
|
|
| q_t = q[:, :, t:t+1, :]
|
|
|
|
|
|
|
|
|
| k_window = k[:, :, window_indices, :]
|
| v_window = v[:, :, window_indices, :]
|
|
|
|
|
|
|
| if global_mask.any():
|
|
|
| global_positions = global_mask[0]
|
| if global_positions.any():
|
| gk_all = gk[:, :, :, :]
|
| gv_all = gv[:, :, :, :]
|
|
|
|
|
| k_full = torch.cat([k_window, gk_all], dim=2)
|
| v_full = torch.cat([v_window, gv_all], dim=2)
|
| else:
|
| k_full = k_window
|
| v_full = v_window
|
| else:
|
| k_full = k_window
|
| v_full = v_window
|
|
|
|
|
|
|
|
|
| attn_scores = torch.matmul(q_t, k_full.transpose(-2, -1)) / (self.head_dim ** 0.5)
|
|
|
|
|
|
|
| if attention_mask is not None:
|
| mask_t = attention_mask[:, window_indices].unsqueeze(1).unsqueeze(2)
|
| attn_scores = attn_scores.masked_fill(mask_t == 0, -1e9)
|
|
|
|
|
| attn_weights = F.softmax(attn_scores, dim=-1)
|
|
|
|
|
| attn_output = torch.matmul(attn_weights, v_full)
|
|
|
|
|
|
|
| attn_output = attn_output.transpose(1, 2).contiguous()
|
| attn_output = attn_output.view(batch, 1, self.d_model)
|
| attn_output = self.out_proj(attn_output)
|
|
|
|
|
| output[:, t:t+1, :] = attn_output
|
|
|
| return output
|
|
|
| def forward_optimized(
|
| self,
|
| x: torch.Tensor,
|
| global_mask: Optional[torch.Tensor] = None,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| """
|
| Optimized forward pass using Flash Attention or efficient windowed attention.
|
| This is a placeholder for actual Flash Attention integration.
|
| """
|
| batch, seq_len, _ = x.shape
|
|
|
| if self.use_flash_attention and self.window_size >= seq_len:
|
|
|
| return self._flash_attention_forward(x, attention_mask)
|
| else:
|
|
|
| return self._windowed_attention_forward(x, global_mask, attention_mask)
|
|
|
| def _flash_attention_forward(
|
| self,
|
| x: torch.Tensor,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| """
|
| Use Flash Attention 2 if available.
|
| Requires: pip install flash-attn
|
| """
|
| try:
|
| from flash_attn import flash_attn_func
|
|
|
| batch, seq_len, _ = x.shape
|
| qkv = self.qkv(x)
|
| q, k, v = qkv.chunk(3, dim=-1)
|
|
|
|
|
| q = q.view(batch, seq_len, self.num_heads, self.head_dim)
|
| k = k.view(batch, seq_len, self.num_heads, self.head_dim)
|
| v = v.view(batch, seq_len, self.num_heads, self.head_dim)
|
|
|
|
|
|
|
| if attention_mask is not None:
|
|
|
| output = flash_attn_func(
|
| q, k, v,
|
| causal=False,
|
| softmax_scale=1.0 / (self.head_dim ** 0.5),
|
| )
|
| else:
|
| output = flash_attn_func(
|
| q, k, v,
|
| causal=False,
|
| )
|
|
|
| output = output.view(batch, seq_len, self.d_model)
|
| return self.out_proj(output)
|
|
|
| except ImportError:
|
| print("Flash Attention not available, falling back to standard attention")
|
| return self._standard_attention(x, attention_mask)
|
|
|
| def _standard_attention(
|
| self,
|
| x: torch.Tensor,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| """Standard full attention (quadratic)."""
|
| batch, seq_len, _ = x.shape
|
| qkv = self.qkv(x)
|
| q, k, v = qkv.chunk(3, dim=-1)
|
|
|
| q = q.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| k = k.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| v = v.view(batch, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
| attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
|
|
|
| if attention_mask is not None:
|
| attn_scores = attn_scores.masked_fill(
|
| attention_mask.unsqueeze(1).unsqueeze(2) == 0,
|
| -1e9
|
| )
|
|
|
| attn_weights = F.softmax(attn_scores, dim=-1)
|
| attn_output = torch.matmul(attn_weights, v)
|
|
|
| attn_output = attn_output.transpose(1, 2).contiguous()
|
| attn_output = attn_output.view(batch, seq_len, self.d_model)
|
| return self.out_proj(attn_output)
|
|
|
| def _windowed_attention_forward(
|
| self,
|
| x: torch.Tensor,
|
| global_mask: Optional[torch.Tensor] = None,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| """
|
| Efficient windowed attention implementation.
|
| Uses unfold to extract windows and batched matrix multiply.
|
| """
|
| batch, seq_len, _ = x.shape
|
| device = x.device
|
|
|
| if global_mask is None:
|
| global_mask = torch.zeros(batch, seq_len, dtype=torch.bool, device=device)
|
|
|
|
|
| qkv = self.qkv(x)
|
| q, k, v = qkv.chunk(3, dim=-1)
|
|
|
|
|
| q = q.view(batch, seq_len, self.num_heads, self.head_dim)
|
| k = k.view(batch, seq_len, self.num_heads, self.head_dim)
|
| v = v.view(batch, seq_len, self.num_heads, self.head_dim)
|
|
|
|
|
| pad_len = self.window_size // 2
|
| k_padded = F.pad(k, (0, 0, 0, 0, pad_len, pad_len))
|
| v_padded = F.pad(v, (0, 0, 0, 0, pad_len, pad_len))
|
|
|
|
|
|
|
| k_windows = k_padded.unfold(1, self.window_size, 1)
|
| v_windows = v_padded.unfold(1, self.window_size, 1)
|
|
|
|
|
| k_windows = k_windows.permute(0, 1, 3, 2, 4)
|
| v_windows = v_windows.permute(0, 1, 3, 2, 4)
|
|
|
|
|
|
|
| q_expanded = q.unsqueeze(3)
|
| k_windows = k_windows
|
|
|
|
|
| attn_scores = torch.matmul(q_expanded, k_windows.transpose(-2, -1)) / (self.head_dim ** 0.5)
|
| attn_scores = attn_scores.squeeze(3)
|
|
|
|
|
| attn_weights = F.softmax(attn_scores, dim=-1)
|
|
|
|
|
| attn_output = torch.matmul(attn_weights.unsqueeze(3), v_windows).squeeze(3)
|
|
|
|
|
|
|
| attn_output = attn_output.view(batch, seq_len, self.d_model)
|
|
|
|
|
| if global_mask.any():
|
|
|
|
|
| global_indices = global_mask[0].nonzero(as_tuple=True)[0]
|
| if len(global_indices) > 0:
|
|
|
|
|
| full_attn = self._standard_attention(x, attention_mask)
|
|
|
| attn_output = torch.where(
|
| global_mask.unsqueeze(-1),
|
| full_attn,
|
| attn_output
|
| )
|
|
|
| return self.out_proj(attn_output)
|
|
|
|
|
| def test_vortex_local_attention():
|
| """Test the VortexLocalAttention layer."""
|
| batch_size = 2
|
| seq_len = 256
|
| d_model = 4096
|
| num_heads = 32
|
| window_size = 512
|
|
|
| attn = VortexLocalAttention(d_model, num_heads, window_size, use_flash_attention=False)
|
| x = torch.randn(batch_size, seq_len, d_model)
|
|
|
|
|
| output = attn(x)
|
| print(f"Input shape: {x.shape}")
|
| print(f"Output shape: {output.shape}")
|
| assert output.shape == x.shape, f"Expected {x.shape}, got {output.shape}"
|
|
|
|
|
| global_mask = torch.zeros(batch_size, seq_len, dtype=torch.bool)
|
| global_mask[0, 0] = True
|
| global_mask[1, -1] = True
|
| output2 = attn(x, global_mask=global_mask)
|
| assert output2.shape == x.shape
|
|
|
| print("VortexLocalAttention test passed!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_vortex_local_attention()
|
|
|