ggunio's picture
Upload folder using huggingface_hub
ff85374 verified
"""
Intelligent Tokenizer v6.2.0 - 6-Layer Decoder with Multi-Level Cross-Attention
Incorporates GPT-5 suggestions for KV cache optimization
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, Tuple
import math
class KVCacheOptimizedAttention(nn.Module):
"""
KV Cache Optimized Attention - GPT-5 suggestion
16Q β†’ 2K/V for 8x memory reduction
"""
def __init__(self, hidden_dim: int = 1280, num_heads: int = 16, kv_compression: int = 8):
super().__init__()
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.kv_heads = max(2, num_heads // kv_compression) # 16/8 = 2 KV heads
self.head_dim = hidden_dim // num_heads # 80
# Query uses all heads
self.q_proj = nn.Linear(hidden_dim, hidden_dim) # 16 heads
# Key/Value use fewer heads (GPT-5 suggestion)
self.k_proj = nn.Linear(hidden_dim, self.kv_heads * self.head_dim) # 2 heads
self.v_proj = nn.Linear(hidden_dim, self.kv_heads * self.head_dim) # 2 heads
# Output projection
self.o_proj = nn.Linear(hidden_dim, hidden_dim)
# KV cache for inference
self.register_buffer('cached_keys', None)
self.register_buffer('cached_values', None)
def forward(self,
hidden_states: torch.Tensor,
encoder_hidden: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = False) -> Tuple[torch.Tensor, Optional[Tuple]]:
"""
Forward pass with KV cache optimization
"""
batch_size, seq_len = hidden_states.shape[:2]
# Query projection (all heads)
Q = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim)
Q = Q.transpose(1, 2) # [batch, heads, seq, dim]
# Key/Value source (self or cross)
kv_source = encoder_hidden if encoder_hidden is not None else hidden_states
# Key/Value projection (fewer heads)
K = self.k_proj(kv_source).view(batch_size, -1, self.kv_heads, self.head_dim)
V = self.v_proj(kv_source).view(batch_size, -1, self.kv_heads, self.head_dim)
K = K.transpose(1, 2) # [batch, kv_heads, seq, dim]
V = V.transpose(1, 2)
# Repeat KV heads to match Q heads (broadcast)
K = K.repeat_interleave(self.num_heads // self.kv_heads, dim=1)
V = V.repeat_interleave(self.num_heads // self.kv_heads, dim=1)
# Cache management for incremental generation (GPT suggestion)
if use_cache:
# For incremental generation, only process new token
if self.cached_keys is not None and hidden_states.size(1) == 1:
# Append new K/V to cache
K = torch.cat([self.cached_keys, K], dim=2)
V = torch.cat([self.cached_values, V], dim=2)
# Update cache
self.cached_keys = K
self.cached_values = V
# Scaled dot-product attention
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
# Use additive mask (GPT suggestion)
if attention_mask is not None:
scores = scores + attention_mask # additive mask: -inf where masked, 0 elsewhere
attn_weights = F.softmax(scores, dim=-1)
attn_output = torch.matmul(attn_weights, V)
# Reshape and project
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.hidden_dim)
output = self.o_proj(attn_output)
return output, (K, V) if use_cache else None
class SelectiveCrossAttention(nn.Module):
"""
Selective cross-attention - only attend to relevant encoder layers
Reduces 24 β†’ 8 cross-attentions for efficiency
"""
def __init__(self, hidden_dim: int = 1280, layer_id: int = 0):
super().__init__()
self.hidden_dim = hidden_dim
self.layer_id = layer_id
# Define which encoder layers this decoder layer should attend to
self.encoder_connections = {
0: [0], # Decoder L0 β†’ Encoder L0 (byte info)
1: [0], # Decoder L1 β†’ Encoder L0 (byte info)
2: [1, 2], # Decoder L2 β†’ Encoder L1,2 (language info)
3: [1, 2], # Decoder L3 β†’ Encoder L1,2 (language info)
4: [3], # Decoder L4 β†’ Encoder L3 (semantic info)
5: [3], # Decoder L5 β†’ Encoder L3 (semantic info)
}
# Get connections for this layer
self.connected_layers = self.encoder_connections.get(layer_id, [0])
# Create attention modules only for connected layers
self.cross_attentions = nn.ModuleList([
KVCacheOptimizedAttention(hidden_dim, num_heads=16, kv_compression=8)
for _ in self.connected_layers
])
# Lightweight fusion with weighted sum (GPT suggestion)
self.fusion = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.SiLU(),
nn.Dropout(0.1)
)
# Learnable weights for connected layers only
self.layer_weights = nn.Parameter(torch.ones(len(self.connected_layers)) / len(self.connected_layers))
def forward(self,
decoder_hidden: torch.Tensor,
encoder_all_hidden: List[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Selectively attend to relevant encoder layers only
"""
# Only attend to connected encoder layers
cross_outputs = []
for i, layer_idx in enumerate(self.connected_layers):
if layer_idx < len(encoder_all_hidden):
encoder_hidden = encoder_all_hidden[layer_idx]
cross_out, _ = self.cross_attentions[i](
hidden_states=decoder_hidden,
encoder_hidden=encoder_hidden,
attention_mask=attention_mask
)
cross_outputs.append(cross_out)
# Weighted sum fusion for connected layers only
if len(cross_outputs) > 1:
weighted_outputs = torch.stack(cross_outputs, dim=0) # [N, batch, seq, hidden]
weights = F.softmax(self.layer_weights, dim=0).view(-1, 1, 1, 1)
fused = (weighted_outputs * weights).sum(dim=0) # [batch, seq, hidden]
else:
# Single connection - no fusion needed
fused = cross_outputs[0] if cross_outputs else decoder_hidden
# Apply lightweight fusion layer
fused = self.fusion(fused)
return fused
class SwiGLU(nn.Module):
"""SwiGLU activation for better convergence (GPT suggestion)"""
def __init__(self, dim: int, mult: float = 2.66):
super().__init__()
inner = int(round(dim * mult / 2)) * 2 # Even alignment
self.w1 = nn.Linear(dim, inner // 2)
self.w2 = nn.Linear(dim, inner // 2)
self.w3 = nn.Linear(inner // 2, dim)
def forward(self, x):
return self.w3(F.silu(self.w1(x)) * self.w2(x))
class DecoderLayer(nn.Module):
"""
Single decoder layer with self-attention and selective cross-attention
"""
def __init__(self, hidden_dim: int = 1280, num_heads: int = 16, layer_id: int = 0):
super().__init__()
self.hidden_dim = hidden_dim
self.layer_id = layer_id
# Self-attention (with KV cache optimization)
self.self_attn = KVCacheOptimizedAttention(hidden_dim, num_heads, kv_compression=8)
self.self_attn_norm = nn.LayerNorm(hidden_dim)
# Selective cross-attention to specific encoder layers
self.cross_attn = SelectiveCrossAttention(hidden_dim, layer_id=layer_id)
self.cross_attn_norm = nn.LayerNorm(hidden_dim)
# Feed-forward network with SwiGLU (GPT suggestion)
self.ffn = SwiGLU(hidden_dim, mult=2.66)
self.ffn_norm = nn.LayerNorm(hidden_dim)
# Dropout for residual connections
self.dropout = nn.Dropout(0.1)
def forward(self,
hidden_states: torch.Tensor,
encoder_all_hidden: List[torch.Tensor],
self_attention_mask: Optional[torch.Tensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = False) -> Tuple[torch.Tensor, Optional[Tuple]]:
"""
Forward pass through decoder layer
"""
# Self-attention with residual
residual = hidden_states
hidden_states = self.self_attn_norm(hidden_states)
self_attn_out, cache = self.self_attn(
hidden_states,
attention_mask=self_attention_mask,
use_cache=use_cache
)
hidden_states = residual + self.dropout(self_attn_out)
# Cross-attention with residual
residual = hidden_states
hidden_states = self.cross_attn_norm(hidden_states)
cross_attn_out = self.cross_attn(
hidden_states,
encoder_all_hidden,
attention_mask=cross_attention_mask
)
hidden_states = residual + self.dropout(cross_attn_out)
# FFN with residual
residual = hidden_states
hidden_states = self.ffn_norm(hidden_states)
ffn_out = self.ffn(hidden_states)
hidden_states = residual + self.dropout(ffn_out)
return hidden_states, cache
class DecoderV62(nn.Module):
"""
6-Layer Decoder with Multi-Level Cross-Attention
Reduced from 8 layers but compensated with better cross-attention
"""
def __init__(self, config: Optional[Dict] = None):
super().__init__()
# Configuration
self.hidden_dim = 1280
self.num_heads = 16
self.num_layers = 6 # Reduced from 8
self.vocab_size = 260 # 256 bytes + special tokens
self.max_seq_len = 48
# Token constants (GPT suggestion - explicit constants)
self.PAD = 256
self.BOS = 257
self.EOS = 258
self.MASK = 259
# Token embedding and position encoding
self.token_embedding = nn.Embedding(self.vocab_size, self.hidden_dim)
self.position_embedding = nn.Embedding(self.max_seq_len, self.hidden_dim)
# 6 decoder layers with layer-specific cross-attention
self.layers = nn.ModuleList([
DecoderLayer(self.hidden_dim, self.num_heads, layer_id=i)
for i in range(self.num_layers)
])
# Output projection
self.output_norm = nn.LayerNorm(self.hidden_dim)
self.output_projection = nn.Linear(self.hidden_dim, self.vocab_size)
# Monitoring (GPT-5 suggestion)
# Track importance of ENCODER layers (4) used by decoder
self.register_buffer('layer_importance', torch.zeros(4)) # Track importance of 4 encoder layers
def forward(self,
encoder_all_hidden: List[torch.Tensor],
decoder_input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
past_key_values: Optional[List] = None) -> Dict[str, torch.Tensor]:
"""
Forward pass through decoder
Args:
encoder_all_hidden: All encoder layer outputs (4 layers)
decoder_input_ids: Input token IDs for teacher forcing
attention_mask: Attention mask
use_cache: Whether to cache KV for inference
past_key_values: Cached KV from previous steps
"""
batch_size = encoder_all_hidden[0].size(0)
device = encoder_all_hidden[0].device
# If no decoder input, start with compressed representation
if decoder_input_ids is None:
# Use encoder's final compressed output as starting point
hidden_states = encoder_all_hidden[-1] # [batch, M tokens, 1280]
seq_len = hidden_states.size(1)
else:
# Teacher forcing mode: use provided tokens
seq_len = decoder_input_ids.size(1)
# Embeddings
token_embeds = self.token_embedding(decoder_input_ids)
position_ids = torch.arange(seq_len, device=device).expand(batch_size, -1)
position_embeds = self.position_embedding(position_ids)
hidden_states = token_embeds + position_embeds
# Create causal mask for self-attention (additive mask - GPT suggestion)
causal_mask = torch.full((1, 1, seq_len, seq_len), float('-inf'), device=device)
causal_mask = torch.triu(causal_mask, diagonal=1) # [1, 1, seq, seq]
# Pass through decoder layers
all_hidden_states = []
all_caches = [] if use_cache else None
for i, layer in enumerate(self.layers):
# GPT final check: Create proper cross-attention mask for encoder hidden states
if encoder_all_hidden is not None and len(encoder_all_hidden) > 0:
S_enc = encoder_all_hidden[0].size(1) # Encoder sequence length
# Create additive mask (0 = attend, -inf = mask)
cross_mask = torch.zeros((batch_size, 1, 1, S_enc), device=hidden_states.device)
else:
cross_mask = None
hidden_states, cache = layer(
hidden_states,
encoder_all_hidden,
self_attention_mask=causal_mask,
cross_attention_mask=cross_mask, # Use proper cross mask
use_cache=use_cache
)
all_hidden_states.append(hidden_states)
if use_cache:
all_caches.append(cache)
# Final output projection
hidden_states = self.output_norm(hidden_states)
logits = self.output_projection(hidden_states)
# Update monitoring: track encoder layer importance
# (This would be computed based on cross-attention weights in practice)
with torch.no_grad():
# Simplified: assume equal importance for now
self.layer_importance = torch.tensor([0.25, 0.25, 0.25, 0.25])
outputs = {
'logits': logits,
'last_hidden_state': hidden_states,
'all_hidden_states': all_hidden_states,
'encoder_layer_importance': self.layer_importance
}
if use_cache:
outputs['past_key_values'] = all_caches
return outputs
def generate(self,
encoder_all_hidden: List[torch.Tensor],
max_length: int = 48,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 0.95) -> torch.Tensor:
"""
Autoregressive generation
"""
batch_size = encoder_all_hidden[0].size(0)
device = encoder_all_hidden[0].device
# Start with BOS token
generated = torch.full((batch_size, 1), self.BOS, device=device)
# Generate tokens one by one
past_key_values = None
for _ in range(max_length - 1):
# GPT optimization: Only pass last token for O(T) complexity
if past_key_values is not None:
decoder_input = generated[:, -1:] # Last token only
else:
decoder_input = generated # Full sequence for first step
outputs = self.forward(
encoder_all_hidden,
decoder_input_ids=decoder_input,
use_cache=True,
past_key_values=past_key_values
)
logits = outputs['logits'][:, -1, :] / temperature
# Top-k filtering
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = float('-inf')
# Top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above threshold
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf')
# Sample
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
# Append to generated sequence
generated = torch.cat([generated, next_token], dim=1)
# Check for EOS
if (next_token == self.EOS).all():
break
past_key_values = outputs.get('past_key_values')
return generated
def get_memory_usage(self) -> Dict[str, float]:
"""
Calculate memory usage with KV cache optimization (GPT-5 metric)
"""
# Standard attention: 16 heads for K and V
standard_kv_memory = 2 * 16 * self.max_seq_len * 80 * 4 # bytes
# Optimized: 2 heads for K and V
optimized_kv_memory = 2 * 2 * self.max_seq_len * 80 * 4 # bytes
return {
'standard_kv_mb': standard_kv_memory / (1024 * 1024),
'optimized_kv_mb': optimized_kv_memory / (1024 * 1024),
'reduction_ratio': standard_kv_memory / optimized_kv_memory,
'total_params_m': sum(p.numel() for p in self.parameters()) / 1e6
}
if __name__ == "__main__":
# Test the decoder
decoder = DecoderV62()
# Simulate encoder outputs (4 layers, 6 tokens each)
batch_size = 2
num_tokens = 6 # After progressive splitting
hidden_dim = 1280
encoder_outputs = [
torch.randn(batch_size, num_tokens, hidden_dim)
for _ in range(4)
]
# Test with teacher forcing
decoder_input = torch.randint(0, 256, (batch_size, 48))
output = decoder(encoder_outputs, decoder_input_ids=decoder_input)
print(f"Decoder output shape: {output['logits'].shape}")
print(f"Encoder layer importance: {output['encoder_layer_importance']}")
# Test generation
generated = decoder.generate(encoder_outputs, max_length=48)
print(f"Generated shape: {generated.shape}")
# Memory usage
memory_stats = decoder.get_memory_usage()
print(f"Memory optimization: {memory_stats['reduction_ratio']:.1f}x reduction")
print(f"Total parameters: {memory_stats['total_params_m']:.1f}M")