|
|
"""
|
|
|
Intelligent Tokenizer v6.2.0 - 6-Layer Decoder with Multi-Level Cross-Attention
|
|
|
Incorporates GPT-5 suggestions for KV cache optimization
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
import math
|
|
|
|
|
|
|
|
|
class KVCacheOptimizedAttention(nn.Module):
|
|
|
"""
|
|
|
KV Cache Optimized Attention - GPT-5 suggestion
|
|
|
16Q β 2K/V for 8x memory reduction
|
|
|
"""
|
|
|
|
|
|
def __init__(self, hidden_dim: int = 1280, num_heads: int = 16, kv_compression: int = 8):
|
|
|
super().__init__()
|
|
|
self.hidden_dim = hidden_dim
|
|
|
self.num_heads = num_heads
|
|
|
self.kv_heads = max(2, num_heads // kv_compression)
|
|
|
self.head_dim = hidden_dim // num_heads
|
|
|
|
|
|
|
|
|
self.q_proj = nn.Linear(hidden_dim, hidden_dim)
|
|
|
|
|
|
|
|
|
self.k_proj = nn.Linear(hidden_dim, self.kv_heads * self.head_dim)
|
|
|
self.v_proj = nn.Linear(hidden_dim, self.kv_heads * self.head_dim)
|
|
|
|
|
|
|
|
|
self.o_proj = nn.Linear(hidden_dim, hidden_dim)
|
|
|
|
|
|
|
|
|
self.register_buffer('cached_keys', None)
|
|
|
self.register_buffer('cached_values', None)
|
|
|
|
|
|
def forward(self,
|
|
|
hidden_states: torch.Tensor,
|
|
|
encoder_hidden: Optional[torch.Tensor] = None,
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
use_cache: bool = False) -> Tuple[torch.Tensor, Optional[Tuple]]:
|
|
|
"""
|
|
|
Forward pass with KV cache optimization
|
|
|
"""
|
|
|
batch_size, seq_len = hidden_states.shape[:2]
|
|
|
|
|
|
|
|
|
Q = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim)
|
|
|
Q = Q.transpose(1, 2)
|
|
|
|
|
|
|
|
|
kv_source = encoder_hidden if encoder_hidden is not None else hidden_states
|
|
|
|
|
|
|
|
|
K = self.k_proj(kv_source).view(batch_size, -1, self.kv_heads, self.head_dim)
|
|
|
V = self.v_proj(kv_source).view(batch_size, -1, self.kv_heads, self.head_dim)
|
|
|
K = K.transpose(1, 2)
|
|
|
V = V.transpose(1, 2)
|
|
|
|
|
|
|
|
|
K = K.repeat_interleave(self.num_heads // self.kv_heads, dim=1)
|
|
|
V = V.repeat_interleave(self.num_heads // self.kv_heads, dim=1)
|
|
|
|
|
|
|
|
|
if use_cache:
|
|
|
|
|
|
if self.cached_keys is not None and hidden_states.size(1) == 1:
|
|
|
|
|
|
K = torch.cat([self.cached_keys, K], dim=2)
|
|
|
V = torch.cat([self.cached_values, V], dim=2)
|
|
|
|
|
|
self.cached_keys = K
|
|
|
self.cached_values = V
|
|
|
|
|
|
|
|
|
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
|
|
|
|
|
|
|
|
if attention_mask is not None:
|
|
|
scores = scores + attention_mask
|
|
|
|
|
|
attn_weights = F.softmax(scores, dim=-1)
|
|
|
attn_output = torch.matmul(attn_weights, V)
|
|
|
|
|
|
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
|
attn_output = attn_output.view(batch_size, seq_len, self.hidden_dim)
|
|
|
output = self.o_proj(attn_output)
|
|
|
|
|
|
return output, (K, V) if use_cache else None
|
|
|
|
|
|
|
|
|
class SelectiveCrossAttention(nn.Module):
|
|
|
"""
|
|
|
Selective cross-attention - only attend to relevant encoder layers
|
|
|
Reduces 24 β 8 cross-attentions for efficiency
|
|
|
"""
|
|
|
|
|
|
def __init__(self, hidden_dim: int = 1280, layer_id: int = 0):
|
|
|
super().__init__()
|
|
|
self.hidden_dim = hidden_dim
|
|
|
self.layer_id = layer_id
|
|
|
|
|
|
|
|
|
self.encoder_connections = {
|
|
|
0: [0],
|
|
|
1: [0],
|
|
|
2: [1, 2],
|
|
|
3: [1, 2],
|
|
|
4: [3],
|
|
|
5: [3],
|
|
|
}
|
|
|
|
|
|
|
|
|
self.connected_layers = self.encoder_connections.get(layer_id, [0])
|
|
|
|
|
|
|
|
|
self.cross_attentions = nn.ModuleList([
|
|
|
KVCacheOptimizedAttention(hidden_dim, num_heads=16, kv_compression=8)
|
|
|
for _ in self.connected_layers
|
|
|
])
|
|
|
|
|
|
|
|
|
self.fusion = nn.Sequential(
|
|
|
nn.Linear(hidden_dim, hidden_dim),
|
|
|
nn.LayerNorm(hidden_dim),
|
|
|
nn.SiLU(),
|
|
|
nn.Dropout(0.1)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.layer_weights = nn.Parameter(torch.ones(len(self.connected_layers)) / len(self.connected_layers))
|
|
|
|
|
|
def forward(self,
|
|
|
decoder_hidden: torch.Tensor,
|
|
|
encoder_all_hidden: List[torch.Tensor],
|
|
|
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
"""
|
|
|
Selectively attend to relevant encoder layers only
|
|
|
"""
|
|
|
|
|
|
cross_outputs = []
|
|
|
for i, layer_idx in enumerate(self.connected_layers):
|
|
|
if layer_idx < len(encoder_all_hidden):
|
|
|
encoder_hidden = encoder_all_hidden[layer_idx]
|
|
|
cross_out, _ = self.cross_attentions[i](
|
|
|
hidden_states=decoder_hidden,
|
|
|
encoder_hidden=encoder_hidden,
|
|
|
attention_mask=attention_mask
|
|
|
)
|
|
|
cross_outputs.append(cross_out)
|
|
|
|
|
|
|
|
|
if len(cross_outputs) > 1:
|
|
|
weighted_outputs = torch.stack(cross_outputs, dim=0)
|
|
|
weights = F.softmax(self.layer_weights, dim=0).view(-1, 1, 1, 1)
|
|
|
fused = (weighted_outputs * weights).sum(dim=0)
|
|
|
else:
|
|
|
|
|
|
fused = cross_outputs[0] if cross_outputs else decoder_hidden
|
|
|
|
|
|
|
|
|
fused = self.fusion(fused)
|
|
|
|
|
|
return fused
|
|
|
|
|
|
|
|
|
class SwiGLU(nn.Module):
|
|
|
"""SwiGLU activation for better convergence (GPT suggestion)"""
|
|
|
def __init__(self, dim: int, mult: float = 2.66):
|
|
|
super().__init__()
|
|
|
inner = int(round(dim * mult / 2)) * 2
|
|
|
self.w1 = nn.Linear(dim, inner // 2)
|
|
|
self.w2 = nn.Linear(dim, inner // 2)
|
|
|
self.w3 = nn.Linear(inner // 2, dim)
|
|
|
|
|
|
def forward(self, x):
|
|
|
return self.w3(F.silu(self.w1(x)) * self.w2(x))
|
|
|
|
|
|
|
|
|
class DecoderLayer(nn.Module):
|
|
|
"""
|
|
|
Single decoder layer with self-attention and selective cross-attention
|
|
|
"""
|
|
|
|
|
|
def __init__(self, hidden_dim: int = 1280, num_heads: int = 16, layer_id: int = 0):
|
|
|
super().__init__()
|
|
|
self.hidden_dim = hidden_dim
|
|
|
self.layer_id = layer_id
|
|
|
|
|
|
|
|
|
self.self_attn = KVCacheOptimizedAttention(hidden_dim, num_heads, kv_compression=8)
|
|
|
self.self_attn_norm = nn.LayerNorm(hidden_dim)
|
|
|
|
|
|
|
|
|
self.cross_attn = SelectiveCrossAttention(hidden_dim, layer_id=layer_id)
|
|
|
self.cross_attn_norm = nn.LayerNorm(hidden_dim)
|
|
|
|
|
|
|
|
|
self.ffn = SwiGLU(hidden_dim, mult=2.66)
|
|
|
self.ffn_norm = nn.LayerNorm(hidden_dim)
|
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(0.1)
|
|
|
|
|
|
def forward(self,
|
|
|
hidden_states: torch.Tensor,
|
|
|
encoder_all_hidden: List[torch.Tensor],
|
|
|
self_attention_mask: Optional[torch.Tensor] = None,
|
|
|
cross_attention_mask: Optional[torch.Tensor] = None,
|
|
|
use_cache: bool = False) -> Tuple[torch.Tensor, Optional[Tuple]]:
|
|
|
"""
|
|
|
Forward pass through decoder layer
|
|
|
"""
|
|
|
|
|
|
residual = hidden_states
|
|
|
hidden_states = self.self_attn_norm(hidden_states)
|
|
|
self_attn_out, cache = self.self_attn(
|
|
|
hidden_states,
|
|
|
attention_mask=self_attention_mask,
|
|
|
use_cache=use_cache
|
|
|
)
|
|
|
hidden_states = residual + self.dropout(self_attn_out)
|
|
|
|
|
|
|
|
|
residual = hidden_states
|
|
|
hidden_states = self.cross_attn_norm(hidden_states)
|
|
|
cross_attn_out = self.cross_attn(
|
|
|
hidden_states,
|
|
|
encoder_all_hidden,
|
|
|
attention_mask=cross_attention_mask
|
|
|
)
|
|
|
hidden_states = residual + self.dropout(cross_attn_out)
|
|
|
|
|
|
|
|
|
residual = hidden_states
|
|
|
hidden_states = self.ffn_norm(hidden_states)
|
|
|
ffn_out = self.ffn(hidden_states)
|
|
|
hidden_states = residual + self.dropout(ffn_out)
|
|
|
|
|
|
return hidden_states, cache
|
|
|
|
|
|
|
|
|
class DecoderV62(nn.Module):
|
|
|
"""
|
|
|
6-Layer Decoder with Multi-Level Cross-Attention
|
|
|
Reduced from 8 layers but compensated with better cross-attention
|
|
|
"""
|
|
|
|
|
|
def __init__(self, config: Optional[Dict] = None):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.hidden_dim = 1280
|
|
|
self.num_heads = 16
|
|
|
self.num_layers = 6
|
|
|
self.vocab_size = 260
|
|
|
self.max_seq_len = 48
|
|
|
|
|
|
|
|
|
self.PAD = 256
|
|
|
self.BOS = 257
|
|
|
self.EOS = 258
|
|
|
self.MASK = 259
|
|
|
|
|
|
|
|
|
self.token_embedding = nn.Embedding(self.vocab_size, self.hidden_dim)
|
|
|
self.position_embedding = nn.Embedding(self.max_seq_len, self.hidden_dim)
|
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([
|
|
|
DecoderLayer(self.hidden_dim, self.num_heads, layer_id=i)
|
|
|
for i in range(self.num_layers)
|
|
|
])
|
|
|
|
|
|
|
|
|
self.output_norm = nn.LayerNorm(self.hidden_dim)
|
|
|
self.output_projection = nn.Linear(self.hidden_dim, self.vocab_size)
|
|
|
|
|
|
|
|
|
|
|
|
self.register_buffer('layer_importance', torch.zeros(4))
|
|
|
|
|
|
def forward(self,
|
|
|
encoder_all_hidden: List[torch.Tensor],
|
|
|
decoder_input_ids: Optional[torch.Tensor] = None,
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
use_cache: bool = False,
|
|
|
past_key_values: Optional[List] = None) -> Dict[str, torch.Tensor]:
|
|
|
"""
|
|
|
Forward pass through decoder
|
|
|
|
|
|
Args:
|
|
|
encoder_all_hidden: All encoder layer outputs (4 layers)
|
|
|
decoder_input_ids: Input token IDs for teacher forcing
|
|
|
attention_mask: Attention mask
|
|
|
use_cache: Whether to cache KV for inference
|
|
|
past_key_values: Cached KV from previous steps
|
|
|
"""
|
|
|
batch_size = encoder_all_hidden[0].size(0)
|
|
|
device = encoder_all_hidden[0].device
|
|
|
|
|
|
|
|
|
if decoder_input_ids is None:
|
|
|
|
|
|
hidden_states = encoder_all_hidden[-1]
|
|
|
seq_len = hidden_states.size(1)
|
|
|
else:
|
|
|
|
|
|
seq_len = decoder_input_ids.size(1)
|
|
|
|
|
|
|
|
|
token_embeds = self.token_embedding(decoder_input_ids)
|
|
|
position_ids = torch.arange(seq_len, device=device).expand(batch_size, -1)
|
|
|
position_embeds = self.position_embedding(position_ids)
|
|
|
|
|
|
hidden_states = token_embeds + position_embeds
|
|
|
|
|
|
|
|
|
causal_mask = torch.full((1, 1, seq_len, seq_len), float('-inf'), device=device)
|
|
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
|
|
|
|
|
|
|
|
all_hidden_states = []
|
|
|
all_caches = [] if use_cache else None
|
|
|
|
|
|
for i, layer in enumerate(self.layers):
|
|
|
|
|
|
if encoder_all_hidden is not None and len(encoder_all_hidden) > 0:
|
|
|
S_enc = encoder_all_hidden[0].size(1)
|
|
|
|
|
|
cross_mask = torch.zeros((batch_size, 1, 1, S_enc), device=hidden_states.device)
|
|
|
else:
|
|
|
cross_mask = None
|
|
|
|
|
|
hidden_states, cache = layer(
|
|
|
hidden_states,
|
|
|
encoder_all_hidden,
|
|
|
self_attention_mask=causal_mask,
|
|
|
cross_attention_mask=cross_mask,
|
|
|
use_cache=use_cache
|
|
|
)
|
|
|
|
|
|
all_hidden_states.append(hidden_states)
|
|
|
if use_cache:
|
|
|
all_caches.append(cache)
|
|
|
|
|
|
|
|
|
hidden_states = self.output_norm(hidden_states)
|
|
|
logits = self.output_projection(hidden_states)
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
self.layer_importance = torch.tensor([0.25, 0.25, 0.25, 0.25])
|
|
|
|
|
|
outputs = {
|
|
|
'logits': logits,
|
|
|
'last_hidden_state': hidden_states,
|
|
|
'all_hidden_states': all_hidden_states,
|
|
|
'encoder_layer_importance': self.layer_importance
|
|
|
}
|
|
|
|
|
|
if use_cache:
|
|
|
outputs['past_key_values'] = all_caches
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
def generate(self,
|
|
|
encoder_all_hidden: List[torch.Tensor],
|
|
|
max_length: int = 48,
|
|
|
temperature: float = 1.0,
|
|
|
top_k: int = 50,
|
|
|
top_p: float = 0.95) -> torch.Tensor:
|
|
|
"""
|
|
|
Autoregressive generation
|
|
|
"""
|
|
|
batch_size = encoder_all_hidden[0].size(0)
|
|
|
device = encoder_all_hidden[0].device
|
|
|
|
|
|
|
|
|
generated = torch.full((batch_size, 1), self.BOS, device=device)
|
|
|
|
|
|
|
|
|
past_key_values = None
|
|
|
for _ in range(max_length - 1):
|
|
|
|
|
|
if past_key_values is not None:
|
|
|
decoder_input = generated[:, -1:]
|
|
|
else:
|
|
|
decoder_input = generated
|
|
|
|
|
|
outputs = self.forward(
|
|
|
encoder_all_hidden,
|
|
|
decoder_input_ids=decoder_input,
|
|
|
use_cache=True,
|
|
|
past_key_values=past_key_values
|
|
|
)
|
|
|
|
|
|
logits = outputs['logits'][:, -1, :] / temperature
|
|
|
|
|
|
|
|
|
if top_k > 0:
|
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
|
|
logits[indices_to_remove] = float('-inf')
|
|
|
|
|
|
|
|
|
if top_p < 1.0:
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
|
sorted_indices_to_remove[..., 0] = 0
|
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
|
|
logits[indices_to_remove] = float('-inf')
|
|
|
|
|
|
|
|
|
probs = F.softmax(logits, dim=-1)
|
|
|
next_token = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
|
|
|
|
generated = torch.cat([generated, next_token], dim=1)
|
|
|
|
|
|
|
|
|
if (next_token == self.EOS).all():
|
|
|
break
|
|
|
|
|
|
past_key_values = outputs.get('past_key_values')
|
|
|
|
|
|
return generated
|
|
|
|
|
|
def get_memory_usage(self) -> Dict[str, float]:
|
|
|
"""
|
|
|
Calculate memory usage with KV cache optimization (GPT-5 metric)
|
|
|
"""
|
|
|
|
|
|
standard_kv_memory = 2 * 16 * self.max_seq_len * 80 * 4
|
|
|
|
|
|
|
|
|
optimized_kv_memory = 2 * 2 * self.max_seq_len * 80 * 4
|
|
|
|
|
|
return {
|
|
|
'standard_kv_mb': standard_kv_memory / (1024 * 1024),
|
|
|
'optimized_kv_mb': optimized_kv_memory / (1024 * 1024),
|
|
|
'reduction_ratio': standard_kv_memory / optimized_kv_memory,
|
|
|
'total_params_m': sum(p.numel() for p in self.parameters()) / 1e6
|
|
|
}
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
decoder = DecoderV62()
|
|
|
|
|
|
|
|
|
batch_size = 2
|
|
|
num_tokens = 6
|
|
|
hidden_dim = 1280
|
|
|
|
|
|
encoder_outputs = [
|
|
|
torch.randn(batch_size, num_tokens, hidden_dim)
|
|
|
for _ in range(4)
|
|
|
]
|
|
|
|
|
|
|
|
|
decoder_input = torch.randint(0, 256, (batch_size, 48))
|
|
|
output = decoder(encoder_outputs, decoder_input_ids=decoder_input)
|
|
|
|
|
|
print(f"Decoder output shape: {output['logits'].shape}")
|
|
|
print(f"Encoder layer importance: {output['encoder_layer_importance']}")
|
|
|
|
|
|
|
|
|
generated = decoder.generate(encoder_outputs, max_length=48)
|
|
|
print(f"Generated shape: {generated.shape}")
|
|
|
|
|
|
|
|
|
memory_stats = decoder.get_memory_usage()
|
|
|
print(f"Memory optimization: {memory_stats['reduction_ratio']:.1f}x reduction")
|
|
|
print(f"Total parameters: {memory_stats['total_params_m']:.1f}M") |