Upload folder using huggingface_hub
Browse files- core/__pycache__/decoder.cpython-310.pyc +0 -0
- core/__pycache__/decoder.cpython-313.pyc +0 -0
- core/__pycache__/encoder.cpython-310.pyc +0 -0
- core/__pycache__/encoder.cpython-313.pyc +0 -0
- core/__pycache__/tokenizer.cpython-310.pyc +0 -0
- core/__pycache__/tokenizer.cpython-313.pyc +0 -0
- core/__pycache__/unified_model.cpython-310.pyc +0 -0
- core/decoder.py +485 -0
- core/encoder.py +588 -0
- core/intelligent_loss.py +589 -0
- core/scheduler.py +669 -0
- core/tokenizer.py +477 -0
- core/unified_model.py +481 -695
core/__pycache__/decoder.cpython-310.pyc
ADDED
|
Binary file (11.8 kB). View file
|
|
|
core/__pycache__/decoder.cpython-313.pyc
ADDED
|
Binary file (20.8 kB). View file
|
|
|
core/__pycache__/encoder.cpython-310.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
core/__pycache__/encoder.cpython-313.pyc
ADDED
|
Binary file (24.8 kB). View file
|
|
|
core/__pycache__/tokenizer.cpython-310.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
core/__pycache__/tokenizer.cpython-313.pyc
ADDED
|
Binary file (19.8 kB). View file
|
|
|
core/__pycache__/unified_model.cpython-310.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
core/decoder.py
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Intelligent Tokenizer v6.2.0 - 6-Layer Decoder with Multi-Level Cross-Attention
|
| 3 |
+
Incorporates GPT-5 suggestions for KV cache optimization
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from typing import Dict, List, Optional, Tuple
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class KVCacheOptimizedAttention(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
KV Cache Optimized Attention - GPT-5 suggestion
|
| 16 |
+
16Q → 2K/V for 8x memory reduction
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, hidden_dim: int = 1280, num_heads: int = 16, kv_compression: int = 8):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.hidden_dim = hidden_dim
|
| 22 |
+
self.num_heads = num_heads
|
| 23 |
+
self.kv_heads = max(2, num_heads // kv_compression) # 16/8 = 2 KV heads
|
| 24 |
+
self.head_dim = hidden_dim // num_heads # 80
|
| 25 |
+
|
| 26 |
+
# Query uses all heads
|
| 27 |
+
self.q_proj = nn.Linear(hidden_dim, hidden_dim) # 16 heads
|
| 28 |
+
|
| 29 |
+
# Key/Value use fewer heads (GPT-5 suggestion)
|
| 30 |
+
self.k_proj = nn.Linear(hidden_dim, self.kv_heads * self.head_dim) # 2 heads
|
| 31 |
+
self.v_proj = nn.Linear(hidden_dim, self.kv_heads * self.head_dim) # 2 heads
|
| 32 |
+
|
| 33 |
+
# Output projection
|
| 34 |
+
self.o_proj = nn.Linear(hidden_dim, hidden_dim)
|
| 35 |
+
|
| 36 |
+
# KV cache for inference
|
| 37 |
+
self.register_buffer('cached_keys', None)
|
| 38 |
+
self.register_buffer('cached_values', None)
|
| 39 |
+
|
| 40 |
+
def forward(self,
|
| 41 |
+
hidden_states: torch.Tensor,
|
| 42 |
+
encoder_hidden: Optional[torch.Tensor] = None,
|
| 43 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 44 |
+
use_cache: bool = False) -> Tuple[torch.Tensor, Optional[Tuple]]:
|
| 45 |
+
"""
|
| 46 |
+
Forward pass with KV cache optimization
|
| 47 |
+
"""
|
| 48 |
+
batch_size, seq_len = hidden_states.shape[:2]
|
| 49 |
+
|
| 50 |
+
# Query projection (all heads)
|
| 51 |
+
Q = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim)
|
| 52 |
+
Q = Q.transpose(1, 2) # [batch, heads, seq, dim]
|
| 53 |
+
|
| 54 |
+
# Key/Value source (self or cross)
|
| 55 |
+
kv_source = encoder_hidden if encoder_hidden is not None else hidden_states
|
| 56 |
+
|
| 57 |
+
# Key/Value projection (fewer heads)
|
| 58 |
+
K = self.k_proj(kv_source).view(batch_size, -1, self.kv_heads, self.head_dim)
|
| 59 |
+
V = self.v_proj(kv_source).view(batch_size, -1, self.kv_heads, self.head_dim)
|
| 60 |
+
K = K.transpose(1, 2) # [batch, kv_heads, seq, dim]
|
| 61 |
+
V = V.transpose(1, 2)
|
| 62 |
+
|
| 63 |
+
# Repeat KV heads to match Q heads (broadcast)
|
| 64 |
+
K = K.repeat_interleave(self.num_heads // self.kv_heads, dim=1)
|
| 65 |
+
V = V.repeat_interleave(self.num_heads // self.kv_heads, dim=1)
|
| 66 |
+
|
| 67 |
+
# Cache management for incremental generation (GPT suggestion)
|
| 68 |
+
if use_cache:
|
| 69 |
+
# For incremental generation, only process new token
|
| 70 |
+
if self.cached_keys is not None and hidden_states.size(1) == 1:
|
| 71 |
+
# Append new K/V to cache
|
| 72 |
+
K = torch.cat([self.cached_keys, K], dim=2)
|
| 73 |
+
V = torch.cat([self.cached_values, V], dim=2)
|
| 74 |
+
# Update cache
|
| 75 |
+
self.cached_keys = K
|
| 76 |
+
self.cached_values = V
|
| 77 |
+
|
| 78 |
+
# Scaled dot-product attention
|
| 79 |
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 80 |
+
|
| 81 |
+
# Use additive mask (GPT suggestion)
|
| 82 |
+
if attention_mask is not None:
|
| 83 |
+
scores = scores + attention_mask # additive mask: -inf where masked, 0 elsewhere
|
| 84 |
+
|
| 85 |
+
attn_weights = F.softmax(scores, dim=-1)
|
| 86 |
+
attn_output = torch.matmul(attn_weights, V)
|
| 87 |
+
|
| 88 |
+
# Reshape and project
|
| 89 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 90 |
+
attn_output = attn_output.view(batch_size, seq_len, self.hidden_dim)
|
| 91 |
+
output = self.o_proj(attn_output)
|
| 92 |
+
|
| 93 |
+
return output, (K, V) if use_cache else None
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class SelectiveCrossAttention(nn.Module):
|
| 97 |
+
"""
|
| 98 |
+
Selective cross-attention - only attend to relevant encoder layers
|
| 99 |
+
Reduces 24 → 8 cross-attentions for efficiency
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
def __init__(self, hidden_dim: int = 1280, layer_id: int = 0):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.hidden_dim = hidden_dim
|
| 105 |
+
self.layer_id = layer_id
|
| 106 |
+
|
| 107 |
+
# Define which encoder layers this decoder layer should attend to
|
| 108 |
+
self.encoder_connections = {
|
| 109 |
+
0: [0], # Decoder L0 → Encoder L0 (byte info)
|
| 110 |
+
1: [0], # Decoder L1 → Encoder L0 (byte info)
|
| 111 |
+
2: [1, 2], # Decoder L2 → Encoder L1,2 (language info)
|
| 112 |
+
3: [1, 2], # Decoder L3 → Encoder L1,2 (language info)
|
| 113 |
+
4: [3], # Decoder L4 → Encoder L3 (semantic info)
|
| 114 |
+
5: [3], # Decoder L5 → Encoder L3 (semantic info)
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
# Get connections for this layer
|
| 118 |
+
self.connected_layers = self.encoder_connections.get(layer_id, [0])
|
| 119 |
+
|
| 120 |
+
# Create attention modules only for connected layers
|
| 121 |
+
self.cross_attentions = nn.ModuleList([
|
| 122 |
+
KVCacheOptimizedAttention(hidden_dim, num_heads=16, kv_compression=8)
|
| 123 |
+
for _ in self.connected_layers
|
| 124 |
+
])
|
| 125 |
+
|
| 126 |
+
# Lightweight fusion with weighted sum (GPT suggestion)
|
| 127 |
+
self.fusion = nn.Sequential(
|
| 128 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 129 |
+
nn.LayerNorm(hidden_dim),
|
| 130 |
+
nn.SiLU(),
|
| 131 |
+
nn.Dropout(0.1)
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
# Learnable weights for connected layers only
|
| 135 |
+
self.layer_weights = nn.Parameter(torch.ones(len(self.connected_layers)) / len(self.connected_layers))
|
| 136 |
+
|
| 137 |
+
def forward(self,
|
| 138 |
+
decoder_hidden: torch.Tensor,
|
| 139 |
+
encoder_all_hidden: List[torch.Tensor],
|
| 140 |
+
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 141 |
+
"""
|
| 142 |
+
Selectively attend to relevant encoder layers only
|
| 143 |
+
"""
|
| 144 |
+
# Only attend to connected encoder layers
|
| 145 |
+
cross_outputs = []
|
| 146 |
+
for i, layer_idx in enumerate(self.connected_layers):
|
| 147 |
+
if layer_idx < len(encoder_all_hidden):
|
| 148 |
+
encoder_hidden = encoder_all_hidden[layer_idx]
|
| 149 |
+
cross_out, _ = self.cross_attentions[i](
|
| 150 |
+
hidden_states=decoder_hidden,
|
| 151 |
+
encoder_hidden=encoder_hidden,
|
| 152 |
+
attention_mask=attention_mask
|
| 153 |
+
)
|
| 154 |
+
cross_outputs.append(cross_out)
|
| 155 |
+
|
| 156 |
+
# Weighted sum fusion for connected layers only
|
| 157 |
+
if len(cross_outputs) > 1:
|
| 158 |
+
weighted_outputs = torch.stack(cross_outputs, dim=0) # [N, batch, seq, hidden]
|
| 159 |
+
weights = F.softmax(self.layer_weights, dim=0).view(-1, 1, 1, 1)
|
| 160 |
+
fused = (weighted_outputs * weights).sum(dim=0) # [batch, seq, hidden]
|
| 161 |
+
else:
|
| 162 |
+
# Single connection - no fusion needed
|
| 163 |
+
fused = cross_outputs[0] if cross_outputs else decoder_hidden
|
| 164 |
+
|
| 165 |
+
# Apply lightweight fusion layer
|
| 166 |
+
fused = self.fusion(fused)
|
| 167 |
+
|
| 168 |
+
return fused
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
class SwiGLU(nn.Module):
|
| 172 |
+
"""SwiGLU activation for better convergence (GPT suggestion)"""
|
| 173 |
+
def __init__(self, dim: int, mult: float = 2.66):
|
| 174 |
+
super().__init__()
|
| 175 |
+
inner = int(round(dim * mult / 2)) * 2 # Even alignment
|
| 176 |
+
self.w1 = nn.Linear(dim, inner // 2)
|
| 177 |
+
self.w2 = nn.Linear(dim, inner // 2)
|
| 178 |
+
self.w3 = nn.Linear(inner // 2, dim)
|
| 179 |
+
|
| 180 |
+
def forward(self, x):
|
| 181 |
+
return self.w3(F.silu(self.w1(x)) * self.w2(x))
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class DecoderLayer(nn.Module):
|
| 185 |
+
"""
|
| 186 |
+
Single decoder layer with self-attention and selective cross-attention
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
def __init__(self, hidden_dim: int = 1280, num_heads: int = 16, layer_id: int = 0):
|
| 190 |
+
super().__init__()
|
| 191 |
+
self.hidden_dim = hidden_dim
|
| 192 |
+
self.layer_id = layer_id
|
| 193 |
+
|
| 194 |
+
# Self-attention (with KV cache optimization)
|
| 195 |
+
self.self_attn = KVCacheOptimizedAttention(hidden_dim, num_heads, kv_compression=8)
|
| 196 |
+
self.self_attn_norm = nn.LayerNorm(hidden_dim)
|
| 197 |
+
|
| 198 |
+
# Selective cross-attention to specific encoder layers
|
| 199 |
+
self.cross_attn = SelectiveCrossAttention(hidden_dim, layer_id=layer_id)
|
| 200 |
+
self.cross_attn_norm = nn.LayerNorm(hidden_dim)
|
| 201 |
+
|
| 202 |
+
# Feed-forward network with SwiGLU (GPT suggestion)
|
| 203 |
+
self.ffn = SwiGLU(hidden_dim, mult=2.66)
|
| 204 |
+
self.ffn_norm = nn.LayerNorm(hidden_dim)
|
| 205 |
+
|
| 206 |
+
# Dropout for residual connections
|
| 207 |
+
self.dropout = nn.Dropout(0.1)
|
| 208 |
+
|
| 209 |
+
def forward(self,
|
| 210 |
+
hidden_states: torch.Tensor,
|
| 211 |
+
encoder_all_hidden: List[torch.Tensor],
|
| 212 |
+
self_attention_mask: Optional[torch.Tensor] = None,
|
| 213 |
+
cross_attention_mask: Optional[torch.Tensor] = None,
|
| 214 |
+
use_cache: bool = False) -> Tuple[torch.Tensor, Optional[Tuple]]:
|
| 215 |
+
"""
|
| 216 |
+
Forward pass through decoder layer
|
| 217 |
+
"""
|
| 218 |
+
# Self-attention with residual
|
| 219 |
+
residual = hidden_states
|
| 220 |
+
hidden_states = self.self_attn_norm(hidden_states)
|
| 221 |
+
self_attn_out, cache = self.self_attn(
|
| 222 |
+
hidden_states,
|
| 223 |
+
attention_mask=self_attention_mask,
|
| 224 |
+
use_cache=use_cache
|
| 225 |
+
)
|
| 226 |
+
hidden_states = residual + self.dropout(self_attn_out)
|
| 227 |
+
|
| 228 |
+
# Cross-attention with residual
|
| 229 |
+
residual = hidden_states
|
| 230 |
+
hidden_states = self.cross_attn_norm(hidden_states)
|
| 231 |
+
cross_attn_out = self.cross_attn(
|
| 232 |
+
hidden_states,
|
| 233 |
+
encoder_all_hidden,
|
| 234 |
+
attention_mask=cross_attention_mask
|
| 235 |
+
)
|
| 236 |
+
hidden_states = residual + self.dropout(cross_attn_out)
|
| 237 |
+
|
| 238 |
+
# FFN with residual
|
| 239 |
+
residual = hidden_states
|
| 240 |
+
hidden_states = self.ffn_norm(hidden_states)
|
| 241 |
+
ffn_out = self.ffn(hidden_states)
|
| 242 |
+
hidden_states = residual + self.dropout(ffn_out)
|
| 243 |
+
|
| 244 |
+
return hidden_states, cache
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class DecoderV62(nn.Module):
|
| 248 |
+
"""
|
| 249 |
+
6-Layer Decoder with Multi-Level Cross-Attention
|
| 250 |
+
Reduced from 8 layers but compensated with better cross-attention
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
def __init__(self, config: Optional[Dict] = None):
|
| 254 |
+
super().__init__()
|
| 255 |
+
|
| 256 |
+
# Configuration
|
| 257 |
+
self.hidden_dim = 1280
|
| 258 |
+
self.num_heads = 16
|
| 259 |
+
self.num_layers = 6 # Reduced from 8
|
| 260 |
+
self.vocab_size = 260 # 256 bytes + special tokens
|
| 261 |
+
self.max_seq_len = 48
|
| 262 |
+
|
| 263 |
+
# Token constants (GPT suggestion - explicit constants)
|
| 264 |
+
self.PAD = 256
|
| 265 |
+
self.BOS = 257
|
| 266 |
+
self.EOS = 258
|
| 267 |
+
self.MASK = 259
|
| 268 |
+
|
| 269 |
+
# Token embedding and position encoding
|
| 270 |
+
self.token_embedding = nn.Embedding(self.vocab_size, self.hidden_dim)
|
| 271 |
+
self.position_embedding = nn.Embedding(self.max_seq_len, self.hidden_dim)
|
| 272 |
+
|
| 273 |
+
# 6 decoder layers with layer-specific cross-attention
|
| 274 |
+
self.layers = nn.ModuleList([
|
| 275 |
+
DecoderLayer(self.hidden_dim, self.num_heads, layer_id=i)
|
| 276 |
+
for i in range(self.num_layers)
|
| 277 |
+
])
|
| 278 |
+
|
| 279 |
+
# Output projection
|
| 280 |
+
self.output_norm = nn.LayerNorm(self.hidden_dim)
|
| 281 |
+
self.output_projection = nn.Linear(self.hidden_dim, self.vocab_size)
|
| 282 |
+
|
| 283 |
+
# Monitoring (GPT-5 suggestion)
|
| 284 |
+
# Track importance of ENCODER layers (4) used by decoder
|
| 285 |
+
self.register_buffer('layer_importance', torch.zeros(4)) # Track importance of 4 encoder layers
|
| 286 |
+
|
| 287 |
+
def forward(self,
|
| 288 |
+
encoder_all_hidden: List[torch.Tensor],
|
| 289 |
+
decoder_input_ids: Optional[torch.Tensor] = None,
|
| 290 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 291 |
+
use_cache: bool = False,
|
| 292 |
+
past_key_values: Optional[List] = None) -> Dict[str, torch.Tensor]:
|
| 293 |
+
"""
|
| 294 |
+
Forward pass through decoder
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
encoder_all_hidden: All encoder layer outputs (4 layers)
|
| 298 |
+
decoder_input_ids: Input token IDs for teacher forcing
|
| 299 |
+
attention_mask: Attention mask
|
| 300 |
+
use_cache: Whether to cache KV for inference
|
| 301 |
+
past_key_values: Cached KV from previous steps
|
| 302 |
+
"""
|
| 303 |
+
batch_size = encoder_all_hidden[0].size(0)
|
| 304 |
+
device = encoder_all_hidden[0].device
|
| 305 |
+
|
| 306 |
+
# If no decoder input, start with compressed representation
|
| 307 |
+
if decoder_input_ids is None:
|
| 308 |
+
# Use encoder's final compressed output as starting point
|
| 309 |
+
hidden_states = encoder_all_hidden[-1] # [batch, M tokens, 1280]
|
| 310 |
+
seq_len = hidden_states.size(1)
|
| 311 |
+
else:
|
| 312 |
+
# Teacher forcing mode: use provided tokens
|
| 313 |
+
seq_len = decoder_input_ids.size(1)
|
| 314 |
+
|
| 315 |
+
# Embeddings
|
| 316 |
+
token_embeds = self.token_embedding(decoder_input_ids)
|
| 317 |
+
position_ids = torch.arange(seq_len, device=device).expand(batch_size, -1)
|
| 318 |
+
position_embeds = self.position_embedding(position_ids)
|
| 319 |
+
|
| 320 |
+
hidden_states = token_embeds + position_embeds
|
| 321 |
+
|
| 322 |
+
# Create causal mask for self-attention (additive mask - GPT suggestion)
|
| 323 |
+
causal_mask = torch.full((1, 1, seq_len, seq_len), float('-inf'), device=device)
|
| 324 |
+
causal_mask = torch.triu(causal_mask, diagonal=1) # [1, 1, seq, seq]
|
| 325 |
+
|
| 326 |
+
# Pass through decoder layers
|
| 327 |
+
all_hidden_states = []
|
| 328 |
+
all_caches = [] if use_cache else None
|
| 329 |
+
|
| 330 |
+
for i, layer in enumerate(self.layers):
|
| 331 |
+
# GPT final check: Create proper cross-attention mask for encoder hidden states
|
| 332 |
+
if encoder_all_hidden is not None and len(encoder_all_hidden) > 0:
|
| 333 |
+
S_enc = encoder_all_hidden[0].size(1) # Encoder sequence length
|
| 334 |
+
# Create additive mask (0 = attend, -inf = mask)
|
| 335 |
+
cross_mask = torch.zeros((batch_size, 1, 1, S_enc), device=hidden_states.device)
|
| 336 |
+
else:
|
| 337 |
+
cross_mask = None
|
| 338 |
+
|
| 339 |
+
hidden_states, cache = layer(
|
| 340 |
+
hidden_states,
|
| 341 |
+
encoder_all_hidden,
|
| 342 |
+
self_attention_mask=causal_mask,
|
| 343 |
+
cross_attention_mask=cross_mask, # Use proper cross mask
|
| 344 |
+
use_cache=use_cache
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
all_hidden_states.append(hidden_states)
|
| 348 |
+
if use_cache:
|
| 349 |
+
all_caches.append(cache)
|
| 350 |
+
|
| 351 |
+
# Final output projection
|
| 352 |
+
hidden_states = self.output_norm(hidden_states)
|
| 353 |
+
logits = self.output_projection(hidden_states)
|
| 354 |
+
|
| 355 |
+
# Update monitoring: track encoder layer importance
|
| 356 |
+
# (This would be computed based on cross-attention weights in practice)
|
| 357 |
+
with torch.no_grad():
|
| 358 |
+
# Simplified: assume equal importance for now
|
| 359 |
+
self.layer_importance = torch.tensor([0.25, 0.25, 0.25, 0.25])
|
| 360 |
+
|
| 361 |
+
outputs = {
|
| 362 |
+
'logits': logits,
|
| 363 |
+
'last_hidden_state': hidden_states,
|
| 364 |
+
'all_hidden_states': all_hidden_states,
|
| 365 |
+
'encoder_layer_importance': self.layer_importance
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
if use_cache:
|
| 369 |
+
outputs['past_key_values'] = all_caches
|
| 370 |
+
|
| 371 |
+
return outputs
|
| 372 |
+
|
| 373 |
+
def generate(self,
|
| 374 |
+
encoder_all_hidden: List[torch.Tensor],
|
| 375 |
+
max_length: int = 48,
|
| 376 |
+
temperature: float = 1.0,
|
| 377 |
+
top_k: int = 50,
|
| 378 |
+
top_p: float = 0.95) -> torch.Tensor:
|
| 379 |
+
"""
|
| 380 |
+
Autoregressive generation
|
| 381 |
+
"""
|
| 382 |
+
batch_size = encoder_all_hidden[0].size(0)
|
| 383 |
+
device = encoder_all_hidden[0].device
|
| 384 |
+
|
| 385 |
+
# Start with BOS token
|
| 386 |
+
generated = torch.full((batch_size, 1), self.BOS, device=device)
|
| 387 |
+
|
| 388 |
+
# Generate tokens one by one
|
| 389 |
+
past_key_values = None
|
| 390 |
+
for _ in range(max_length - 1):
|
| 391 |
+
# GPT optimization: Only pass last token for O(T) complexity
|
| 392 |
+
if past_key_values is not None:
|
| 393 |
+
decoder_input = generated[:, -1:] # Last token only
|
| 394 |
+
else:
|
| 395 |
+
decoder_input = generated # Full sequence for first step
|
| 396 |
+
|
| 397 |
+
outputs = self.forward(
|
| 398 |
+
encoder_all_hidden,
|
| 399 |
+
decoder_input_ids=decoder_input,
|
| 400 |
+
use_cache=True,
|
| 401 |
+
past_key_values=past_key_values
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
logits = outputs['logits'][:, -1, :] / temperature
|
| 405 |
+
|
| 406 |
+
# Top-k filtering
|
| 407 |
+
if top_k > 0:
|
| 408 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 409 |
+
logits[indices_to_remove] = float('-inf')
|
| 410 |
+
|
| 411 |
+
# Top-p (nucleus) filtering
|
| 412 |
+
if top_p < 1.0:
|
| 413 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 414 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 415 |
+
|
| 416 |
+
# Remove tokens with cumulative probability above threshold
|
| 417 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 418 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 419 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 420 |
+
|
| 421 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 422 |
+
logits[indices_to_remove] = float('-inf')
|
| 423 |
+
|
| 424 |
+
# Sample
|
| 425 |
+
probs = F.softmax(logits, dim=-1)
|
| 426 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 427 |
+
|
| 428 |
+
# Append to generated sequence
|
| 429 |
+
generated = torch.cat([generated, next_token], dim=1)
|
| 430 |
+
|
| 431 |
+
# Check for EOS
|
| 432 |
+
if (next_token == self.EOS).all():
|
| 433 |
+
break
|
| 434 |
+
|
| 435 |
+
past_key_values = outputs.get('past_key_values')
|
| 436 |
+
|
| 437 |
+
return generated
|
| 438 |
+
|
| 439 |
+
def get_memory_usage(self) -> Dict[str, float]:
|
| 440 |
+
"""
|
| 441 |
+
Calculate memory usage with KV cache optimization (GPT-5 metric)
|
| 442 |
+
"""
|
| 443 |
+
# Standard attention: 16 heads for K and V
|
| 444 |
+
standard_kv_memory = 2 * 16 * self.max_seq_len * 80 * 4 # bytes
|
| 445 |
+
|
| 446 |
+
# Optimized: 2 heads for K and V
|
| 447 |
+
optimized_kv_memory = 2 * 2 * self.max_seq_len * 80 * 4 # bytes
|
| 448 |
+
|
| 449 |
+
return {
|
| 450 |
+
'standard_kv_mb': standard_kv_memory / (1024 * 1024),
|
| 451 |
+
'optimized_kv_mb': optimized_kv_memory / (1024 * 1024),
|
| 452 |
+
'reduction_ratio': standard_kv_memory / optimized_kv_memory,
|
| 453 |
+
'total_params_m': sum(p.numel() for p in self.parameters()) / 1e6
|
| 454 |
+
}
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
if __name__ == "__main__":
|
| 458 |
+
# Test the decoder
|
| 459 |
+
decoder = DecoderV62()
|
| 460 |
+
|
| 461 |
+
# Simulate encoder outputs (4 layers, 6 tokens each)
|
| 462 |
+
batch_size = 2
|
| 463 |
+
num_tokens = 6 # After progressive splitting
|
| 464 |
+
hidden_dim = 1280
|
| 465 |
+
|
| 466 |
+
encoder_outputs = [
|
| 467 |
+
torch.randn(batch_size, num_tokens, hidden_dim)
|
| 468 |
+
for _ in range(4)
|
| 469 |
+
]
|
| 470 |
+
|
| 471 |
+
# Test with teacher forcing
|
| 472 |
+
decoder_input = torch.randint(0, 256, (batch_size, 48))
|
| 473 |
+
output = decoder(encoder_outputs, decoder_input_ids=decoder_input)
|
| 474 |
+
|
| 475 |
+
print(f"Decoder output shape: {output['logits'].shape}")
|
| 476 |
+
print(f"Encoder layer importance: {output['encoder_layer_importance']}")
|
| 477 |
+
|
| 478 |
+
# Test generation
|
| 479 |
+
generated = decoder.generate(encoder_outputs, max_length=48)
|
| 480 |
+
print(f"Generated shape: {generated.shape}")
|
| 481 |
+
|
| 482 |
+
# Memory usage
|
| 483 |
+
memory_stats = decoder.get_memory_usage()
|
| 484 |
+
print(f"Memory optimization: {memory_stats['reduction_ratio']:.1f}x reduction")
|
| 485 |
+
print(f"Total parameters: {memory_stats['total_params_m']:.1f}M")
|
core/encoder.py
ADDED
|
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Intelligent Tokenizer v6.2.0 - Progressive Splitting Encoder
|
| 3 |
+
With GPT-5 suggested improvements
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from typing import Dict, List, Optional, Tuple
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RoPEPositionalEncoding(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Rotary Position Embedding (RoPE) - GPT-5 suggestion
|
| 16 |
+
Better for handling chunk boundaries and variable sequence lengths
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, dim: int, max_seq_len: int = 48, base: int = 10000):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.dim = dim
|
| 22 |
+
self.max_seq_len = max_seq_len
|
| 23 |
+
self.base = base
|
| 24 |
+
|
| 25 |
+
# Precompute sinusoidal frequencies
|
| 26 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
| 27 |
+
self.register_buffer('inv_freq', inv_freq)
|
| 28 |
+
|
| 29 |
+
# Precompute positional encodings
|
| 30 |
+
t = torch.arange(max_seq_len).type_as(self.inv_freq)
|
| 31 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 32 |
+
self.register_buffer('cos_cached', freqs.cos())
|
| 33 |
+
self.register_buffer('sin_cached', freqs.sin())
|
| 34 |
+
|
| 35 |
+
def forward(self, x: torch.Tensor, seq_len: int = None) -> torch.Tensor:
|
| 36 |
+
"""
|
| 37 |
+
Apply RoPE to input tensor
|
| 38 |
+
Handles chunk boundary corrections as suggested by GPT-5
|
| 39 |
+
"""
|
| 40 |
+
if seq_len is None:
|
| 41 |
+
seq_len = x.shape[1]
|
| 42 |
+
|
| 43 |
+
# Get cached cos/sin values
|
| 44 |
+
cos = self.cos_cached[:seq_len]
|
| 45 |
+
sin = self.sin_cached[:seq_len]
|
| 46 |
+
|
| 47 |
+
# Apply rotary embedding
|
| 48 |
+
x_rot = self._apply_rotary_emb(x, cos, sin)
|
| 49 |
+
|
| 50 |
+
return x_rot
|
| 51 |
+
|
| 52 |
+
def _apply_rotary_emb(self, x, cos, sin):
|
| 53 |
+
"""Apply rotary embedding to input"""
|
| 54 |
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
| 55 |
+
x_rot = torch.stack([
|
| 56 |
+
x1 * cos - x2 * sin,
|
| 57 |
+
x1 * sin + x2 * cos
|
| 58 |
+
], dim=-1).flatten(-2)
|
| 59 |
+
return x_rot
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class GatedCrossAttention(nn.Module):
|
| 63 |
+
"""
|
| 64 |
+
Gated Cross-Attention with MQA - GPT-5 suggestion
|
| 65 |
+
Monitor gate values for quality assessment
|
| 66 |
+
16Q → 2K/V for 8x memory reduction
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(self, hidden_dim: int = 1280, num_heads: int = 16, kv_heads: int = 2):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.hidden_dim = hidden_dim
|
| 72 |
+
self.num_heads = num_heads
|
| 73 |
+
self.kv_heads = kv_heads # Reduced KV heads (GPT suggestion)
|
| 74 |
+
self.head_dim = hidden_dim // num_heads # 80
|
| 75 |
+
|
| 76 |
+
# Multi-Query Attention projections
|
| 77 |
+
self.q_proj = nn.Linear(hidden_dim, hidden_dim) # 16 heads
|
| 78 |
+
self.k_proj = nn.Linear(hidden_dim, kv_heads * self.head_dim) # 2 heads
|
| 79 |
+
self.v_proj = nn.Linear(hidden_dim, kv_heads * self.head_dim) # 2 heads
|
| 80 |
+
self.o_proj = nn.Linear(hidden_dim, hidden_dim)
|
| 81 |
+
|
| 82 |
+
# Gating mechanism (GPT-5 suggestion)
|
| 83 |
+
self.gate = nn.Sequential(
|
| 84 |
+
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 85 |
+
nn.Sigmoid()
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Gate monitoring (for analysis)
|
| 89 |
+
self.register_buffer('gate_values', torch.zeros(1))
|
| 90 |
+
|
| 91 |
+
# Warmup factor (GPT suggestion)
|
| 92 |
+
self.register_buffer('warmup_alpha', torch.tensor(1.0))
|
| 93 |
+
|
| 94 |
+
def forward(self,
|
| 95 |
+
query: torch.Tensor,
|
| 96 |
+
key: torch.Tensor,
|
| 97 |
+
value: torch.Tensor,
|
| 98 |
+
mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 99 |
+
"""
|
| 100 |
+
Forward pass with gate monitoring
|
| 101 |
+
Returns: (output, gate_values)
|
| 102 |
+
"""
|
| 103 |
+
batch_size, seq_len = query.shape[:2]
|
| 104 |
+
|
| 105 |
+
# Multi-head attention projections
|
| 106 |
+
Q = self.q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim)
|
| 107 |
+
K = self.k_proj(key).view(batch_size, -1, self.kv_heads, self.head_dim)
|
| 108 |
+
V = self.v_proj(value).view(batch_size, -1, self.kv_heads, self.head_dim)
|
| 109 |
+
|
| 110 |
+
# Transpose for attention computation
|
| 111 |
+
Q = Q.transpose(1, 2) # [batch, heads, seq, dim]
|
| 112 |
+
K = K.transpose(1, 2) # [batch, kv_heads, seq, dim]
|
| 113 |
+
V = V.transpose(1, 2)
|
| 114 |
+
|
| 115 |
+
# Repeat KV heads to match Q heads if necessary
|
| 116 |
+
if self.kv_heads < self.num_heads:
|
| 117 |
+
repeat_factor = self.num_heads // self.kv_heads
|
| 118 |
+
K = K.repeat_interleave(repeat_factor, dim=1)
|
| 119 |
+
V = V.repeat_interleave(repeat_factor, dim=1)
|
| 120 |
+
|
| 121 |
+
# Scaled dot-product attention
|
| 122 |
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 123 |
+
|
| 124 |
+
if mask is not None:
|
| 125 |
+
scores = scores.masked_fill(mask == 0, -1e9)
|
| 126 |
+
|
| 127 |
+
attn_weights = F.softmax(scores, dim=-1)
|
| 128 |
+
attn_output = torch.matmul(attn_weights, V)
|
| 129 |
+
|
| 130 |
+
# Reshape back
|
| 131 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 132 |
+
attn_output = attn_output.view(batch_size, seq_len, self.hidden_dim)
|
| 133 |
+
attn_output = self.o_proj(attn_output)
|
| 134 |
+
|
| 135 |
+
# Gating mechanism
|
| 136 |
+
gate_input = torch.cat([query, attn_output], dim=-1)
|
| 137 |
+
gate_values = self.gate(gate_input)
|
| 138 |
+
|
| 139 |
+
# Store gate values for monitoring (keep tensor shape consistent)
|
| 140 |
+
self.gate_values[0] = gate_values.mean().detach()
|
| 141 |
+
|
| 142 |
+
# Apply gate with warmup factor (GPT suggestion)
|
| 143 |
+
gate_values = gate_values * self.warmup_alpha
|
| 144 |
+
output = gate_values * attn_output + (1 - gate_values) * query
|
| 145 |
+
|
| 146 |
+
return output, gate_values
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class ProgressiveSplittingLayer(nn.Module):
|
| 151 |
+
"""
|
| 152 |
+
Core innovation: 48 bytes → 1 token → N tokens → M tokens
|
| 153 |
+
"""
|
| 154 |
+
|
| 155 |
+
def __init__(self, hidden_dim: int = 1280, config: Optional[Dict] = None):
|
| 156 |
+
super().__init__()
|
| 157 |
+
self.hidden_dim = hidden_dim
|
| 158 |
+
self.config = config or {}
|
| 159 |
+
|
| 160 |
+
# Dynamic splitting: 1~4 tokens for efficiency
|
| 161 |
+
# 48 bytes / 4 tokens = 12:1 compression (still beats BPE's 4:1)
|
| 162 |
+
self.min_tokens = 1 # 48:1 compression
|
| 163 |
+
self.max_tokens = 4 # 12:1 compression (still 3x better than BPE)
|
| 164 |
+
|
| 165 |
+
# Initial compression: 48 bytes → 1 super token
|
| 166 |
+
self.byte_embed = nn.Embedding(260, 64) # Small embedding
|
| 167 |
+
self.initial_compressor = nn.Sequential(
|
| 168 |
+
nn.Linear(48 * 64, 2048),
|
| 169 |
+
nn.LayerNorm(2048),
|
| 170 |
+
nn.ReLU(),
|
| 171 |
+
nn.Dropout(0.1),
|
| 172 |
+
nn.Linear(2048, hidden_dim),
|
| 173 |
+
nn.LayerNorm(hidden_dim)
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# Language-aware splitting: 1 → N tokens (config-based)
|
| 177 |
+
self.language_splitter = nn.ModuleDict({
|
| 178 |
+
'analyzer': nn.Sequential(
|
| 179 |
+
nn.Linear(hidden_dim, 512),
|
| 180 |
+
nn.ReLU(),
|
| 181 |
+
nn.Linear(512, 256) # Language features
|
| 182 |
+
),
|
| 183 |
+
'split_predictor': nn.Linear(256, self.max_tokens), # Predict 1~4 tokens
|
| 184 |
+
# Single unified expander that can produce any number of tokens
|
| 185 |
+
'dynamic_expander': nn.Sequential(
|
| 186 |
+
nn.Linear(hidden_dim, hidden_dim * 2),
|
| 187 |
+
nn.LayerNorm(hidden_dim * 2),
|
| 188 |
+
nn.GELU(), # Better than ReLU for transformers
|
| 189 |
+
nn.Linear(hidden_dim * 2, hidden_dim * self.max_tokens) # Can produce up to 4 tokens
|
| 190 |
+
),
|
| 191 |
+
# Token-wise importance predictor
|
| 192 |
+
'importance_predictor': nn.Sequential(
|
| 193 |
+
nn.Linear(hidden_dim, 256),
|
| 194 |
+
nn.ReLU(),
|
| 195 |
+
nn.Linear(256, self.max_tokens), # Importance for each potential token
|
| 196 |
+
nn.Softmax(dim=-1)
|
| 197 |
+
)
|
| 198 |
+
})
|
| 199 |
+
|
| 200 |
+
# Boundary refinement: N → M tokens with linguistic awareness
|
| 201 |
+
self.boundary_refiner = nn.ModuleDict({
|
| 202 |
+
'scorer': nn.Sequential(
|
| 203 |
+
nn.Linear(hidden_dim, 512),
|
| 204 |
+
nn.ReLU(),
|
| 205 |
+
nn.Linear(512, 1)
|
| 206 |
+
),
|
| 207 |
+
'morpheme_detector': nn.Conv1d(256, 64, 3), # 형태소
|
| 208 |
+
'word_detector': nn.Conv1d(256, 64, 5), # 단어
|
| 209 |
+
'phrase_detector': nn.Conv1d(256, 64, 7), # 구
|
| 210 |
+
'adjuster': nn.TransformerEncoderLayer(
|
| 211 |
+
d_model=hidden_dim,
|
| 212 |
+
nhead=16,
|
| 213 |
+
dim_feedforward=4 * hidden_dim,
|
| 214 |
+
dropout=0.1,
|
| 215 |
+
batch_first=True
|
| 216 |
+
)
|
| 217 |
+
})
|
| 218 |
+
|
| 219 |
+
# Initialize split_predictor bias to prefer 1 token initially
|
| 220 |
+
# This ensures untrained model starts with maximum compression
|
| 221 |
+
with torch.no_grad():
|
| 222 |
+
self.language_splitter['split_predictor'].bias.data = torch.tensor([2.0, -1.0, -1.0, -1.0])
|
| 223 |
+
# High bias for 1 token, negative for others
|
| 224 |
+
|
| 225 |
+
def forward(self, input_ids: torch.Tensor, temperature: float = 1.0) -> Dict[str, torch.Tensor]:
|
| 226 |
+
"""
|
| 227 |
+
Progressive splitting forward pass
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
input_ids: Input byte sequence [batch, seq_len]
|
| 231 |
+
temperature: Gumbel-Softmax temperature for annealing
|
| 232 |
+
"""
|
| 233 |
+
batch_size = input_ids.size(0)
|
| 234 |
+
|
| 235 |
+
# Step 1: 48 bytes → 1 super token
|
| 236 |
+
byte_embeddings = self.byte_embed(input_ids) # [batch, 48, 64]
|
| 237 |
+
flattened = byte_embeddings.view(batch_size, -1) # [batch, 3072]
|
| 238 |
+
super_token = self.initial_compressor(flattened) # [batch, 1280]
|
| 239 |
+
super_token = super_token.unsqueeze(1) # [batch, 1, 1280]
|
| 240 |
+
|
| 241 |
+
# Step 2: Language analysis and splitting (1 → N)
|
| 242 |
+
lang_features = self.language_splitter['analyzer'](super_token)
|
| 243 |
+
split_logits = self.language_splitter['split_predictor'](lang_features)
|
| 244 |
+
split_weights = F.softmax(split_logits, dim=-1) # [batch, 1, 8]
|
| 245 |
+
|
| 246 |
+
# Direct transformation from super token to initial representation
|
| 247 |
+
# No hardcoded splits - let the model learn everything
|
| 248 |
+
lang_tokens = super_token # Start with compressed representation
|
| 249 |
+
|
| 250 |
+
# TRUE Adaptive expansion - Model learns optimal split (1~4 tokens)
|
| 251 |
+
# Analyze content to decide how many tokens needed
|
| 252 |
+
expansion_features = self.language_splitter['analyzer'](lang_tokens) # [batch, 1, 256]
|
| 253 |
+
|
| 254 |
+
# Dynamic expansion: generate up to 4 tokens from super token
|
| 255 |
+
expanded = self.language_splitter['dynamic_expander'](lang_tokens.squeeze(1)) # [batch, hidden_dim*4]
|
| 256 |
+
expanded = expanded.reshape(batch_size, self.max_tokens, self.hidden_dim) # [batch, 4, hidden_dim]
|
| 257 |
+
|
| 258 |
+
# Predict how many tokens we actually need (1~4)
|
| 259 |
+
split_logits = self.language_splitter['split_predictor'](expansion_features.squeeze(1)) # [batch, 4]
|
| 260 |
+
# Clamp logits to prevent extreme values that cause NaN
|
| 261 |
+
split_logits = torch.clamp(split_logits, min=-10, max=10)
|
| 262 |
+
# Ensure minimum temperature to prevent instability
|
| 263 |
+
safe_temperature = max(temperature, 0.5)
|
| 264 |
+
split_weights = F.gumbel_softmax(split_logits, tau=safe_temperature, hard=False, dim=-1) # [batch, 4]
|
| 265 |
+
|
| 266 |
+
# Predict importance for each potential token position
|
| 267 |
+
importance = self.language_splitter['importance_predictor'](lang_tokens.squeeze(1)) # [batch, 4]
|
| 268 |
+
|
| 269 |
+
# Dynamic token selection with importance-weighted allocation
|
| 270 |
+
# Create cumulative mask for progressive token usage
|
| 271 |
+
# If split_weights = [0.1, 0.2, 0.6, 0.1], we mainly use 3 tokens
|
| 272 |
+
|
| 273 |
+
# Create progressive masks for 1, 2, 3, 4 tokens
|
| 274 |
+
masks = []
|
| 275 |
+
for n in range(1, self.max_tokens + 1):
|
| 276 |
+
mask = torch.zeros(batch_size, self.max_tokens, 1, device=expanded.device)
|
| 277 |
+
mask[:, :n, :] = 1.0
|
| 278 |
+
masks.append(mask)
|
| 279 |
+
|
| 280 |
+
# Apply importance-weighted masking
|
| 281 |
+
# Important parts get more tokens, less important parts get fewer
|
| 282 |
+
weighted_outputs = []
|
| 283 |
+
for i, mask in enumerate(masks):
|
| 284 |
+
num_tokens = i + 1
|
| 285 |
+
# Weight by both split decision and importance
|
| 286 |
+
token_weight = split_weights[:, i:i+1].unsqueeze(-1) # [batch, 1, 1]
|
| 287 |
+
|
| 288 |
+
# Apply importance modulation for asymmetric splits
|
| 289 |
+
if num_tokens > 1:
|
| 290 |
+
# Redistribute tokens based on importance
|
| 291 |
+
importance_adjusted = importance[:, :num_tokens].unsqueeze(-1) # [batch, n, 1]
|
| 292 |
+
masked = expanded[:, :num_tokens] * importance_adjusted
|
| 293 |
+
else:
|
| 294 |
+
masked = expanded[:, :num_tokens]
|
| 295 |
+
|
| 296 |
+
# Pad to max length
|
| 297 |
+
if num_tokens < self.max_tokens:
|
| 298 |
+
padding = torch.zeros(batch_size, self.max_tokens - num_tokens, self.hidden_dim,
|
| 299 |
+
device=expanded.device)
|
| 300 |
+
masked = torch.cat([masked, padding], dim=1)
|
| 301 |
+
|
| 302 |
+
weighted_outputs.append(masked * token_weight)
|
| 303 |
+
|
| 304 |
+
# Sum all weighted possibilities (differentiable selection)
|
| 305 |
+
lang_tokens = sum(weighted_outputs)
|
| 306 |
+
|
| 307 |
+
# Determine effective number of tokens (for monitoring)
|
| 308 |
+
# Weighted average of token counts
|
| 309 |
+
token_counts = torch.arange(1, self.max_tokens + 1, device=split_weights.device, dtype=torch.float32)
|
| 310 |
+
avg_tokens = (split_weights * token_counts).sum(dim=-1).mean().item()
|
| 311 |
+
|
| 312 |
+
k = lang_tokens.size(1)
|
| 313 |
+
|
| 314 |
+
# Step 3: Boundary refinement (N → M)
|
| 315 |
+
# Calculate boundary scores for each token position
|
| 316 |
+
boundary_scores = self.boundary_refiner['scorer'](lang_tokens) # [batch, N, 1]
|
| 317 |
+
|
| 318 |
+
# Detect linguistic boundaries (morpheme, word, phrase)
|
| 319 |
+
# Extract features for boundary detection
|
| 320 |
+
if hasattr(lang_tokens, 'shape') and len(lang_tokens.shape) == 3:
|
| 321 |
+
batch_size, num_tokens, hidden_dim = lang_tokens.shape
|
| 322 |
+
|
| 323 |
+
# For boundary detection, we need to consider the original byte sequence
|
| 324 |
+
# But we're working with compressed tokens here
|
| 325 |
+
# So we detect boundaries based on learned representations
|
| 326 |
+
|
| 327 |
+
# Apply boundary adjustment with TransformerEncoderLayer
|
| 328 |
+
# This learns to adjust token boundaries based on context
|
| 329 |
+
refined_tokens = self.boundary_refiner['adjuster'](lang_tokens)
|
| 330 |
+
|
| 331 |
+
# The adjuster should learn to:
|
| 332 |
+
# 1. Respect UTF-8 boundaries (learned during training)
|
| 333 |
+
# 2. Align with word/phrase boundaries (learned from language patterns)
|
| 334 |
+
# 3. Maintain semantic coherence within each token
|
| 335 |
+
else:
|
| 336 |
+
refined_tokens = lang_tokens
|
| 337 |
+
|
| 338 |
+
# Determine actual number of tokens based on highest probability
|
| 339 |
+
# During inference, use argmax. During training, use weighted average.
|
| 340 |
+
if self.training:
|
| 341 |
+
# During training, use weighted average for differentiability
|
| 342 |
+
actual_num_tokens = avg_tokens
|
| 343 |
+
else:
|
| 344 |
+
# During inference, select the split with highest probability
|
| 345 |
+
split_decision = torch.argmax(split_weights, dim=-1) # [batch]
|
| 346 |
+
actual_num_tokens = (split_decision.float().mean() + 1).item() # +1 because indices are 0-3
|
| 347 |
+
|
| 348 |
+
# Calculate compression ratio based on actual tokens used
|
| 349 |
+
compression_ratio = 48.0 / max(1, actual_num_tokens)
|
| 350 |
+
|
| 351 |
+
return {
|
| 352 |
+
'tokens': refined_tokens,
|
| 353 |
+
'num_tokens': actual_num_tokens,
|
| 354 |
+
'compression_ratio': torch.tensor(compression_ratio, device=refined_tokens.device),
|
| 355 |
+
'gate_values': None, # Will be filled by cross-attention
|
| 356 |
+
'language_features': lang_features,
|
| 357 |
+
'split_weights': split_weights,
|
| 358 |
+
'avg_tokens': avg_tokens if 'avg_tokens' in locals() else refined_tokens.size(1),
|
| 359 |
+
'split_distribution': split_weights.mean(dim=0) if 'split_weights' in locals() else None
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
class EncoderV62(nn.Module):
|
| 364 |
+
"""
|
| 365 |
+
4-Layer Progressive Splitting Encoder with Cross-Attention
|
| 366 |
+
All layers: 1280 dimensions
|
| 367 |
+
"""
|
| 368 |
+
|
| 369 |
+
def __init__(self, config: Optional[Dict] = None):
|
| 370 |
+
super().__init__()
|
| 371 |
+
|
| 372 |
+
# Store config for later use
|
| 373 |
+
self.config = config or {}
|
| 374 |
+
|
| 375 |
+
# Configuration
|
| 376 |
+
self.hidden_dim = 1280
|
| 377 |
+
self.num_heads = 16
|
| 378 |
+
self.num_layers = 4
|
| 379 |
+
self.max_seq_len = 48
|
| 380 |
+
self.dropout = 0.1
|
| 381 |
+
|
| 382 |
+
# RoPE positional encoding (GPT-5 suggestion)
|
| 383 |
+
self.rope = RoPEPositionalEncoding(self.hidden_dim, self.max_seq_len)
|
| 384 |
+
|
| 385 |
+
# Layer 0: Progressive Splitting (48→1→N→M) - Pass config
|
| 386 |
+
self.progressive_splitter = ProgressiveSplittingLayer(self.hidden_dim, config)
|
| 387 |
+
|
| 388 |
+
# Layers 1-3: Transformer encoders with cross-attention
|
| 389 |
+
self.encoder_layers = nn.ModuleList([
|
| 390 |
+
nn.TransformerEncoderLayer(
|
| 391 |
+
d_model=self.hidden_dim,
|
| 392 |
+
nhead=self.num_heads,
|
| 393 |
+
dim_feedforward=4 * self.hidden_dim, # 5120
|
| 394 |
+
dropout=self.dropout,
|
| 395 |
+
batch_first=True
|
| 396 |
+
) for _ in range(3)
|
| 397 |
+
])
|
| 398 |
+
|
| 399 |
+
# Cross-attention between layers with MQA (GPT-5 suggestion)
|
| 400 |
+
self.cross_attentions = nn.ModuleList([
|
| 401 |
+
GatedCrossAttention(self.hidden_dim, self.num_heads, kv_heads=2) # 8x memory reduction
|
| 402 |
+
for _ in range(3)
|
| 403 |
+
])
|
| 404 |
+
|
| 405 |
+
# Output heads for different tasks
|
| 406 |
+
self.boundary_head = nn.Linear(self.hidden_dim, 4)
|
| 407 |
+
self.language_head = nn.Linear(self.hidden_dim, 128) # Reduced from 512 (GPT suggestion)
|
| 408 |
+
self.compression_head = nn.Linear(self.hidden_dim, self.hidden_dim)
|
| 409 |
+
|
| 410 |
+
# Monitoring metrics (GPT-5 suggestion)
|
| 411 |
+
self.register_buffer('compression_ratios', torch.zeros(1))
|
| 412 |
+
self.register_buffer('gate_averages', torch.zeros(3))
|
| 413 |
+
|
| 414 |
+
def forward(self,
|
| 415 |
+
input_ids: torch.Tensor,
|
| 416 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 417 |
+
temperature: float = 1.0) -> Dict[str, torch.Tensor]:
|
| 418 |
+
"""
|
| 419 |
+
Forward pass through the encoder
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
input_ids: Input byte sequence
|
| 423 |
+
attention_mask: Optional attention mask
|
| 424 |
+
temperature: Gumbel-Softmax temperature for annealing
|
| 425 |
+
"""
|
| 426 |
+
# Layer 0: Progressive splitting with temperature
|
| 427 |
+
split_output = self.progressive_splitter(input_ids, temperature)
|
| 428 |
+
x = split_output['tokens'] # [batch, M, 1280]
|
| 429 |
+
|
| 430 |
+
# Apply RoPE
|
| 431 |
+
x = self.rope(x, x.size(1))
|
| 432 |
+
|
| 433 |
+
# Store all hidden states for decoder
|
| 434 |
+
all_hidden_states = [x]
|
| 435 |
+
gate_values_list = []
|
| 436 |
+
|
| 437 |
+
# Layers 1-3 with cross-attention
|
| 438 |
+
for i, (encoder_layer, cross_attn) in enumerate(
|
| 439 |
+
zip(self.encoder_layers, self.cross_attentions)
|
| 440 |
+
):
|
| 441 |
+
# Self-attention through transformer layer
|
| 442 |
+
# GPT final check: Don't pass mask after progressive splitting changes sequence length
|
| 443 |
+
x = encoder_layer(x) # No mask needed (no padding after compression)
|
| 444 |
+
|
| 445 |
+
# Cross-attention with previous layer
|
| 446 |
+
if i > 0:
|
| 447 |
+
# Cross-attention with previous layer
|
| 448 |
+
x, gate_values = cross_attn(
|
| 449 |
+
query=x,
|
| 450 |
+
key=all_hidden_states[-1],
|
| 451 |
+
value=all_hidden_states[-1],
|
| 452 |
+
mask=None # Mask not applicable after compression
|
| 453 |
+
)
|
| 454 |
+
gate_values_list.append(gate_values)
|
| 455 |
+
# Keep tensor shape consistent - store in existing buffer element
|
| 456 |
+
self.gate_averages[i-1] = gate_values.mean().detach().item() # Fix indexing
|
| 457 |
+
|
| 458 |
+
all_hidden_states.append(x)
|
| 459 |
+
|
| 460 |
+
# Output projections
|
| 461 |
+
boundaries = self.boundary_head(x)
|
| 462 |
+
language_clusters = self.language_head(x)
|
| 463 |
+
compressed = self.compression_head(x)
|
| 464 |
+
|
| 465 |
+
# Update monitoring metrics
|
| 466 |
+
# Ensure tensor is 1-dimensional for buffer assignment
|
| 467 |
+
compression_ratio = split_output['compression_ratio']
|
| 468 |
+
if compression_ratio.dim() == 0: # Scalar tensor
|
| 469 |
+
self.compression_ratios[0] = compression_ratio
|
| 470 |
+
else:
|
| 471 |
+
self.compression_ratios = compression_ratio
|
| 472 |
+
|
| 473 |
+
return {
|
| 474 |
+
'last_hidden_state': x,
|
| 475 |
+
'all_hidden_states': all_hidden_states,
|
| 476 |
+
'boundaries': boundaries,
|
| 477 |
+
'language_clusters': language_clusters,
|
| 478 |
+
'compressed': compressed,
|
| 479 |
+
'compression_ratio': split_output['compression_ratio'],
|
| 480 |
+
'num_tokens': split_output['num_tokens'],
|
| 481 |
+
'splitting_probs': split_output.get('split_weights', None), # Add for diagnostics
|
| 482 |
+
'gate_values': gate_values_list,
|
| 483 |
+
'gate_averages': self.gate_averages,
|
| 484 |
+
'split_info': {
|
| 485 |
+
'language_features': split_output['language_features'],
|
| 486 |
+
'split_weights': split_output['split_weights']
|
| 487 |
+
}
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
def get_monitoring_stats(self) -> Dict[str, float]:
|
| 491 |
+
"""
|
| 492 |
+
Get monitoring statistics (GPT-5 suggestion)
|
| 493 |
+
"""
|
| 494 |
+
return {
|
| 495 |
+
'avg_compression_ratio': self.compression_ratios.item(),
|
| 496 |
+
'gate_layer1': self.gate_averages[0].item(),
|
| 497 |
+
'gate_layer2': self.gate_averages[1].item(),
|
| 498 |
+
'gate_layer3': self.gate_averages[2].item(),
|
| 499 |
+
}
|
| 500 |
+
|
| 501 |
+
def set_warmup_step(self, step: int, total_warmup: int = 1000):
|
| 502 |
+
"""
|
| 503 |
+
Set warmup alpha for all gates (GPT suggestion)
|
| 504 |
+
Gradually increase gate influence from 0 to 1
|
| 505 |
+
"""
|
| 506 |
+
alpha = min(1.0, step / total_warmup)
|
| 507 |
+
for cross_attn in self.cross_attentions:
|
| 508 |
+
cross_attn.warmup_alpha = torch.tensor(alpha, device=cross_attn.warmup_alpha.device)
|
| 509 |
+
|
| 510 |
+
def adaptive_compression_control(self, reconstruction_loss: float):
|
| 511 |
+
"""
|
| 512 |
+
Adaptive compression based on reconstruction quality
|
| 513 |
+
No fixed phases - model learns optimal compression
|
| 514 |
+
"""
|
| 515 |
+
# If reconstruction is poor, model will learn to use more tokens
|
| 516 |
+
# This happens automatically through gradient descent
|
| 517 |
+
# No manual phase control needed
|
| 518 |
+
pass # Let gradients handle it
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
class DualSlidingWindowEncoder(EncoderV62):
|
| 522 |
+
"""
|
| 523 |
+
Extension with dual sliding window system
|
| 524 |
+
Handles both chunk-level and token-level boundaries
|
| 525 |
+
"""
|
| 526 |
+
|
| 527 |
+
def __init__(self, config: Optional[Dict] = None):
|
| 528 |
+
super().__init__(config)
|
| 529 |
+
|
| 530 |
+
# Chunk-level sliding window
|
| 531 |
+
self.chunk_window = nn.Conv1d(
|
| 532 |
+
in_channels=1,
|
| 533 |
+
out_channels=1,
|
| 534 |
+
kernel_size=8, # 8-byte overlap
|
| 535 |
+
stride=40, # 48-8=40 stride
|
| 536 |
+
padding=4
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
# Token-level sliding window
|
| 540 |
+
self.token_window = nn.MultiheadAttention(
|
| 541 |
+
embed_dim=self.hidden_dim,
|
| 542 |
+
num_heads=self.num_heads,
|
| 543 |
+
batch_first=True
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
def process_long_sequence(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 547 |
+
"""
|
| 548 |
+
Handle sequences longer than 48 bytes with sliding windows
|
| 549 |
+
"""
|
| 550 |
+
batch_size, seq_len = input_ids.shape
|
| 551 |
+
|
| 552 |
+
if seq_len <= 48:
|
| 553 |
+
return super().forward(input_ids)
|
| 554 |
+
|
| 555 |
+
# Process in chunks with overlap
|
| 556 |
+
chunks = []
|
| 557 |
+
for i in range(0, seq_len - 48 + 1, 40): # 8-byte overlap
|
| 558 |
+
chunk = input_ids[:, i:i+48]
|
| 559 |
+
chunk_output = super().forward(chunk)
|
| 560 |
+
chunks.append(chunk_output['last_hidden_state'])
|
| 561 |
+
|
| 562 |
+
# Combine chunks with attention
|
| 563 |
+
combined = torch.cat(chunks, dim=1)
|
| 564 |
+
attended, _ = self.token_window(combined, combined, combined)
|
| 565 |
+
|
| 566 |
+
return {
|
| 567 |
+
'last_hidden_state': attended,
|
| 568 |
+
'num_chunks': len(chunks),
|
| 569 |
+
'total_compression': seq_len / attended.size(1)
|
| 570 |
+
}
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
if __name__ == "__main__":
|
| 574 |
+
# Test the encoder
|
| 575 |
+
encoder = EncoderV62()
|
| 576 |
+
|
| 577 |
+
# Test input
|
| 578 |
+
batch_size = 2
|
| 579 |
+
input_ids = torch.randint(0, 256, (batch_size, 48))
|
| 580 |
+
|
| 581 |
+
# Forward pass
|
| 582 |
+
output = encoder(input_ids)
|
| 583 |
+
|
| 584 |
+
print(f"Input shape: {input_ids.shape}")
|
| 585 |
+
print(f"Output tokens: {output['num_tokens']}")
|
| 586 |
+
print(f"Compression ratio: {output['compression_ratio']:.2f}:1")
|
| 587 |
+
print(f"Gate averages: {output['gate_averages']}")
|
| 588 |
+
print(f"Monitoring stats: {encoder.get_monitoring_stats()}")
|
core/intelligent_loss.py
ADDED
|
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Intelligent Loss Functions for v6.2.0
|
| 3 |
+
Multi-objective loss with GPT-5 suggested improvements
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from typing import Dict, Optional, Tuple
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class IntelligentLoss(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Comprehensive loss function for progressive splitting tokenizer
|
| 16 |
+
Combines multiple objectives with dynamic weighting
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, config: Optional[Dict] = None):
|
| 20 |
+
super().__init__()
|
| 21 |
+
|
| 22 |
+
# Default configuration
|
| 23 |
+
self.config = config or {}
|
| 24 |
+
|
| 25 |
+
# Special tokens (must match tokenizer)
|
| 26 |
+
self.PAD = 256
|
| 27 |
+
self.BOS = 257
|
| 28 |
+
self.EOS = 258
|
| 29 |
+
self.MASK = 259
|
| 30 |
+
|
| 31 |
+
# Loss components
|
| 32 |
+
self.reconstruction_loss = ReconstructionLoss(self.PAD)
|
| 33 |
+
self.compression_loss = CompressionLoss()
|
| 34 |
+
self.boundary_loss = BoundaryLoss()
|
| 35 |
+
self.language_loss = LanguageLoss()
|
| 36 |
+
self.consistency_loss = ConsistencyLoss()
|
| 37 |
+
|
| 38 |
+
# Dynamic weight adjustment
|
| 39 |
+
self.use_dynamic_weights = True
|
| 40 |
+
self.weight_history = {
|
| 41 |
+
'reconstruction': [],
|
| 42 |
+
'compression': [],
|
| 43 |
+
'boundary': [],
|
| 44 |
+
'language': [],
|
| 45 |
+
'consistency': []
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
def estimate_language_difficulty(self, targets: Dict) -> float:
|
| 49 |
+
"""Estimate language difficulty based on input characteristics"""
|
| 50 |
+
if 'input_ids' not in targets:
|
| 51 |
+
return 1.0
|
| 52 |
+
|
| 53 |
+
input_ids = targets['input_ids']
|
| 54 |
+
if input_ids.numel() == 0:
|
| 55 |
+
return 1.0
|
| 56 |
+
|
| 57 |
+
# Higher entropy = more complex language
|
| 58 |
+
unique_tokens = input_ids.unique().numel()
|
| 59 |
+
total_tokens = input_ids.numel()
|
| 60 |
+
diversity = min(1.0, (unique_tokens / total_tokens) * 2)
|
| 61 |
+
|
| 62 |
+
return diversity
|
| 63 |
+
|
| 64 |
+
def forward(self,
|
| 65 |
+
outputs: Dict[str, torch.Tensor],
|
| 66 |
+
targets: Dict[str, torch.Tensor],
|
| 67 |
+
weights: Optional[Dict[str, float]] = None) -> Dict[str, torch.Tensor]:
|
| 68 |
+
"""
|
| 69 |
+
Compute combined loss with all objectives
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
outputs: Model outputs dictionary
|
| 73 |
+
targets: Target values dictionary
|
| 74 |
+
weights: Optional weight overrides
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Dictionary with total loss and individual components
|
| 78 |
+
"""
|
| 79 |
+
losses = {}
|
| 80 |
+
|
| 81 |
+
# 1. Reconstruction loss (primary objective)
|
| 82 |
+
if 'logits' in outputs and 'input_ids' in targets:
|
| 83 |
+
losses['reconstruction'] = self.reconstruction_loss(
|
| 84 |
+
outputs['logits'],
|
| 85 |
+
targets['input_ids'],
|
| 86 |
+
targets.get('attention_mask')
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# 2. Compression loss (encourage optimal compression)
|
| 90 |
+
if 'compression_ratio' in outputs:
|
| 91 |
+
losses['compression'] = self.compression_loss(
|
| 92 |
+
outputs['compression_ratio'],
|
| 93 |
+
outputs.get('num_tokens')
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# 3. Boundary loss (learn meaningful boundaries)
|
| 97 |
+
if 'boundaries' in outputs and 'boundary_targets' in targets:
|
| 98 |
+
losses['boundary'] = self.boundary_loss(
|
| 99 |
+
outputs['boundaries'],
|
| 100 |
+
targets['boundary_targets'],
|
| 101 |
+
targets.get('boundary_mask')
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# 4. Language loss (language identification/clustering)
|
| 105 |
+
if 'language_clusters' in outputs and 'language_targets' in targets:
|
| 106 |
+
losses['language'] = self.language_loss(
|
| 107 |
+
outputs['language_clusters'],
|
| 108 |
+
targets['language_targets']
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# 5. Consistency loss (encoder-decoder consistency)
|
| 112 |
+
if 'encoder_hidden' in outputs and 'decoder_hidden' in outputs:
|
| 113 |
+
losses['consistency'] = self.consistency_loss(
|
| 114 |
+
outputs['encoder_hidden'],
|
| 115 |
+
outputs['decoder_hidden']
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Apply weights (either provided or dynamic)
|
| 119 |
+
if weights is None and self.use_dynamic_weights:
|
| 120 |
+
weights = self.compute_dynamic_weights(losses)
|
| 121 |
+
elif weights is None:
|
| 122 |
+
weights = {
|
| 123 |
+
'reconstruction': 1.0,
|
| 124 |
+
'compression': 1.0,
|
| 125 |
+
'boundary': 1.0,
|
| 126 |
+
'language': 0.5,
|
| 127 |
+
'consistency': 0.5
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
# Weighted sum
|
| 131 |
+
total_loss = torch.tensor(0.0, device=next(iter(losses.values())).device)
|
| 132 |
+
for key, loss in losses.items():
|
| 133 |
+
weight = weights.get(key, 1.0)
|
| 134 |
+
total_loss = total_loss + weight * loss
|
| 135 |
+
losses[f'{key}_weighted'] = weight * loss
|
| 136 |
+
|
| 137 |
+
losses['total'] = total_loss
|
| 138 |
+
|
| 139 |
+
# Update weight history
|
| 140 |
+
for key in self.weight_history:
|
| 141 |
+
if key in losses:
|
| 142 |
+
self.weight_history[key].append(losses[key].item())
|
| 143 |
+
|
| 144 |
+
return losses
|
| 145 |
+
|
| 146 |
+
def compute_dynamic_weights(self, losses: Dict[str, torch.Tensor]) -> Dict[str, float]:
|
| 147 |
+
"""
|
| 148 |
+
Dynamically adjust weights based on loss magnitudes and progress
|
| 149 |
+
GPT-5 suggestion: balance loss magnitudes for stable training
|
| 150 |
+
"""
|
| 151 |
+
weights = {}
|
| 152 |
+
eps = 1e-8 # GPT fix: prevent division by zero
|
| 153 |
+
|
| 154 |
+
# Get loss magnitudes with NaN protection
|
| 155 |
+
magnitudes = {}
|
| 156 |
+
for k, v in losses.items():
|
| 157 |
+
if torch.isnan(v) or torch.isinf(v):
|
| 158 |
+
magnitudes[k] = 1.0 # Default safe value
|
| 159 |
+
else:
|
| 160 |
+
magnitudes[k] = v.item()
|
| 161 |
+
|
| 162 |
+
# Compute relative scales (GPT fix: add epsilon)
|
| 163 |
+
avg_magnitude = max(eps, sum(magnitudes.values()) / len(magnitudes))
|
| 164 |
+
|
| 165 |
+
for key, magnitude in magnitudes.items():
|
| 166 |
+
# Inverse scaling to balance magnitudes (GPT fix: add epsilon)
|
| 167 |
+
weights[key] = avg_magnitude / max(eps, magnitude)
|
| 168 |
+
|
| 169 |
+
# Dynamic adjustment based on loss ratios
|
| 170 |
+
if 'reconstruction' in magnitudes and 'compression' in magnitudes:
|
| 171 |
+
recon_loss = magnitudes['reconstruction']
|
| 172 |
+
comp_loss = magnitudes['compression']
|
| 173 |
+
|
| 174 |
+
# If reconstruction loss is too high relative to compression
|
| 175 |
+
if recon_loss > comp_loss * 10:
|
| 176 |
+
# Drastically reduce compression pressure
|
| 177 |
+
weights['compression'] *= 0.1
|
| 178 |
+
weights['reconstruction'] *= 5.0
|
| 179 |
+
elif recon_loss > comp_loss * 5:
|
| 180 |
+
# Moderate adjustment
|
| 181 |
+
weights['compression'] *= 0.5
|
| 182 |
+
weights['reconstruction'] *= 2.0
|
| 183 |
+
elif recon_loss < comp_loss * 0.5:
|
| 184 |
+
# Good reconstruction, can push compression
|
| 185 |
+
weights['compression'] *= 2.0
|
| 186 |
+
weights['reconstruction'] *= 0.5
|
| 187 |
+
|
| 188 |
+
# Normalize weights to prevent explosion
|
| 189 |
+
total_weight = sum(weights.values())
|
| 190 |
+
if total_weight > 0:
|
| 191 |
+
weights = {k: min(10.0, v / total_weight * len(weights)) for k, v in weights.items()}
|
| 192 |
+
|
| 193 |
+
return weights
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class ReconstructionLoss(nn.Module):
|
| 197 |
+
"""
|
| 198 |
+
Cross-entropy loss for sequence reconstruction
|
| 199 |
+
With label smoothing and focal loss options
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
def __init__(self, pad_token: int = 256, label_smoothing: float = 0.1):
|
| 203 |
+
super().__init__()
|
| 204 |
+
self.pad_token = pad_token
|
| 205 |
+
self.label_smoothing = label_smoothing
|
| 206 |
+
self.focal_alpha = 0.25
|
| 207 |
+
self.focal_gamma = 2.0
|
| 208 |
+
self.use_focal = False
|
| 209 |
+
|
| 210 |
+
def forward(self,
|
| 211 |
+
logits: torch.Tensor,
|
| 212 |
+
targets: torch.Tensor,
|
| 213 |
+
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 214 |
+
"""
|
| 215 |
+
Compute reconstruction loss
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
logits: [batch, seq_len, vocab_size]
|
| 219 |
+
targets: [batch, seq_len]
|
| 220 |
+
mask: [batch, seq_len] attention mask
|
| 221 |
+
"""
|
| 222 |
+
batch_size, seq_len, vocab_size = logits.shape
|
| 223 |
+
|
| 224 |
+
# Reshape for loss computation
|
| 225 |
+
logits_flat = logits.reshape(-1, vocab_size)
|
| 226 |
+
targets_flat = targets.reshape(-1)
|
| 227 |
+
|
| 228 |
+
if self.use_focal:
|
| 229 |
+
# Focal loss for hard examples
|
| 230 |
+
ce_loss = F.cross_entropy(logits_flat, targets_flat, reduction='none')
|
| 231 |
+
pt = torch.exp(-ce_loss)
|
| 232 |
+
focal_loss = self.focal_alpha * (1 - pt) ** self.focal_gamma * ce_loss
|
| 233 |
+
|
| 234 |
+
if mask is not None:
|
| 235 |
+
mask_flat = mask.reshape(-1)
|
| 236 |
+
focal_loss = focal_loss * mask_flat
|
| 237 |
+
loss = focal_loss.sum() / mask_flat.sum()
|
| 238 |
+
else:
|
| 239 |
+
loss = focal_loss.mean()
|
| 240 |
+
else:
|
| 241 |
+
# Standard cross-entropy with label smoothing
|
| 242 |
+
if mask is not None:
|
| 243 |
+
mask_flat = mask.reshape(-1).bool() # GPT fix: ensure bool dtype
|
| 244 |
+
loss = F.cross_entropy(
|
| 245 |
+
logits_flat[mask_flat],
|
| 246 |
+
targets_flat[mask_flat],
|
| 247 |
+
ignore_index=self.pad_token,
|
| 248 |
+
label_smoothing=self.label_smoothing
|
| 249 |
+
)
|
| 250 |
+
else:
|
| 251 |
+
loss = F.cross_entropy(
|
| 252 |
+
logits_flat,
|
| 253 |
+
targets_flat,
|
| 254 |
+
ignore_index=self.pad_token,
|
| 255 |
+
label_smoothing=self.label_smoothing
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
return loss
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class CompressionLoss(nn.Module):
|
| 262 |
+
"""
|
| 263 |
+
Aggressive compression loss - push for high compression
|
| 264 |
+
Must beat existing tokenizers (4 bytes/token = 4:1)
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
def __init__(self):
|
| 268 |
+
super().__init__()
|
| 269 |
+
# Dynamic compression based on token count
|
| 270 |
+
# 1 token = 48:1, 2 = 24:1, 3 = 16:1, 4 = 12:1
|
| 271 |
+
self.min_ratio = 12.0 # 4 tokens (worst case, still 3x better than BPE)
|
| 272 |
+
self.target_ratio = 24.0 # 2 tokens (optimal balance)
|
| 273 |
+
self.max_ratio = 48.0 # 1 token (best compression)
|
| 274 |
+
|
| 275 |
+
def forward(self,
|
| 276 |
+
compression_ratio: torch.Tensor,
|
| 277 |
+
num_tokens: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 278 |
+
"""
|
| 279 |
+
Compute compression loss (GPT fix: fully vectorized)
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
compression_ratio: Current compression ratio (scalar or batch)
|
| 283 |
+
num_tokens: Number of tokens used (for additional penalty)
|
| 284 |
+
"""
|
| 285 |
+
# Ensure tensor (GPT fix: handle device properly)
|
| 286 |
+
if not torch.is_tensor(compression_ratio):
|
| 287 |
+
device = num_tokens.device if torch.is_tensor(num_tokens) else torch.device('cpu')
|
| 288 |
+
compression_ratio = torch.tensor(compression_ratio, dtype=torch.float32, device=device)
|
| 289 |
+
|
| 290 |
+
# Aggressive compression enforcement
|
| 291 |
+
# MUST achieve at least 16:1 to be viable
|
| 292 |
+
if compression_ratio < self.min_ratio:
|
| 293 |
+
# Moderate penalty for falling below minimum (reduced for stability)
|
| 294 |
+
under_loss = ((self.min_ratio - compression_ratio) / self.min_ratio) * 0.5
|
| 295 |
+
else:
|
| 296 |
+
under_loss = torch.tensor(0.0, dtype=compression_ratio.dtype, device=compression_ratio.device)
|
| 297 |
+
|
| 298 |
+
# Reward getting close to target (24:1)
|
| 299 |
+
if self.min_ratio <= compression_ratio < self.target_ratio:
|
| 300 |
+
# Encourage reaching target
|
| 301 |
+
target_loss = ((self.target_ratio - compression_ratio) / self.target_ratio) * 0.5
|
| 302 |
+
elif compression_ratio >= self.target_ratio:
|
| 303 |
+
# Excellent compression - small reward for going higher
|
| 304 |
+
target_loss = -0.1 * torch.log(compression_ratio / self.target_ratio + 1.0)
|
| 305 |
+
else:
|
| 306 |
+
target_loss = torch.tensor(0.0, dtype=compression_ratio.dtype, device=compression_ratio.device)
|
| 307 |
+
|
| 308 |
+
# Only mild penalty for extreme compression (>48:1)
|
| 309 |
+
if compression_ratio > self.max_ratio:
|
| 310 |
+
over_loss = ((compression_ratio - self.max_ratio) / self.max_ratio) * 0.2
|
| 311 |
+
else:
|
| 312 |
+
over_loss = torch.tensor(0.0, dtype=compression_ratio.dtype, device=compression_ratio.device)
|
| 313 |
+
|
| 314 |
+
loss = under_loss + target_loss + over_loss
|
| 315 |
+
|
| 316 |
+
# Additional penalty based on token count (GPT fix: vectorized)
|
| 317 |
+
if num_tokens is not None:
|
| 318 |
+
if not torch.is_tensor(num_tokens):
|
| 319 |
+
num_tokens = torch.tensor(num_tokens, dtype=torch.float32, device=compression_ratio.device)
|
| 320 |
+
token_penalty = 0.1 * torch.clamp(num_tokens - 8, min=0.0) ** 2
|
| 321 |
+
loss = loss + token_penalty
|
| 322 |
+
|
| 323 |
+
return loss.mean() if loss.dim() > 0 else loss
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class BoundaryLoss(nn.Module):
|
| 327 |
+
"""
|
| 328 |
+
Learn meaningful chunk boundaries
|
| 329 |
+
Combines multiple boundary objectives
|
| 330 |
+
"""
|
| 331 |
+
|
| 332 |
+
def __init__(self):
|
| 333 |
+
super().__init__()
|
| 334 |
+
self.bce_loss = nn.BCEWithLogitsLoss(reduction='none')
|
| 335 |
+
|
| 336 |
+
def forward(self,
|
| 337 |
+
predicted: torch.Tensor,
|
| 338 |
+
target: torch.Tensor,
|
| 339 |
+
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 340 |
+
"""
|
| 341 |
+
Compute boundary loss
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
predicted: [batch, seq_len, boundary_classes] predicted boundaries
|
| 345 |
+
target: [batch, seq_len, boundary_classes] target boundaries
|
| 346 |
+
mask: [batch, seq_len] valid positions mask
|
| 347 |
+
"""
|
| 348 |
+
# Binary cross-entropy for boundary prediction
|
| 349 |
+
loss = self.bce_loss(predicted, target.float())
|
| 350 |
+
|
| 351 |
+
if mask is not None:
|
| 352 |
+
# Apply mask
|
| 353 |
+
mask_expanded = mask.unsqueeze(-1).expand_as(loss)
|
| 354 |
+
loss = loss * mask_expanded
|
| 355 |
+
loss = loss.sum() / mask_expanded.sum()
|
| 356 |
+
else:
|
| 357 |
+
loss = loss.mean()
|
| 358 |
+
|
| 359 |
+
# Add regularization for boundary sparsity
|
| 360 |
+
# (boundaries should be relatively rare)
|
| 361 |
+
boundary_probs = torch.sigmoid(predicted)
|
| 362 |
+
sparsity_loss = 0.01 * boundary_probs.mean()
|
| 363 |
+
|
| 364 |
+
# Add smoothness regularization
|
| 365 |
+
# (boundaries should be somewhat smooth/continuous)
|
| 366 |
+
if predicted.size(1) > 1:
|
| 367 |
+
diff = predicted[:, 1:] - predicted[:, :-1]
|
| 368 |
+
smoothness_loss = 0.01 * (diff ** 2).mean()
|
| 369 |
+
else:
|
| 370 |
+
smoothness_loss = 0.0
|
| 371 |
+
|
| 372 |
+
total_loss = loss + sparsity_loss + smoothness_loss
|
| 373 |
+
|
| 374 |
+
return total_loss
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
class LanguageLoss(nn.Module):
|
| 378 |
+
"""
|
| 379 |
+
Language identification/clustering loss
|
| 380 |
+
Supports both classification and clustering objectives
|
| 381 |
+
"""
|
| 382 |
+
|
| 383 |
+
def __init__(self, num_languages: int = 128, temperature: float = 0.07):
|
| 384 |
+
super().__init__()
|
| 385 |
+
self.num_languages = num_languages
|
| 386 |
+
self.temperature = temperature
|
| 387 |
+
|
| 388 |
+
# For supervised language classification
|
| 389 |
+
self.ce_loss = nn.CrossEntropyLoss()
|
| 390 |
+
|
| 391 |
+
def forward(self,
|
| 392 |
+
predicted: torch.Tensor,
|
| 393 |
+
target: torch.Tensor,
|
| 394 |
+
mode: str = 'classification') -> torch.Tensor:
|
| 395 |
+
"""
|
| 396 |
+
Compute language loss
|
| 397 |
+
|
| 398 |
+
Args:
|
| 399 |
+
predicted: [batch, seq_len, num_languages] or [batch, num_languages]
|
| 400 |
+
target: Language labels or cluster assignments
|
| 401 |
+
mode: 'classification' or 'clustering'
|
| 402 |
+
"""
|
| 403 |
+
if mode == 'classification':
|
| 404 |
+
# Standard classification loss
|
| 405 |
+
if predicted.dim() == 3:
|
| 406 |
+
# Sequence-level predictions
|
| 407 |
+
batch_size, seq_len, _ = predicted.shape
|
| 408 |
+
predicted = predicted.reshape(-1, self.num_languages)
|
| 409 |
+
target = target.reshape(-1)
|
| 410 |
+
|
| 411 |
+
loss = self.ce_loss(predicted, target)
|
| 412 |
+
|
| 413 |
+
elif mode == 'clustering':
|
| 414 |
+
# Contrastive clustering loss (similar to SimCLR)
|
| 415 |
+
# Normalize embeddings
|
| 416 |
+
predicted = F.normalize(predicted, dim=-1)
|
| 417 |
+
|
| 418 |
+
# Compute similarity matrix
|
| 419 |
+
sim_matrix = torch.matmul(predicted, predicted.t()) / self.temperature
|
| 420 |
+
|
| 421 |
+
# Create labels (assuming batch contains similar samples)
|
| 422 |
+
batch_size = predicted.size(0)
|
| 423 |
+
labels = torch.arange(batch_size, device=predicted.device)
|
| 424 |
+
|
| 425 |
+
# Contrastive loss
|
| 426 |
+
loss = F.cross_entropy(sim_matrix, labels)
|
| 427 |
+
|
| 428 |
+
else:
|
| 429 |
+
raise ValueError(f"Unknown mode: {mode}")
|
| 430 |
+
|
| 431 |
+
return loss
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class ConsistencyLoss(nn.Module):
|
| 435 |
+
"""
|
| 436 |
+
Ensure consistency between encoder and decoder representations
|
| 437 |
+
GPT-5 suggestion: helps with training stability
|
| 438 |
+
"""
|
| 439 |
+
|
| 440 |
+
def __init__(self, margin: float = 0.5):
|
| 441 |
+
super().__init__()
|
| 442 |
+
self.margin = margin
|
| 443 |
+
|
| 444 |
+
def forward(self,
|
| 445 |
+
encoder_hidden: torch.Tensor,
|
| 446 |
+
decoder_hidden: torch.Tensor) -> torch.Tensor:
|
| 447 |
+
"""
|
| 448 |
+
Compute consistency loss between encoder and decoder
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
encoder_hidden: [batch, seq_len, hidden_dim]
|
| 452 |
+
decoder_hidden: [batch, seq_len, hidden_dim]
|
| 453 |
+
"""
|
| 454 |
+
# Ensure same shape
|
| 455 |
+
if encoder_hidden.shape != decoder_hidden.shape:
|
| 456 |
+
# Align sequence lengths if different
|
| 457 |
+
min_len = min(encoder_hidden.size(1), decoder_hidden.size(1))
|
| 458 |
+
encoder_hidden = encoder_hidden[:, :min_len]
|
| 459 |
+
decoder_hidden = decoder_hidden[:, :min_len]
|
| 460 |
+
|
| 461 |
+
# L2 distance
|
| 462 |
+
l2_loss = F.mse_loss(encoder_hidden, decoder_hidden)
|
| 463 |
+
|
| 464 |
+
# Cosine similarity loss
|
| 465 |
+
encoder_norm = F.normalize(encoder_hidden, dim=-1)
|
| 466 |
+
decoder_norm = F.normalize(decoder_hidden, dim=-1)
|
| 467 |
+
cosine_sim = (encoder_norm * decoder_norm).sum(dim=-1)
|
| 468 |
+
cosine_loss = 1.0 - cosine_sim.mean()
|
| 469 |
+
|
| 470 |
+
# Combined loss
|
| 471 |
+
loss = l2_loss + 0.5 * cosine_loss
|
| 472 |
+
|
| 473 |
+
return loss
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
class AdaptiveLossScheduler:
|
| 477 |
+
"""
|
| 478 |
+
Dynamically adjust loss weights during training
|
| 479 |
+
Based on training progress and performance
|
| 480 |
+
"""
|
| 481 |
+
|
| 482 |
+
def __init__(self, config: Dict):
|
| 483 |
+
self.config = config
|
| 484 |
+
self.current_phase = 0
|
| 485 |
+
self.phase_epochs = [30, 60, 100] # Phase transition points
|
| 486 |
+
|
| 487 |
+
# Define phase-specific weights
|
| 488 |
+
self.phase_weights = [
|
| 489 |
+
# Phase 1: Boundary mastery
|
| 490 |
+
{
|
| 491 |
+
'reconstruction': 2.0,
|
| 492 |
+
'compression': 0.5,
|
| 493 |
+
'boundary': 3.0,
|
| 494 |
+
'language': 0.5,
|
| 495 |
+
'consistency': 0.5
|
| 496 |
+
},
|
| 497 |
+
# Phase 2: Compression focus
|
| 498 |
+
{
|
| 499 |
+
'reconstruction': 2.0,
|
| 500 |
+
'compression': 3.0,
|
| 501 |
+
'boundary': 1.0,
|
| 502 |
+
'language': 1.0,
|
| 503 |
+
'consistency': 1.0
|
| 504 |
+
},
|
| 505 |
+
# Phase 3: Balanced optimization
|
| 506 |
+
{
|
| 507 |
+
'reconstruction': 3.0,
|
| 508 |
+
'compression': 2.0,
|
| 509 |
+
'boundary': 1.0,
|
| 510 |
+
'language': 1.0,
|
| 511 |
+
'consistency': 1.5
|
| 512 |
+
}
|
| 513 |
+
]
|
| 514 |
+
|
| 515 |
+
def get_weights(self, epoch: int, metrics: Optional[Dict] = None) -> Dict[str, float]:
|
| 516 |
+
"""
|
| 517 |
+
Get current loss weights based on training phase
|
| 518 |
+
|
| 519 |
+
Args:
|
| 520 |
+
epoch: Current training epoch
|
| 521 |
+
metrics: Optional performance metrics for adaptive adjustment
|
| 522 |
+
"""
|
| 523 |
+
# Determine current phase
|
| 524 |
+
for i, phase_end in enumerate(self.phase_epochs):
|
| 525 |
+
if epoch <= phase_end:
|
| 526 |
+
self.current_phase = i
|
| 527 |
+
break
|
| 528 |
+
|
| 529 |
+
weights = self.phase_weights[self.current_phase].copy()
|
| 530 |
+
|
| 531 |
+
# Adaptive adjustments based on metrics
|
| 532 |
+
if metrics:
|
| 533 |
+
# If reconstruction is poor, increase its weight
|
| 534 |
+
if metrics.get('reconstruction_accuracy', 1.0) < 0.9:
|
| 535 |
+
weights['reconstruction'] *= 1.5
|
| 536 |
+
|
| 537 |
+
# If compression is off target, adjust weight
|
| 538 |
+
compression_ratio = metrics.get('compression_ratio', 16.0)
|
| 539 |
+
if compression_ratio < 8.0 or compression_ratio > 20.0:
|
| 540 |
+
weights['compression'] *= 1.5
|
| 541 |
+
|
| 542 |
+
return weights
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
if __name__ == "__main__":
|
| 546 |
+
# Test losses
|
| 547 |
+
print("Testing Intelligent Loss Functions")
|
| 548 |
+
|
| 549 |
+
# Create loss module
|
| 550 |
+
loss_fn = IntelligentLoss()
|
| 551 |
+
|
| 552 |
+
# Create dummy data
|
| 553 |
+
batch_size = 2
|
| 554 |
+
seq_len = 48
|
| 555 |
+
vocab_size = 260
|
| 556 |
+
hidden_dim = 1280
|
| 557 |
+
|
| 558 |
+
outputs = {
|
| 559 |
+
'logits': torch.randn(batch_size, seq_len, vocab_size),
|
| 560 |
+
'compression_ratio': torch.tensor(16.0),
|
| 561 |
+
'num_tokens': torch.tensor(3),
|
| 562 |
+
'boundaries': torch.randn(batch_size, seq_len, 4),
|
| 563 |
+
'language_clusters': torch.randn(batch_size, 128),
|
| 564 |
+
'encoder_hidden': torch.randn(batch_size, seq_len, hidden_dim),
|
| 565 |
+
'decoder_hidden': torch.randn(batch_size, seq_len, hidden_dim)
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
targets = {
|
| 569 |
+
'input_ids': torch.randint(0, 256, (batch_size, seq_len)),
|
| 570 |
+
'attention_mask': torch.ones(batch_size, seq_len),
|
| 571 |
+
'boundary_targets': torch.zeros(batch_size, seq_len, 4),
|
| 572 |
+
'language_targets': torch.randint(0, 128, (batch_size,))
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
# Compute losses
|
| 576 |
+
losses = loss_fn(outputs, targets)
|
| 577 |
+
|
| 578 |
+
print("\nLoss components:")
|
| 579 |
+
for key, value in losses.items():
|
| 580 |
+
if isinstance(value, torch.Tensor):
|
| 581 |
+
print(f" {key}: {value.item():.4f}")
|
| 582 |
+
|
| 583 |
+
# Test adaptive scheduler
|
| 584 |
+
scheduler = AdaptiveLossScheduler({})
|
| 585 |
+
|
| 586 |
+
print("\nPhase weights:")
|
| 587 |
+
for epoch in [10, 40, 70]:
|
| 588 |
+
weights = scheduler.get_weights(epoch)
|
| 589 |
+
print(f" Epoch {epoch}: {weights}")
|
core/scheduler.py
ADDED
|
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Learning Rate Schedulers for v6.2.0
|
| 3 |
+
Advanced scheduling with warmup and phase-based adjustments
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import math
|
| 8 |
+
from typing import Optional, Dict, List, Any
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class WarmupCosineScheduler:
|
| 13 |
+
"""
|
| 14 |
+
Cosine annealing with linear warmup
|
| 15 |
+
GPT-5 suggested: Essential for stable progressive splitting training
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self,
|
| 19 |
+
optimizer: torch.optim.Optimizer,
|
| 20 |
+
warmup_steps: int,
|
| 21 |
+
total_steps: int,
|
| 22 |
+
min_lr: float = 1e-6,
|
| 23 |
+
max_lr: Optional[float] = None):
|
| 24 |
+
self.optimizer = optimizer
|
| 25 |
+
self.warmup_steps = warmup_steps
|
| 26 |
+
self.total_steps = total_steps
|
| 27 |
+
self.min_lr = min_lr
|
| 28 |
+
self.max_lr = max_lr or optimizer.param_groups[0]['lr']
|
| 29 |
+
self.current_step = 0
|
| 30 |
+
|
| 31 |
+
def step(self):
|
| 32 |
+
"""Update learning rate"""
|
| 33 |
+
self.current_step += 1
|
| 34 |
+
|
| 35 |
+
if self.current_step <= self.warmup_steps:
|
| 36 |
+
# Linear warmup
|
| 37 |
+
lr = self.max_lr * (self.current_step / self.warmup_steps)
|
| 38 |
+
else:
|
| 39 |
+
# Cosine annealing (GPT fix: guard against division by zero)
|
| 40 |
+
if self.total_steps <= self.warmup_steps:
|
| 41 |
+
lr = self.min_lr
|
| 42 |
+
else:
|
| 43 |
+
progress = (self.current_step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
|
| 44 |
+
progress = min(1.0, max(0.0, progress)) # Clamp to [0, 1]
|
| 45 |
+
lr = self.min_lr + (self.max_lr - self.min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
|
| 46 |
+
|
| 47 |
+
for param_group in self.optimizer.param_groups:
|
| 48 |
+
param_group['lr'] = lr
|
| 49 |
+
|
| 50 |
+
return lr
|
| 51 |
+
|
| 52 |
+
def get_lr(self):
|
| 53 |
+
"""Get current learning rate"""
|
| 54 |
+
return self.optimizer.param_groups[0]['lr']
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class PhaseBasedScheduler:
|
| 58 |
+
"""
|
| 59 |
+
Curriculum learning scheduler with phase transitions
|
| 60 |
+
Adjusts learning rate based on training phases
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(self,
|
| 64 |
+
optimizer: torch.optim.Optimizer,
|
| 65 |
+
phase_configs: List[Dict],
|
| 66 |
+
current_epoch: int = 0):
|
| 67 |
+
"""
|
| 68 |
+
Args:
|
| 69 |
+
optimizer: PyTorch optimizer
|
| 70 |
+
phase_configs: List of phase configurations
|
| 71 |
+
[{
|
| 72 |
+
'epochs': (start, end),
|
| 73 |
+
'lr': learning_rate,
|
| 74 |
+
'warmup_epochs': warmup_duration
|
| 75 |
+
}, ...]
|
| 76 |
+
"""
|
| 77 |
+
self.optimizer = optimizer
|
| 78 |
+
self.phase_configs = phase_configs
|
| 79 |
+
self.current_epoch = current_epoch
|
| 80 |
+
self.current_phase = 0
|
| 81 |
+
self.base_lr = optimizer.param_groups[0]['lr']
|
| 82 |
+
|
| 83 |
+
def step(self, epoch: Optional[int] = None):
|
| 84 |
+
"""Update learning rate based on current phase"""
|
| 85 |
+
if epoch is not None:
|
| 86 |
+
self.current_epoch = epoch
|
| 87 |
+
|
| 88 |
+
# Find current phase
|
| 89 |
+
for i, phase in enumerate(self.phase_configs):
|
| 90 |
+
start_epoch, end_epoch = phase['epochs']
|
| 91 |
+
if start_epoch <= self.current_epoch <= end_epoch:
|
| 92 |
+
self.current_phase = i
|
| 93 |
+
break
|
| 94 |
+
|
| 95 |
+
phase = self.phase_configs[self.current_phase]
|
| 96 |
+
target_lr = phase['lr']
|
| 97 |
+
warmup_epochs = phase.get('warmup_epochs', 0)
|
| 98 |
+
start_epoch = phase['epochs'][0]
|
| 99 |
+
|
| 100 |
+
# Apply warmup if in warmup period
|
| 101 |
+
if self.current_epoch - start_epoch < warmup_epochs:
|
| 102 |
+
warmup_progress = (self.current_epoch - start_epoch + 1) / warmup_epochs
|
| 103 |
+
lr = target_lr * warmup_progress
|
| 104 |
+
else:
|
| 105 |
+
lr = target_lr
|
| 106 |
+
|
| 107 |
+
# Update optimizer
|
| 108 |
+
for param_group in self.optimizer.param_groups:
|
| 109 |
+
param_group['lr'] = lr
|
| 110 |
+
|
| 111 |
+
return lr
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class AdaptiveScheduler:
|
| 115 |
+
"""
|
| 116 |
+
Adaptive learning rate based on validation metrics
|
| 117 |
+
Reduces LR when metrics plateau
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(self,
|
| 121 |
+
optimizer: torch.optim.Optimizer,
|
| 122 |
+
mode: str = 'min',
|
| 123 |
+
factor: float = 0.5,
|
| 124 |
+
patience: int = 10,
|
| 125 |
+
threshold: float = 1e-4,
|
| 126 |
+
min_lr: float = 1e-7):
|
| 127 |
+
"""
|
| 128 |
+
Args:
|
| 129 |
+
optimizer: PyTorch optimizer
|
| 130 |
+
mode: 'min' or 'max' - whether to reduce LR when metric stops decreasing or increasing
|
| 131 |
+
factor: Factor to reduce LR by
|
| 132 |
+
patience: Number of epochs with no improvement to wait
|
| 133 |
+
threshold: Minimum change to qualify as improvement
|
| 134 |
+
min_lr: Minimum learning rate
|
| 135 |
+
"""
|
| 136 |
+
self.optimizer = optimizer
|
| 137 |
+
self.mode = mode
|
| 138 |
+
self.factor = factor
|
| 139 |
+
self.patience = patience
|
| 140 |
+
self.threshold = threshold
|
| 141 |
+
self.min_lr = min_lr
|
| 142 |
+
|
| 143 |
+
self.best_score = None
|
| 144 |
+
self.num_bad_epochs = 0
|
| 145 |
+
self.last_reduction = 0
|
| 146 |
+
|
| 147 |
+
def step(self, metric: float, epoch: int = 0):
|
| 148 |
+
"""Update learning rate based on metric"""
|
| 149 |
+
current_lr = self.optimizer.param_groups[0]['lr']
|
| 150 |
+
|
| 151 |
+
if self.best_score is None:
|
| 152 |
+
self.best_score = metric
|
| 153 |
+
else:
|
| 154 |
+
if self.mode == 'min':
|
| 155 |
+
improved = metric < self.best_score - self.threshold
|
| 156 |
+
else:
|
| 157 |
+
improved = metric > self.best_score + self.threshold
|
| 158 |
+
|
| 159 |
+
if improved:
|
| 160 |
+
self.best_score = metric
|
| 161 |
+
self.num_bad_epochs = 0
|
| 162 |
+
else:
|
| 163 |
+
self.num_bad_epochs += 1
|
| 164 |
+
|
| 165 |
+
# Reduce LR if patience exceeded
|
| 166 |
+
if self.num_bad_epochs >= self.patience:
|
| 167 |
+
new_lr = max(current_lr * self.factor, self.min_lr)
|
| 168 |
+
|
| 169 |
+
if new_lr < current_lr:
|
| 170 |
+
print(f"Reducing learning rate from {current_lr:.2e} to {new_lr:.2e}")
|
| 171 |
+
|
| 172 |
+
for param_group in self.optimizer.param_groups:
|
| 173 |
+
param_group['lr'] = new_lr
|
| 174 |
+
|
| 175 |
+
self.num_bad_epochs = 0
|
| 176 |
+
self.last_reduction = epoch
|
| 177 |
+
|
| 178 |
+
return current_lr
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class ProgressiveSplittingScheduler:
|
| 182 |
+
"""
|
| 183 |
+
Adaptive scheduler for progressive splitting
|
| 184 |
+
No fixed targets - adjusts based on quality feedback
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
def __init__(self,
|
| 188 |
+
optimizer: torch.optim.Optimizer,
|
| 189 |
+
initial_lr: float = 1e-4,
|
| 190 |
+
min_reconstruction: float = 0.85,
|
| 191 |
+
ema: float = 0.98,
|
| 192 |
+
min_lr: float = 1e-7):
|
| 193 |
+
self.optimizer = optimizer
|
| 194 |
+
self.initial_lr = initial_lr
|
| 195 |
+
self.min_reconstruction = min_reconstruction # Quality threshold
|
| 196 |
+
self.ema = ema
|
| 197 |
+
self.min_lr = min_lr
|
| 198 |
+
|
| 199 |
+
# Adaptive multipliers based on performance
|
| 200 |
+
self.quality_multiplier = 1.0 # Adjusts with reconstruction quality
|
| 201 |
+
|
| 202 |
+
# No phases - continuous adaptation
|
| 203 |
+
self.current_state = 'learning'
|
| 204 |
+
|
| 205 |
+
# EMA tracking for smooth transitions
|
| 206 |
+
self._ema_comp = None
|
| 207 |
+
self._ema_recon = None
|
| 208 |
+
|
| 209 |
+
def step(self, metrics: Dict[str, float]):
|
| 210 |
+
"""
|
| 211 |
+
Update learning rate based on current metrics
|
| 212 |
+
GPT fix: EMA smoothing and minimum floor
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
metrics: Dictionary containing:
|
| 216 |
+
- compression_ratio: Current compression ratio
|
| 217 |
+
- reconstruction_acc: Reconstruction accuracy
|
| 218 |
+
"""
|
| 219 |
+
compression_ratio = float(metrics.get('compression_ratio', 0.0))
|
| 220 |
+
reconstruction_acc = float(metrics.get('reconstruction_acc', 0.0))
|
| 221 |
+
|
| 222 |
+
# Update EMA (GPT fix: smooth transitions)
|
| 223 |
+
if self._ema_comp is None:
|
| 224 |
+
self._ema_comp = compression_ratio
|
| 225 |
+
self._ema_recon = reconstruction_acc
|
| 226 |
+
else:
|
| 227 |
+
self._ema_comp = self.ema * self._ema_comp + (1 - self.ema) * compression_ratio
|
| 228 |
+
self._ema_recon = self.ema * self._ema_recon + (1 - self.ema) * reconstruction_acc
|
| 229 |
+
|
| 230 |
+
# Adaptive adjustment based on reconstruction quality only
|
| 231 |
+
# No fixed compression target - emerges from quality
|
| 232 |
+
if self._ema_recon < self.min_reconstruction:
|
| 233 |
+
# Poor reconstruction - reduce LR for careful learning
|
| 234 |
+
self.quality_multiplier = max(0.5, self._ema_recon)
|
| 235 |
+
else:
|
| 236 |
+
# Good reconstruction - normal learning
|
| 237 |
+
self.quality_multiplier = 1.0
|
| 238 |
+
|
| 239 |
+
# Smooth LR changes
|
| 240 |
+
reconstruction_factor = max(0.1, self._ema_recon)
|
| 241 |
+
|
| 242 |
+
# Combined learning rate (adaptive, no phase multiplier)
|
| 243 |
+
lr = self.initial_lr * self.quality_multiplier * reconstruction_factor
|
| 244 |
+
lr = max(lr, self.min_lr) # Ensure minimum LR
|
| 245 |
+
|
| 246 |
+
# Update optimizer
|
| 247 |
+
for param_group in self.optimizer.param_groups:
|
| 248 |
+
param_group['lr'] = lr
|
| 249 |
+
|
| 250 |
+
return lr
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class GumbelTemperatureScheduler:
|
| 254 |
+
"""
|
| 255 |
+
Temperature annealing for Gumbel-Softmax
|
| 256 |
+
GPT-5 suggestion: Critical for progressive splitting
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
def __init__(self,
|
| 260 |
+
initial_temp: float = 1.0,
|
| 261 |
+
final_temp: float = 0.1,
|
| 262 |
+
anneal_rate: float = 0.99995,
|
| 263 |
+
anneal_steps: Optional[int] = None):
|
| 264 |
+
self.initial_temp = initial_temp
|
| 265 |
+
self.final_temp = final_temp
|
| 266 |
+
self.anneal_rate = anneal_rate
|
| 267 |
+
self.anneal_steps = anneal_steps
|
| 268 |
+
self.current_step = 0
|
| 269 |
+
self.current_temp = initial_temp
|
| 270 |
+
|
| 271 |
+
def step(self):
|
| 272 |
+
"""Update temperature"""
|
| 273 |
+
self.current_step += 1
|
| 274 |
+
|
| 275 |
+
if self.anneal_steps:
|
| 276 |
+
# Linear annealing
|
| 277 |
+
progress = min(1.0, self.current_step / self.anneal_steps)
|
| 278 |
+
self.current_temp = self.initial_temp + (self.final_temp - self.initial_temp) * progress
|
| 279 |
+
else:
|
| 280 |
+
# Exponential annealing
|
| 281 |
+
self.current_temp = max(
|
| 282 |
+
self.final_temp,
|
| 283 |
+
self.initial_temp * (self.anneal_rate ** self.current_step)
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
return self.current_temp
|
| 287 |
+
|
| 288 |
+
def get_temperature(self):
|
| 289 |
+
"""Get current temperature"""
|
| 290 |
+
return self.current_temp
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class CompressionRatioScheduler:
|
| 294 |
+
"""
|
| 295 |
+
Schedule target compression ratio during training
|
| 296 |
+
Gradually increase compression requirements
|
| 297 |
+
"""
|
| 298 |
+
|
| 299 |
+
def __init__(self,
|
| 300 |
+
initial_ratio: float = 8.0,
|
| 301 |
+
target_ratio: float = 24.0,
|
| 302 |
+
warmup_epochs: int = 10,
|
| 303 |
+
total_epochs: int = 100):
|
| 304 |
+
self.initial_ratio = initial_ratio
|
| 305 |
+
self.target_ratio = target_ratio
|
| 306 |
+
self.warmup_epochs = warmup_epochs
|
| 307 |
+
self.total_epochs = total_epochs
|
| 308 |
+
self.current_epoch = 0
|
| 309 |
+
|
| 310 |
+
def step(self, epoch: Optional[int] = None):
|
| 311 |
+
"""Update target compression ratio"""
|
| 312 |
+
if epoch is not None:
|
| 313 |
+
self.current_epoch = epoch
|
| 314 |
+
else:
|
| 315 |
+
self.current_epoch += 1
|
| 316 |
+
|
| 317 |
+
if self.current_epoch < self.warmup_epochs:
|
| 318 |
+
# Start with lower compression requirement
|
| 319 |
+
ratio = self.initial_ratio
|
| 320 |
+
else:
|
| 321 |
+
# Gradually increase to target
|
| 322 |
+
progress = (self.current_epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs)
|
| 323 |
+
progress = min(1.0, progress)
|
| 324 |
+
ratio = self.initial_ratio + (self.target_ratio - self.initial_ratio) * progress
|
| 325 |
+
|
| 326 |
+
return ratio
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
class MultiScheduler:
|
| 330 |
+
"""
|
| 331 |
+
Combine multiple schedulers for comprehensive training control
|
| 332 |
+
"""
|
| 333 |
+
|
| 334 |
+
def __init__(self, schedulers: Dict):
|
| 335 |
+
"""
|
| 336 |
+
Args:
|
| 337 |
+
schedulers: Dictionary of schedulers
|
| 338 |
+
{
|
| 339 |
+
'lr': learning_rate_scheduler,
|
| 340 |
+
'gumbel': gumbel_temperature_scheduler,
|
| 341 |
+
'compression': compression_ratio_scheduler,
|
| 342 |
+
...
|
| 343 |
+
}
|
| 344 |
+
"""
|
| 345 |
+
self.schedulers = schedulers
|
| 346 |
+
|
| 347 |
+
def step(self, **kwargs):
|
| 348 |
+
"""
|
| 349 |
+
Step all schedulers
|
| 350 |
+
GPT fix: unified input convention
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
Dictionary with all scheduler outputs
|
| 354 |
+
"""
|
| 355 |
+
results = {}
|
| 356 |
+
|
| 357 |
+
for name, scheduler in self.schedulers.items():
|
| 358 |
+
try:
|
| 359 |
+
# Check scheduler type and pass appropriate arguments
|
| 360 |
+
if hasattr(scheduler, '__class__'):
|
| 361 |
+
class_name = scheduler.__class__.__name__
|
| 362 |
+
|
| 363 |
+
if class_name == 'AdaptiveScheduler' and 'metric' in kwargs:
|
| 364 |
+
results[name] = scheduler.step(kwargs['metric'], kwargs.get('epoch', 0))
|
| 365 |
+
elif class_name == 'PhaseBasedScheduler' and 'epoch' in kwargs:
|
| 366 |
+
results[name] = scheduler.step(kwargs['epoch'])
|
| 367 |
+
elif class_name == 'CompressionRatioScheduler' and 'epoch' in kwargs:
|
| 368 |
+
results[name] = scheduler.step(kwargs['epoch'])
|
| 369 |
+
elif class_name == 'ProgressiveSplittingScheduler' and 'metrics' in kwargs:
|
| 370 |
+
results[name] = scheduler.step(kwargs['metrics'])
|
| 371 |
+
elif hasattr(scheduler, 'step'):
|
| 372 |
+
# Generic step (no arguments)
|
| 373 |
+
results[name] = scheduler.step()
|
| 374 |
+
else:
|
| 375 |
+
if hasattr(scheduler, 'step'):
|
| 376 |
+
results[name] = scheduler.step()
|
| 377 |
+
except Exception as e:
|
| 378 |
+
print(f"Warning: Scheduler '{name}' step failed: {e}")
|
| 379 |
+
results[name] = None
|
| 380 |
+
|
| 381 |
+
return results
|
| 382 |
+
|
| 383 |
+
def get_current_values(self):
|
| 384 |
+
"""Get current values from all schedulers"""
|
| 385 |
+
values = {}
|
| 386 |
+
|
| 387 |
+
for name, scheduler in self.schedulers.items():
|
| 388 |
+
if hasattr(scheduler, 'get_lr'):
|
| 389 |
+
values[name] = scheduler.get_lr()
|
| 390 |
+
elif hasattr(scheduler, 'get_temperature'):
|
| 391 |
+
values[name] = scheduler.get_temperature()
|
| 392 |
+
elif hasattr(scheduler, 'current_temp'):
|
| 393 |
+
values[name] = scheduler.current_temp
|
| 394 |
+
elif hasattr(scheduler, 'current_epoch'):
|
| 395 |
+
values[name] = scheduler.current_epoch
|
| 396 |
+
|
| 397 |
+
return values
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
class GateWarmupScheduler:
|
| 401 |
+
"""게이트 파라미터 웜업 스케줄러
|
| 402 |
+
|
| 403 |
+
초기: 모든 레이어 동등 사용 (gate=1.0)
|
| 404 |
+
웜업: 점진적 게이트 학습 시작
|
| 405 |
+
후기: 최적 게이트 값으로 수렴
|
| 406 |
+
"""
|
| 407 |
+
|
| 408 |
+
def __init__(
|
| 409 |
+
self,
|
| 410 |
+
optimizer: torch.optim.Optimizer,
|
| 411 |
+
warmup_steps: int = 1000,
|
| 412 |
+
gate_param_group_name: str = 'gates',
|
| 413 |
+
importance_param_group_name: str = 'importance'
|
| 414 |
+
):
|
| 415 |
+
"""
|
| 416 |
+
Args:
|
| 417 |
+
optimizer: 옵티마이저
|
| 418 |
+
warmup_steps: 웜업 스텝 수
|
| 419 |
+
gate_param_group_name: 게이트 파라미터 그룹 이름
|
| 420 |
+
importance_param_group_name: 중요도 파라미터 그룹 이름
|
| 421 |
+
"""
|
| 422 |
+
self.optimizer = optimizer
|
| 423 |
+
self.warmup_steps = warmup_steps
|
| 424 |
+
self.gate_group_name = gate_param_group_name
|
| 425 |
+
self.importance_group_name = importance_param_group_name
|
| 426 |
+
|
| 427 |
+
# 초기 학습률 저장
|
| 428 |
+
self.base_lrs = {}
|
| 429 |
+
for group in optimizer.param_groups:
|
| 430 |
+
if 'name' in group:
|
| 431 |
+
self.base_lrs[group['name']] = group['lr']
|
| 432 |
+
|
| 433 |
+
def get_gate_factor(self, step: int) -> float:
|
| 434 |
+
"""게이트 학습률 계수 계산
|
| 435 |
+
|
| 436 |
+
웜업 기간 동안은 낮은 학습률,
|
| 437 |
+
이후 정상 학습률로 전환
|
| 438 |
+
"""
|
| 439 |
+
if step < self.warmup_steps:
|
| 440 |
+
# 웜업 기간: 선형 증가
|
| 441 |
+
return step / self.warmup_steps
|
| 442 |
+
else:
|
| 443 |
+
# 정상 학습
|
| 444 |
+
return 1.0
|
| 445 |
+
|
| 446 |
+
def get_importance_factor(self, step: int) -> float:
|
| 447 |
+
"""중요도 학습률 계수 계산
|
| 448 |
+
|
| 449 |
+
게이트보다 느리게 학습 시작
|
| 450 |
+
"""
|
| 451 |
+
delayed_warmup = self.warmup_steps * 1.5
|
| 452 |
+
if step < delayed_warmup:
|
| 453 |
+
return step / delayed_warmup * 0.5
|
| 454 |
+
else:
|
| 455 |
+
return 1.0
|
| 456 |
+
|
| 457 |
+
def step(self, current_step: int):
|
| 458 |
+
"""스케줄러 스텝
|
| 459 |
+
|
| 460 |
+
Args:
|
| 461 |
+
current_step: 현재 글로벌 스텝
|
| 462 |
+
"""
|
| 463 |
+
# 게이트 파라미터 그룹 학습률 조정
|
| 464 |
+
gate_factor = self.get_gate_factor(current_step)
|
| 465 |
+
importance_factor = self.get_importance_factor(current_step)
|
| 466 |
+
|
| 467 |
+
for group in self.optimizer.param_groups:
|
| 468 |
+
if 'name' not in group:
|
| 469 |
+
continue
|
| 470 |
+
|
| 471 |
+
if group['name'] == self.gate_group_name:
|
| 472 |
+
# 게이트 학습률 조정
|
| 473 |
+
group['lr'] = self.base_lrs[self.gate_group_name] * gate_factor
|
| 474 |
+
|
| 475 |
+
elif group['name'] == self.importance_group_name:
|
| 476 |
+
# 중요도 학습률 조정
|
| 477 |
+
group['lr'] = self.base_lrs[self.importance_group_name] * importance_factor
|
| 478 |
+
|
| 479 |
+
def get_lr(self) -> Dict[str, float]:
|
| 480 |
+
"""현재 학습률 반환"""
|
| 481 |
+
lrs = {}
|
| 482 |
+
for group in self.optimizer.param_groups:
|
| 483 |
+
if 'name' in group:
|
| 484 |
+
lrs[group['name']] = group['lr']
|
| 485 |
+
return lrs
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
class UniversalCosineScheduler:
|
| 489 |
+
"""Universal Cosine Annealing 스케줄러
|
| 490 |
+
|
| 491 |
+
모든 언어에 대해 동일한 스케줄 적용
|
| 492 |
+
"""
|
| 493 |
+
|
| 494 |
+
def __init__(
|
| 495 |
+
self,
|
| 496 |
+
optimizer: torch.optim.Optimizer,
|
| 497 |
+
warmup_steps: int = 1000,
|
| 498 |
+
total_steps: int = 10000,
|
| 499 |
+
min_lr_ratio: float = 0.1
|
| 500 |
+
):
|
| 501 |
+
self.optimizer = optimizer
|
| 502 |
+
self.warmup_steps = warmup_steps
|
| 503 |
+
self.total_steps = total_steps
|
| 504 |
+
self.min_lr_ratio = min_lr_ratio
|
| 505 |
+
self.current_step = 0
|
| 506 |
+
|
| 507 |
+
# 초기 학습률 저장
|
| 508 |
+
self.base_lrs = [group['lr'] for group in optimizer.param_groups]
|
| 509 |
+
|
| 510 |
+
def step(self):
|
| 511 |
+
"""스케줄러 스텝"""
|
| 512 |
+
self.current_step += 1
|
| 513 |
+
|
| 514 |
+
for idx, param_group in enumerate(self.optimizer.param_groups):
|
| 515 |
+
if self.current_step < self.warmup_steps:
|
| 516 |
+
# Warmup 단계
|
| 517 |
+
lr = self.base_lrs[idx] * (self.current_step / self.warmup_steps)
|
| 518 |
+
else:
|
| 519 |
+
# Cosine annealing
|
| 520 |
+
if self.total_steps <= self.warmup_steps:
|
| 521 |
+
# warmup_steps가 total_steps보다 크거나 같은 경우
|
| 522 |
+
lr = self.base_lrs[idx] * self.min_lr_ratio
|
| 523 |
+
else:
|
| 524 |
+
progress = min(1.0, (self.current_step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps))
|
| 525 |
+
lr = self.base_lrs[idx] * (
|
| 526 |
+
self.min_lr_ratio + (1 - self.min_lr_ratio) * 0.5 * (1 + math.cos(math.pi * progress))
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
param_group['lr'] = lr
|
| 530 |
+
|
| 531 |
+
def get_last_lr(self) -> List[float]:
|
| 532 |
+
"""마지막 학습률 반환"""
|
| 533 |
+
return [group['lr'] for group in self.optimizer.param_groups]
|
| 534 |
+
|
| 535 |
+
def state_dict(self) -> Dict[str, Any]:
|
| 536 |
+
"""스케줄러 상태 딕셔너리 반환 (체크포인트 저장용)"""
|
| 537 |
+
return {
|
| 538 |
+
'current_step': self.current_step,
|
| 539 |
+
'warmup_steps': self.warmup_steps,
|
| 540 |
+
'total_steps': self.total_steps,
|
| 541 |
+
'min_lr_ratio': self.min_lr_ratio,
|
| 542 |
+
'base_lrs': self.base_lrs
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
def load_state_dict(self, state_dict: Dict[str, Any]):
|
| 546 |
+
"""스케줄러 상태 로드 (체크포인트 재시작용)"""
|
| 547 |
+
self.current_step = state_dict['current_step']
|
| 548 |
+
self.warmup_steps = state_dict['warmup_steps']
|
| 549 |
+
self.total_steps = state_dict['total_steps']
|
| 550 |
+
self.min_lr_ratio = state_dict['min_lr_ratio']
|
| 551 |
+
self.base_lrs = state_dict['base_lrs']
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
class AdaptiveLayerScheduler:
|
| 555 |
+
"""레이어별 적응적 스케줄러
|
| 556 |
+
|
| 557 |
+
각 레이어의 학습 진행도에 따라 동적으로 조정
|
| 558 |
+
"""
|
| 559 |
+
|
| 560 |
+
def __init__(
|
| 561 |
+
self,
|
| 562 |
+
layer_builder,
|
| 563 |
+
threshold_active: float = 0.7,
|
| 564 |
+
threshold_skip: float = 0.3
|
| 565 |
+
):
|
| 566 |
+
"""
|
| 567 |
+
Args:
|
| 568 |
+
layer_builder: LayerBuilder 인스턴스
|
| 569 |
+
threshold_active: 활성 레이어 임계값
|
| 570 |
+
threshold_skip: 스킵 레이어 임계값
|
| 571 |
+
"""
|
| 572 |
+
self.layer_builder = layer_builder
|
| 573 |
+
self.threshold_active = threshold_active
|
| 574 |
+
self.threshold_skip = threshold_skip
|
| 575 |
+
|
| 576 |
+
# 레이어별 통계
|
| 577 |
+
self.layer_stats = {
|
| 578 |
+
'usage_count': torch.zeros(5),
|
| 579 |
+
'contribution': torch.zeros(5)
|
| 580 |
+
}
|
| 581 |
+
|
| 582 |
+
def update_stats(self, batch_output):
|
| 583 |
+
"""배치 출력으로 통계 업데이트"""
|
| 584 |
+
with torch.no_grad():
|
| 585 |
+
gates = torch.sigmoid(self.layer_builder.layer_gates)
|
| 586 |
+
|
| 587 |
+
# 사용 횟수 업데이트
|
| 588 |
+
self.layer_stats['usage_count'] += (gates > self.threshold_skip).float()
|
| 589 |
+
|
| 590 |
+
# 기여도 추정 (간단한 버전)
|
| 591 |
+
importance = torch.nn.functional.softmax(
|
| 592 |
+
self.layer_builder.layer_importance, dim=0
|
| 593 |
+
)
|
| 594 |
+
self.layer_stats['contribution'] += importance.detach()
|
| 595 |
+
|
| 596 |
+
def get_layer_status(self) -> Dict[int, str]:
|
| 597 |
+
"""각 레이어의 상태 반환"""
|
| 598 |
+
gates = torch.sigmoid(self.layer_builder.layer_gates)
|
| 599 |
+
status = {}
|
| 600 |
+
|
| 601 |
+
for i in range(5):
|
| 602 |
+
if gates[i] > self.threshold_active:
|
| 603 |
+
status[i] = "ACTIVE"
|
| 604 |
+
elif gates[i] > self.threshold_skip:
|
| 605 |
+
status[i] = "PARTIAL"
|
| 606 |
+
else:
|
| 607 |
+
status[i] = "SKIP"
|
| 608 |
+
|
| 609 |
+
return status
|
| 610 |
+
|
| 611 |
+
def suggest_pruning(self) -> List[int]:
|
| 612 |
+
"""프루닝 가능한 레이어 제안"""
|
| 613 |
+
gates = torch.sigmoid(self.layer_builder.layer_gates)
|
| 614 |
+
prunable = []
|
| 615 |
+
|
| 616 |
+
for i in range(5):
|
| 617 |
+
if gates[i] < self.threshold_skip:
|
| 618 |
+
# 낮은 게이트 값 + 낮은 기여도
|
| 619 |
+
if self.layer_stats['contribution'][i] < 0.1:
|
| 620 |
+
prunable.append(i)
|
| 621 |
+
|
| 622 |
+
return prunable
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
if __name__ == "__main__":
|
| 626 |
+
# Test schedulers
|
| 627 |
+
print("Testing Schedulers")
|
| 628 |
+
|
| 629 |
+
# Create dummy optimizer
|
| 630 |
+
model = torch.nn.Linear(10, 10)
|
| 631 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
| 632 |
+
|
| 633 |
+
# Test WarmupCosineScheduler
|
| 634 |
+
print("\n1. WarmupCosineScheduler:")
|
| 635 |
+
scheduler = WarmupCosineScheduler(optimizer, warmup_steps=100, total_steps=1000)
|
| 636 |
+
lrs = []
|
| 637 |
+
for step in range(200):
|
| 638 |
+
lr = scheduler.step()
|
| 639 |
+
if step % 20 == 0:
|
| 640 |
+
print(f" Step {step}: LR = {lr:.6f}")
|
| 641 |
+
lrs.append(lr)
|
| 642 |
+
|
| 643 |
+
# Test PhaseBasedScheduler
|
| 644 |
+
print("\n2. PhaseBasedScheduler:")
|
| 645 |
+
phase_configs = [
|
| 646 |
+
{'epochs': (0, 30), 'lr': 1e-4, 'warmup_epochs': 5},
|
| 647 |
+
{'epochs': (31, 60), 'lr': 5e-5, 'warmup_epochs': 2},
|
| 648 |
+
{'epochs': (61, 100), 'lr': 1e-5, 'warmup_epochs': 0}
|
| 649 |
+
]
|
| 650 |
+
scheduler = PhaseBasedScheduler(optimizer, phase_configs)
|
| 651 |
+
for epoch in [0, 5, 31, 35, 61, 80]:
|
| 652 |
+
lr = scheduler.step(epoch)
|
| 653 |
+
print(f" Epoch {epoch}: LR = {lr:.6f}")
|
| 654 |
+
|
| 655 |
+
# Test GumbelTemperatureScheduler
|
| 656 |
+
print("\n3. GumbelTemperatureScheduler:")
|
| 657 |
+
scheduler = GumbelTemperatureScheduler()
|
| 658 |
+
for step in [0, 100, 500, 1000, 5000]:
|
| 659 |
+
for _ in range(step - scheduler.current_step):
|
| 660 |
+
scheduler.step()
|
| 661 |
+
temp = scheduler.get_temperature()
|
| 662 |
+
print(f" Step {step}: Temperature = {temp:.4f}")
|
| 663 |
+
|
| 664 |
+
# Test CompressionRatioScheduler
|
| 665 |
+
print("\n4. CompressionRatioScheduler:")
|
| 666 |
+
scheduler = CompressionRatioScheduler()
|
| 667 |
+
for epoch in [0, 5, 10, 30, 50, 80, 100]:
|
| 668 |
+
ratio = scheduler.step(epoch)
|
| 669 |
+
print(f" Epoch {epoch}: Target ratio = {ratio:.1f}:1")
|
core/tokenizer.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Intelligent Tokenizer v6.2.0 - Byte Tokenizer with 46+2 Configuration
|
| 3 |
+
Handles chunking, sliding windows, and boundary adjustments
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _trim_utf8_boundary(byte_seq: List[int], limit: int) -> int:
|
| 14 |
+
"""
|
| 15 |
+
Trim byte sequence to valid UTF-8 boundary (GPT suggestion)
|
| 16 |
+
"""
|
| 17 |
+
end = min(limit, len(byte_seq))
|
| 18 |
+
while end > 0:
|
| 19 |
+
try:
|
| 20 |
+
bytes(byte_seq[:end]).decode('utf-8')
|
| 21 |
+
return end
|
| 22 |
+
except UnicodeDecodeError:
|
| 23 |
+
end -= 1
|
| 24 |
+
return limit
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ByteTokenizerV62:
|
| 28 |
+
"""
|
| 29 |
+
Pure byte-level tokenizer
|
| 30 |
+
46 content bytes + 2 special tokens (BOS/EOS) = 48 total
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, config: Optional[Dict] = None):
|
| 34 |
+
# Configuration
|
| 35 |
+
self.content_size = 46 # Actual content bytes
|
| 36 |
+
self.max_seq_len = 48 # Total with BOS/EOS
|
| 37 |
+
self.chunk_overlap = 8 # Overlap for sliding window
|
| 38 |
+
|
| 39 |
+
# Special tokens
|
| 40 |
+
self.PAD = 256
|
| 41 |
+
self.BOS = 257
|
| 42 |
+
self.EOS = 258
|
| 43 |
+
self.MASK = 259
|
| 44 |
+
self.vocab_size = 260 # 256 bytes + 4 special
|
| 45 |
+
|
| 46 |
+
def encode(self,
|
| 47 |
+
text: str,
|
| 48 |
+
add_special_tokens: bool = True,
|
| 49 |
+
return_chunks: bool = False) -> Dict[str, torch.Tensor]:
|
| 50 |
+
"""
|
| 51 |
+
Encode text to byte sequences
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
text: Input text
|
| 55 |
+
add_special_tokens: Whether to add BOS/EOS
|
| 56 |
+
return_chunks: Return multiple chunks for long sequences
|
| 57 |
+
"""
|
| 58 |
+
# Convert to UTF-8 bytes
|
| 59 |
+
byte_sequence = list(text.encode('utf-8'))
|
| 60 |
+
|
| 61 |
+
if return_chunks and len(byte_sequence) > self.content_size:
|
| 62 |
+
# Handle long sequences with sliding window
|
| 63 |
+
return self._encode_with_chunks(byte_sequence, add_special_tokens)
|
| 64 |
+
|
| 65 |
+
# Single chunk processing with UTF-8 boundary (GPT suggestion)
|
| 66 |
+
if len(byte_sequence) > self.content_size:
|
| 67 |
+
cut_point = _trim_utf8_boundary(byte_sequence, self.content_size)
|
| 68 |
+
byte_sequence = byte_sequence[:cut_point]
|
| 69 |
+
|
| 70 |
+
# Add special tokens (GPT suggestion: cleaner padding order)
|
| 71 |
+
if add_special_tokens:
|
| 72 |
+
byte_sequence = [self.BOS] + byte_sequence + [self.EOS]
|
| 73 |
+
|
| 74 |
+
# Pad to max_seq_len (after special tokens for cleaner structure)
|
| 75 |
+
if len(byte_sequence) < self.max_seq_len:
|
| 76 |
+
padding_length = self.max_seq_len - len(byte_sequence)
|
| 77 |
+
byte_sequence = byte_sequence + [self.PAD] * padding_length
|
| 78 |
+
|
| 79 |
+
input_ids = torch.tensor(byte_sequence, dtype=torch.long)
|
| 80 |
+
attention_mask = (input_ids != self.PAD) # bool type (GPT suggestion)
|
| 81 |
+
|
| 82 |
+
return {
|
| 83 |
+
'input_ids': input_ids,
|
| 84 |
+
'attention_mask': attention_mask,
|
| 85 |
+
'length': len(byte_sequence),
|
| 86 |
+
'original_length': len(text.encode('utf-8'))
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
def _encode_with_chunks(self,
|
| 90 |
+
byte_sequence: List[int],
|
| 91 |
+
add_special_tokens: bool) -> Dict[str, torch.Tensor]:
|
| 92 |
+
"""
|
| 93 |
+
Encode long sequences with sliding window chunks
|
| 94 |
+
"""
|
| 95 |
+
chunks = []
|
| 96 |
+
positions = []
|
| 97 |
+
|
| 98 |
+
# Calculate stride (content_size - overlap)
|
| 99 |
+
stride = self.content_size - self.chunk_overlap
|
| 100 |
+
|
| 101 |
+
for i in range(0, len(byte_sequence), stride):
|
| 102 |
+
# Extract chunk
|
| 103 |
+
chunk = byte_sequence[i:i + self.content_size]
|
| 104 |
+
|
| 105 |
+
# Skip if chunk is too small (last chunk)
|
| 106 |
+
if len(chunk) < self.content_size // 2:
|
| 107 |
+
if chunks: # Merge with previous chunk if exists
|
| 108 |
+
last_chunk = chunks[-1]['input_ids'].tolist()
|
| 109 |
+
# Remove padding and special tokens from last chunk (GPT final check)
|
| 110 |
+
last_chunk = [b for b in last_chunk if b not in [self.PAD, self.BOS, self.EOS]]
|
| 111 |
+
# Add current chunk
|
| 112 |
+
merged = last_chunk + chunk + [self.EOS]
|
| 113 |
+
# Repad
|
| 114 |
+
if len(merged) < self.max_seq_len:
|
| 115 |
+
merged += [self.PAD] * (self.max_seq_len - len(merged))
|
| 116 |
+
merged_ids = torch.tensor(merged[:self.max_seq_len], dtype=torch.long)
|
| 117 |
+
merged_mask = (merged_ids != self.PAD) # Recalculate mask (GPT suggestion)
|
| 118 |
+
chunks[-1]['input_ids'] = merged_ids
|
| 119 |
+
chunks[-1]['attention_mask'] = merged_mask
|
| 120 |
+
break
|
| 121 |
+
|
| 122 |
+
# Pad chunk if necessary
|
| 123 |
+
if len(chunk) < self.content_size:
|
| 124 |
+
chunk += [self.PAD] * (self.content_size - len(chunk))
|
| 125 |
+
|
| 126 |
+
# Add special tokens
|
| 127 |
+
if add_special_tokens:
|
| 128 |
+
chunk_with_special = [self.BOS] + chunk + [self.EOS]
|
| 129 |
+
else:
|
| 130 |
+
chunk_with_special = chunk
|
| 131 |
+
|
| 132 |
+
# Create tensors
|
| 133 |
+
input_ids = torch.tensor(chunk_with_special, dtype=torch.long)
|
| 134 |
+
attention_mask = (input_ids != self.PAD) # bool type (GPT suggestion)
|
| 135 |
+
|
| 136 |
+
chunks.append({
|
| 137 |
+
'input_ids': input_ids,
|
| 138 |
+
'attention_mask': attention_mask,
|
| 139 |
+
'position': (i, min(i + self.content_size, len(byte_sequence)))
|
| 140 |
+
})
|
| 141 |
+
positions.append((i, min(i + self.content_size, len(byte_sequence))))
|
| 142 |
+
|
| 143 |
+
# Stack all chunks
|
| 144 |
+
all_input_ids = torch.stack([c['input_ids'] for c in chunks])
|
| 145 |
+
all_attention_masks = torch.stack([c['attention_mask'] for c in chunks])
|
| 146 |
+
|
| 147 |
+
return {
|
| 148 |
+
'input_ids': all_input_ids, # [num_chunks, seq_len]
|
| 149 |
+
'attention_mask': all_attention_masks,
|
| 150 |
+
'num_chunks': len(chunks),
|
| 151 |
+
'chunk_positions': positions,
|
| 152 |
+
'original_length': len(byte_sequence)
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
def reconstruct(self,
|
| 156 |
+
input_ids: torch.Tensor,
|
| 157 |
+
positions: List[Tuple[int, int]] = None,
|
| 158 |
+
skip_special_tokens: bool = True,
|
| 159 |
+
overlap: int = 8) -> str:
|
| 160 |
+
"""
|
| 161 |
+
Reconstruct text from multiple chunks (GPT suggestion)
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
input_ids: [num_chunks, seq_len] for multi-chunk
|
| 165 |
+
positions: List of (start, end) positions for each chunk
|
| 166 |
+
skip_special_tokens: Whether to skip special tokens
|
| 167 |
+
overlap: Overlap size between chunks
|
| 168 |
+
"""
|
| 169 |
+
if input_ids.dim() == 1:
|
| 170 |
+
# Single sequence, use regular decode
|
| 171 |
+
return self.decode(input_ids, skip_special_tokens)
|
| 172 |
+
|
| 173 |
+
# Multi-chunk reconstruction
|
| 174 |
+
pieces = []
|
| 175 |
+
for i, chunk_ids in enumerate(input_ids):
|
| 176 |
+
chunk_ids = chunk_ids.cpu().numpy().tolist()
|
| 177 |
+
|
| 178 |
+
# Remove special tokens and padding
|
| 179 |
+
if skip_special_tokens:
|
| 180 |
+
chunk_ids = [
|
| 181 |
+
b for b in chunk_ids
|
| 182 |
+
if b not in [self.PAD, self.BOS, self.EOS, self.MASK] and b < 256
|
| 183 |
+
]
|
| 184 |
+
|
| 185 |
+
pieces.append(chunk_ids)
|
| 186 |
+
|
| 187 |
+
# Merge chunks with overlap handling
|
| 188 |
+
output = []
|
| 189 |
+
for i, chunk in enumerate(pieces):
|
| 190 |
+
if i == 0:
|
| 191 |
+
output.extend(chunk)
|
| 192 |
+
else:
|
| 193 |
+
# Skip overlap bytes from current chunk
|
| 194 |
+
output.extend(chunk[overlap:] if len(chunk) > overlap else chunk)
|
| 195 |
+
|
| 196 |
+
# Convert to string
|
| 197 |
+
try:
|
| 198 |
+
text = bytes(output).decode('utf-8', errors='replace')
|
| 199 |
+
except:
|
| 200 |
+
text = ""
|
| 201 |
+
|
| 202 |
+
return text
|
| 203 |
+
|
| 204 |
+
def decode(self,
|
| 205 |
+
input_ids: torch.Tensor,
|
| 206 |
+
skip_special_tokens: bool = True) -> str:
|
| 207 |
+
"""
|
| 208 |
+
Decode byte sequences back to text
|
| 209 |
+
"""
|
| 210 |
+
if isinstance(input_ids, torch.Tensor):
|
| 211 |
+
input_ids = input_ids.cpu().numpy().tolist()
|
| 212 |
+
|
| 213 |
+
# Handle batch dimension
|
| 214 |
+
if isinstance(input_ids[0], list):
|
| 215 |
+
input_ids = input_ids[0]
|
| 216 |
+
|
| 217 |
+
# Remove special tokens and padding
|
| 218 |
+
if skip_special_tokens:
|
| 219 |
+
input_ids = [
|
| 220 |
+
b for b in input_ids
|
| 221 |
+
if b not in [self.PAD, self.BOS, self.EOS, self.MASK] and b < 256
|
| 222 |
+
]
|
| 223 |
+
|
| 224 |
+
# Convert bytes to string
|
| 225 |
+
try:
|
| 226 |
+
text = bytes(input_ids).decode('utf-8', errors='replace')
|
| 227 |
+
except:
|
| 228 |
+
text = ""
|
| 229 |
+
|
| 230 |
+
return text
|
| 231 |
+
|
| 232 |
+
def batch_encode(self,
|
| 233 |
+
texts: List[str],
|
| 234 |
+
add_special_tokens: bool = True) -> Dict[str, torch.Tensor]:
|
| 235 |
+
"""
|
| 236 |
+
Encode multiple texts as a batch
|
| 237 |
+
"""
|
| 238 |
+
encoded = [self.encode(text, add_special_tokens) for text in texts]
|
| 239 |
+
|
| 240 |
+
# Find max length
|
| 241 |
+
max_len = max(e['length'] for e in encoded)
|
| 242 |
+
max_len = min(max_len, self.max_seq_len)
|
| 243 |
+
|
| 244 |
+
# Create batch tensors
|
| 245 |
+
batch_size = len(texts)
|
| 246 |
+
input_ids = torch.full((batch_size, max_len), self.PAD, dtype=torch.long)
|
| 247 |
+
attention_mask = torch.zeros((batch_size, max_len), dtype=torch.bool) # bool type (GPT suggestion)
|
| 248 |
+
|
| 249 |
+
for i, enc in enumerate(encoded):
|
| 250 |
+
seq_len = min(enc['length'], max_len)
|
| 251 |
+
if enc['input_ids'].dim() == 0: # Handle scalar
|
| 252 |
+
enc['input_ids'] = enc['input_ids'].unsqueeze(0)
|
| 253 |
+
input_ids[i, :seq_len] = enc['input_ids'][:seq_len]
|
| 254 |
+
attention_mask[i, :seq_len] = True
|
| 255 |
+
|
| 256 |
+
return {
|
| 257 |
+
'input_ids': input_ids,
|
| 258 |
+
'attention_mask': attention_mask,
|
| 259 |
+
'lengths': [e['length'] for e in encoded]
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
class ChunkBoundaryAdjuster(nn.Module):
|
| 264 |
+
"""
|
| 265 |
+
Neural network for adjusting chunk boundaries
|
| 266 |
+
Learns optimal splitting points
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
def __init__(self, hidden_dim: int = 256):
|
| 270 |
+
super().__init__()
|
| 271 |
+
|
| 272 |
+
# Boundary scoring network
|
| 273 |
+
self.boundary_scorer = nn.Sequential(
|
| 274 |
+
nn.Linear(256, hidden_dim), # Input: byte embeddings
|
| 275 |
+
nn.ReLU(),
|
| 276 |
+
nn.Dropout(0.1),
|
| 277 |
+
nn.Linear(hidden_dim, hidden_dim // 2),
|
| 278 |
+
nn.ReLU(),
|
| 279 |
+
nn.Linear(hidden_dim // 2, 1), # Output: boundary score
|
| 280 |
+
nn.Sigmoid()
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# UTF-8 boundary detector
|
| 284 |
+
self.utf8_detector = nn.Sequential(
|
| 285 |
+
nn.Conv1d(1, 16, kernel_size=4, padding=2), # Detect multi-byte patterns
|
| 286 |
+
nn.ReLU(),
|
| 287 |
+
nn.Conv1d(16, 1, kernel_size=1),
|
| 288 |
+
nn.Sigmoid()
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
def forward(self, byte_sequence: torch.Tensor) -> torch.Tensor:
|
| 292 |
+
"""
|
| 293 |
+
Find optimal chunk boundaries
|
| 294 |
+
|
| 295 |
+
Args:
|
| 296 |
+
byte_sequence: [batch, seq_len, embedding_dim]
|
| 297 |
+
|
| 298 |
+
Returns:
|
| 299 |
+
boundary_scores: [batch, seq_len] - probability of boundary at each position
|
| 300 |
+
"""
|
| 301 |
+
batch_size, seq_len = byte_sequence.shape[:2]
|
| 302 |
+
|
| 303 |
+
# Score each position as potential boundary
|
| 304 |
+
boundary_scores = self.boundary_scorer(byte_sequence).squeeze(-1)
|
| 305 |
+
|
| 306 |
+
# Detect UTF-8 boundaries (avoid splitting multi-byte characters)
|
| 307 |
+
byte_values = byte_sequence[..., 0].unsqueeze(1) # [batch, 1, seq_len]
|
| 308 |
+
utf8_scores = self.utf8_detector(byte_values).squeeze(1) # [batch, seq_len]
|
| 309 |
+
|
| 310 |
+
# Combine scores (prefer boundaries at valid UTF-8 positions)
|
| 311 |
+
combined_scores = boundary_scores * utf8_scores
|
| 312 |
+
|
| 313 |
+
# Apply constraints: boundaries should be ~46 bytes apart
|
| 314 |
+
for i in range(0, seq_len, 46):
|
| 315 |
+
if i < seq_len:
|
| 316 |
+
# Boost score at expected positions
|
| 317 |
+
combined_scores[:, i] = combined_scores[:, i] * 1.5
|
| 318 |
+
|
| 319 |
+
return combined_scores
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class SlidingWindowProcessor(nn.Module):
|
| 323 |
+
"""
|
| 324 |
+
Process sequences with sliding windows at multiple scales
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
def __init__(self, window_sizes: List[int] = [8, 16, 32, 46]):
|
| 328 |
+
super().__init__()
|
| 329 |
+
self.window_sizes = window_sizes
|
| 330 |
+
|
| 331 |
+
# Multi-scale convolutions for different window sizes
|
| 332 |
+
self.convs = nn.ModuleList([
|
| 333 |
+
nn.Conv1d(256, 128, kernel_size=ws, stride=ws//2, padding=ws//4)
|
| 334 |
+
for ws in window_sizes
|
| 335 |
+
])
|
| 336 |
+
|
| 337 |
+
# Fusion layer
|
| 338 |
+
self.fusion = nn.Sequential(
|
| 339 |
+
nn.Linear(128 * len(window_sizes), 256),
|
| 340 |
+
nn.ReLU(),
|
| 341 |
+
nn.Dropout(0.1),
|
| 342 |
+
nn.Linear(256, 256)
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
def forward(self, byte_embeddings: torch.Tensor) -> torch.Tensor:
|
| 346 |
+
"""
|
| 347 |
+
Apply multi-scale sliding windows
|
| 348 |
+
|
| 349 |
+
Args:
|
| 350 |
+
byte_embeddings: [batch, seq_len, embedding_dim]
|
| 351 |
+
|
| 352 |
+
Returns:
|
| 353 |
+
processed: [batch, seq_len, embedding_dim]
|
| 354 |
+
"""
|
| 355 |
+
# Transpose for conv1d
|
| 356 |
+
x = byte_embeddings.transpose(1, 2) # [batch, embed, seq]
|
| 357 |
+
|
| 358 |
+
# Apply multi-scale convolutions
|
| 359 |
+
multi_scale_features = []
|
| 360 |
+
for conv in self.convs:
|
| 361 |
+
features = conv(x) # Different seq lengths
|
| 362 |
+
# Global average pooling to fixed size
|
| 363 |
+
pooled = F.adaptive_avg_pool1d(features, byte_embeddings.size(1))
|
| 364 |
+
multi_scale_features.append(pooled)
|
| 365 |
+
|
| 366 |
+
# Concatenate and transpose back
|
| 367 |
+
concat = torch.cat(multi_scale_features, dim=1) # [batch, 128*scales, seq]
|
| 368 |
+
concat = concat.transpose(1, 2) # [batch, seq, 128*scales]
|
| 369 |
+
|
| 370 |
+
# Fuse multi-scale features
|
| 371 |
+
fused = self.fusion(concat) # [batch, seq, 256]
|
| 372 |
+
|
| 373 |
+
# Residual connection
|
| 374 |
+
output = fused + byte_embeddings
|
| 375 |
+
|
| 376 |
+
return output
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
class AdaptiveChunker:
|
| 380 |
+
"""
|
| 381 |
+
Adaptive chunking based on content complexity
|
| 382 |
+
Simple heuristic-based chunker for inference
|
| 383 |
+
"""
|
| 384 |
+
|
| 385 |
+
def __init__(self):
|
| 386 |
+
self.min_chunk = 32
|
| 387 |
+
self.max_chunk = 46
|
| 388 |
+
self.target_chunk = 46
|
| 389 |
+
|
| 390 |
+
def determine_chunk_size(self, text: str) -> int:
|
| 391 |
+
"""
|
| 392 |
+
Determine optimal chunk size based on text characteristics
|
| 393 |
+
"""
|
| 394 |
+
byte_seq = text.encode('utf-8')
|
| 395 |
+
|
| 396 |
+
# Check character types
|
| 397 |
+
has_cjk = any(b >= 0x80 for b in byte_seq[:100]) # Non-ASCII
|
| 398 |
+
has_arabic = any(0x0600 <= ord(c) <= 0x06FF for c in text[:100])
|
| 399 |
+
|
| 400 |
+
# Adjust chunk size based on content
|
| 401 |
+
if has_cjk:
|
| 402 |
+
# CJK characters need smaller chunks (multi-byte)
|
| 403 |
+
return self.min_chunk
|
| 404 |
+
elif has_arabic:
|
| 405 |
+
# Arabic also benefits from smaller chunks
|
| 406 |
+
return 40
|
| 407 |
+
else:
|
| 408 |
+
# ASCII/Latin can use larger chunks
|
| 409 |
+
return self.target_chunk
|
| 410 |
+
|
| 411 |
+
def chunk_text(self, text: str) -> List[str]:
|
| 412 |
+
"""
|
| 413 |
+
Split text into adaptive chunks
|
| 414 |
+
"""
|
| 415 |
+
chunk_size = self.determine_chunk_size(text)
|
| 416 |
+
byte_seq = text.encode('utf-8')
|
| 417 |
+
chunks = []
|
| 418 |
+
|
| 419 |
+
i = 0
|
| 420 |
+
while i < len(byte_seq):
|
| 421 |
+
# Find chunk boundary (don't split UTF-8 sequences)
|
| 422 |
+
end = min(i + chunk_size, len(byte_seq))
|
| 423 |
+
|
| 424 |
+
# Backtrack to valid UTF-8 boundary if needed
|
| 425 |
+
while end > i and end < len(byte_seq):
|
| 426 |
+
try:
|
| 427 |
+
_ = byte_seq[i:end].decode('utf-8')
|
| 428 |
+
break
|
| 429 |
+
except:
|
| 430 |
+
end -= 1
|
| 431 |
+
|
| 432 |
+
chunk_bytes = byte_seq[i:end]
|
| 433 |
+
chunks.append(chunk_bytes.decode('utf-8', errors='replace'))
|
| 434 |
+
i = end
|
| 435 |
+
|
| 436 |
+
return chunks
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
if __name__ == "__main__":
|
| 440 |
+
# Test the tokenizer
|
| 441 |
+
tokenizer = ByteTokenizerV62()
|
| 442 |
+
|
| 443 |
+
# Test texts
|
| 444 |
+
test_texts = [
|
| 445 |
+
"Hello, world!",
|
| 446 |
+
"안녕하세요, 세계!",
|
| 447 |
+
"今天天气很好。",
|
| 448 |
+
"مرحبا بالعالم",
|
| 449 |
+
"A" * 100 # Long text
|
| 450 |
+
]
|
| 451 |
+
|
| 452 |
+
for text in test_texts:
|
| 453 |
+
print(f"\nText: {text[:50]}...")
|
| 454 |
+
|
| 455 |
+
# Single chunk encoding
|
| 456 |
+
encoded = tokenizer.encode(text)
|
| 457 |
+
print(f" Encoded shape: {encoded['input_ids'].shape}")
|
| 458 |
+
print(f" Original length: {encoded['original_length']} bytes")
|
| 459 |
+
|
| 460 |
+
# Decode back
|
| 461 |
+
decoded = tokenizer.decode(encoded['input_ids'])
|
| 462 |
+
print(f" Decoded: {decoded[:50]}...")
|
| 463 |
+
|
| 464 |
+
# Check multi-chunk for long text
|
| 465 |
+
if encoded['original_length'] > 46:
|
| 466 |
+
multi_encoded = tokenizer.encode(text, return_chunks=True)
|
| 467 |
+
print(f" Chunks: {multi_encoded['num_chunks']}")
|
| 468 |
+
|
| 469 |
+
# Test batch encoding
|
| 470 |
+
batch = tokenizer.batch_encode(test_texts[:3])
|
| 471 |
+
print(f"\nBatch shape: {batch['input_ids'].shape}")
|
| 472 |
+
|
| 473 |
+
# Test adaptive chunker
|
| 474 |
+
chunker = AdaptiveChunker()
|
| 475 |
+
for text in test_texts[:3]:
|
| 476 |
+
chunk_size = chunker.determine_chunk_size(text)
|
| 477 |
+
print(f"\n{text[:30]}... → Chunk size: {chunk_size}")
|
core/unified_model.py
CHANGED
|
@@ -1,755 +1,541 @@
|
|
| 1 |
"""
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
- 64 byte chunks for aggressive compression
|
| 5 |
-
- 50 epoch checkpoints with automatic splitting
|
| 6 |
-
- Group relation learning for reconstruction
|
| 7 |
-
- Boundary adjustment for semantic units
|
| 8 |
"""
|
| 9 |
|
| 10 |
import torch
|
| 11 |
import torch.nn as nn
|
| 12 |
import torch.nn.functional as F
|
| 13 |
-
import math
|
| 14 |
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
|
|
|
| 18 |
"""
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
"""
|
| 22 |
-
|
| 23 |
-
def __init__(self,
|
| 24 |
super().__init__()
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
"""
|
| 42 |
-
|
|
|
|
| 43 |
Args:
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
"""
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
self.max_seq_len = max_seq_len
|
| 57 |
-
self.PAD = 256
|
| 58 |
-
self.BOS = 257
|
| 59 |
-
self.EOS = 258
|
| 60 |
-
self.MASK = 259
|
| 61 |
-
|
| 62 |
-
def encode(self, text: str, add_special_tokens: bool = True) -> Dict[str, torch.Tensor]:
|
| 63 |
-
# Convert to UTF-8 bytes
|
| 64 |
-
byte_seq = list(text.encode('utf-8'))
|
| 65 |
-
|
| 66 |
-
# Truncate if needed
|
| 67 |
-
if len(byte_seq) > self.max_seq_len - 2:
|
| 68 |
-
byte_seq = byte_seq[:self.max_seq_len - 2]
|
| 69 |
-
|
| 70 |
-
# Add special tokens
|
| 71 |
-
if add_special_tokens:
|
| 72 |
-
byte_seq = [self.BOS] + byte_seq + [self.EOS]
|
| 73 |
-
|
| 74 |
-
input_ids = torch.tensor(byte_seq, dtype=torch.long)
|
| 75 |
-
attention_mask = torch.ones_like(input_ids)
|
| 76 |
-
|
| 77 |
-
return {
|
| 78 |
-
'input_ids': input_ids,
|
| 79 |
-
'attention_mask': attention_mask,
|
| 80 |
-
'length': len(input_ids)
|
| 81 |
-
}
|
| 82 |
-
|
| 83 |
-
def encode_batch(self, texts: List[str]) -> Dict[str, torch.Tensor]:
|
| 84 |
-
encoded = [self.encode(text) for text in texts]
|
| 85 |
-
max_len = min(max(e['length'] for e in encoded), self.max_seq_len)
|
| 86 |
-
|
| 87 |
-
batch_size = len(texts)
|
| 88 |
-
input_ids = torch.full((batch_size, max_len), self.PAD, dtype=torch.long)
|
| 89 |
-
attention_mask = torch.zeros((batch_size, max_len), dtype=torch.float32)
|
| 90 |
-
|
| 91 |
-
for i, enc in enumerate(encoded):
|
| 92 |
-
seq_len = min(enc['length'], max_len)
|
| 93 |
-
input_ids[i, :seq_len] = enc['input_ids'][:seq_len]
|
| 94 |
-
attention_mask[i, :seq_len] = 1.0
|
| 95 |
-
|
| 96 |
-
return {
|
| 97 |
-
'input_ids': input_ids,
|
| 98 |
-
'attention_mask': attention_mask
|
| 99 |
-
}
|
| 100 |
-
|
| 101 |
-
def decode(self, input_ids: torch.Tensor, skip_special_tokens: bool = True) -> str:
|
| 102 |
-
if isinstance(input_ids, torch.Tensor):
|
| 103 |
-
input_ids = input_ids.cpu().numpy().tolist()
|
| 104 |
-
|
| 105 |
-
if skip_special_tokens:
|
| 106 |
-
input_ids = [b for b in input_ids if b < 256]
|
| 107 |
-
|
| 108 |
-
try:
|
| 109 |
-
byte_array = bytes([min(b, 255) for b in input_ids if b != self.PAD])
|
| 110 |
-
return byte_array.decode('utf-8', errors='replace')
|
| 111 |
-
except:
|
| 112 |
-
return "".join([chr(b) if b < 128 else '?' for b in input_ids if b < 256])
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
class ByteEncoderV61(nn.Module):
|
| 116 |
-
"""
|
| 117 |
-
v6.1: 5-Layer Encoder with Layer-Specialized Architecture
|
| 118 |
-
Layer 0: 768d - Byte to character (with curriculum learning)
|
| 119 |
-
Layer 1: 896d - Language pattern discovery (no labels)
|
| 120 |
-
Layer 2: 1024d - Eojeol/Word formation (+ eojeol PE)
|
| 121 |
-
Layer 3: 1152d - Small phrase grouping (2-3 eojeols)
|
| 122 |
-
Layer 4: 1280d - Final refinement (+ context PE)
|
| 123 |
-
|
| 124 |
-
Target: 어절(eojeol) to 구(phrase) level compression (3:1 ratio)
|
| 125 |
-
"""
|
| 126 |
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
num_heads: List[int] = [12, 14, 16, 18, 20], # v6.1: Progressive heads per layer
|
| 132 |
-
dropout: float = 0.1,
|
| 133 |
-
max_seq_len: int = 64 # v6.1.2: 64 chunk for compression-first
|
| 134 |
-
):
|
| 135 |
-
super().__init__()
|
| 136 |
-
|
| 137 |
-
# Layer 0: Byte to Character with Curriculum Learning
|
| 138 |
-
self.byte_embedding = nn.Embedding(vocab_size, hidden_dims[0])
|
| 139 |
|
| 140 |
-
|
| 141 |
-
# Level 1: Character boundaries (UTF-8 multi-byte)
|
| 142 |
-
self.char_boundary_predictor = nn.Linear(hidden_dims[0], 3) # 0: continue, 1: start, 2: end
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
-
|
| 151 |
-
self.pos_encoding = PositionalEncoding(hidden_dims[0], max_seq_len, dropout)
|
| 152 |
|
| 153 |
-
#
|
| 154 |
-
|
| 155 |
-
|
|
|
|
|
|
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
for i in range(len(hidden_dims)):
|
| 165 |
-
input_dim = hidden_dims[i-1] if i > 0 else hidden_dims[0]
|
| 166 |
-
output_dim = hidden_dims[i]
|
| 167 |
|
| 168 |
-
#
|
| 169 |
-
if
|
| 170 |
-
|
|
|
|
| 171 |
else:
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
}))
|
| 194 |
-
|
| 195 |
-
self.dropout = nn.Dropout(dropout)
|
| 196 |
-
|
| 197 |
-
def forward(
|
| 198 |
-
self,
|
| 199 |
-
input_ids: torch.Tensor,
|
| 200 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 201 |
-
boundary_labels: Optional[torch.Tensor] = None,
|
| 202 |
-
epoch: int = 0
|
| 203 |
-
) -> Dict[str, torch.Tensor]:
|
| 204 |
"""
|
| 205 |
-
|
|
|
|
| 206 |
Args:
|
| 207 |
-
|
| 208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
"""
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
# v6.1: Predict character boundaries (Layer 0)
|
| 219 |
-
char_boundaries = self.char_boundary_predictor(x)
|
| 220 |
-
|
| 221 |
-
# v6.1: Curriculum learning for character boundaries
|
| 222 |
-
# Note: boundary_labels are eojeol boundaries (4 classes), not char boundaries (3 classes)
|
| 223 |
-
# So we don't mix them with char_boundaries - they serve different purposes
|
| 224 |
-
char_boundary_weights = F.softmax(char_boundaries, dim=-1)
|
| 225 |
-
|
| 226 |
-
# Prepare attention mask
|
| 227 |
-
if attention_mask is not None:
|
| 228 |
-
# Keep attention mask as is for TransformerEncoderLayer
|
| 229 |
-
# It expects shape (batch_size, seq_len) and handles masking internally
|
| 230 |
-
pass
|
| 231 |
-
|
| 232 |
-
# v6.1: Process through 5 specialized layers
|
| 233 |
-
all_hidden_states = []
|
| 234 |
-
discovered_patterns = None
|
| 235 |
-
eojeol_boundaries = None
|
| 236 |
-
phrase_boundaries = None
|
| 237 |
-
|
| 238 |
-
for i, layer_dict in enumerate(self.layers):
|
| 239 |
-
# Project if needed (before layer-specific processing)
|
| 240 |
-
if layer_dict['projection'] is not None:
|
| 241 |
-
x = layer_dict['projection'](x)
|
| 242 |
-
|
| 243 |
-
# Layer 1: Add language signals (autonomous discovery)
|
| 244 |
-
if i == 1:
|
| 245 |
-
# Discover language patterns WITHOUT labels (x is now 896d)
|
| 246 |
-
discovered_patterns = self.pattern_discoverer(x)
|
| 247 |
-
lang_signals = self.lang_signal_generator(x)
|
| 248 |
-
|
| 249 |
-
# Layer 2: Predict eojeol boundaries and add position encoding
|
| 250 |
-
elif i == 2:
|
| 251 |
-
# Predict eojeol boundaries (spaces, particles, punctuation)
|
| 252 |
-
eojeol_boundaries = self.eojeol_boundary_predictor(x)
|
| 253 |
-
positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
|
| 254 |
-
group_pe = self.group_pe_layer2(positions)
|
| 255 |
-
x = x + group_pe * 0.1 # Mild addition to preserve main signal
|
| 256 |
-
|
| 257 |
-
# Layer 3: Predict phrase boundaries and add position encoding
|
| 258 |
-
elif i == 3:
|
| 259 |
-
# Predict phrase boundaries (weak/strong syntactic breaks)
|
| 260 |
-
phrase_boundaries = self.phrase_boundary_predictor(x)
|
| 261 |
-
positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
|
| 262 |
-
group_pe = self.group_pe_layer3(positions)
|
| 263 |
-
x = x + group_pe * 0.1
|
| 264 |
-
|
| 265 |
-
elif i == 4:
|
| 266 |
-
positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
|
| 267 |
-
group_pe = self.group_pe_layer4(positions)
|
| 268 |
-
x = x + group_pe * 0.1
|
| 269 |
-
|
| 270 |
-
# Transformer layer - properly handle mask
|
| 271 |
-
if attention_mask is not None:
|
| 272 |
-
key_padding_mask = (attention_mask == 0)
|
| 273 |
-
x = layer_dict['transformer'](x, src_key_padding_mask=key_padding_mask)
|
| 274 |
else:
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
# Pool for sequence representation
|
| 280 |
-
if attention_mask is not None:
|
| 281 |
-
# Masked mean pooling - attention_mask is (batch, seq)
|
| 282 |
-
mask = attention_mask.unsqueeze(-1) # (batch, seq, 1)
|
| 283 |
-
pooled = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
|
| 284 |
else:
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
return {
|
| 288 |
-
'last_hidden_state': x,
|
| 289 |
-
'pooled_output': pooled,
|
| 290 |
-
'all_hidden_states': all_hidden_states,
|
| 291 |
-
# v6.1 boundary predictions
|
| 292 |
-
'char_boundaries': char_boundaries,
|
| 293 |
-
'char_boundary_weights': char_boundary_weights,
|
| 294 |
-
'eojeol_boundaries': eojeol_boundaries,
|
| 295 |
-
'phrase_boundaries': phrase_boundaries,
|
| 296 |
-
'discovered_patterns': discovered_patterns
|
| 297 |
-
}
|
| 298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
def __init__(self, hidden_dim: int = 1280, num_heads: int = 20, dropout: float = 0.1):
|
| 307 |
-
super().__init__()
|
| 308 |
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
#
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
self.reconstruction_attn = nn.MultiheadAttention(
|
| 330 |
-
hidden_dim, 10, dropout * 0.5, batch_first=True
|
| 331 |
-
)
|
| 332 |
-
|
| 333 |
-
# Gating mechanism for adaptive fusion
|
| 334 |
-
self.gate = nn.Sequential(
|
| 335 |
-
nn.Linear(hidden_dim * 2, hidden_dim),
|
| 336 |
-
nn.Sigmoid()
|
| 337 |
-
)
|
| 338 |
-
|
| 339 |
-
self.norm1 = nn.LayerNorm(hidden_dim)
|
| 340 |
-
self.norm2 = nn.LayerNorm(hidden_dim)
|
| 341 |
-
|
| 342 |
-
def forward(
|
| 343 |
-
self,
|
| 344 |
-
query: torch.Tensor,
|
| 345 |
-
key: torch.Tensor,
|
| 346 |
-
query_mask: Optional[torch.Tensor] = None,
|
| 347 |
-
key_mask: Optional[torch.Tensor] = None
|
| 348 |
-
) -> Dict[str, torch.Tensor]:
|
| 349 |
-
# Normalize inputs
|
| 350 |
-
query_norm = self.norm1(query)
|
| 351 |
-
key_norm = self.norm2(key)
|
| 352 |
-
|
| 353 |
-
# Fix key_mask dimension if needed
|
| 354 |
-
if key_mask is not None:
|
| 355 |
-
# Ensure key_mask matches key sequence length
|
| 356 |
-
if key_mask.dim() == 2 and key_mask.size(1) != key.size(1):
|
| 357 |
-
# Create new mask with correct dimensions
|
| 358 |
-
batch_size = key.size(0)
|
| 359 |
-
seq_len = key.size(1)
|
| 360 |
-
key_mask = torch.ones(batch_size, seq_len, dtype=key_mask.dtype, device=key_mask.device)
|
| 361 |
-
|
| 362 |
-
# Cross attention
|
| 363 |
-
attn_output, attn_weights = self.cross_attn(
|
| 364 |
-
query_norm, key_norm, key_norm,
|
| 365 |
-
key_padding_mask=(key_mask == 0) if key_mask is not None else None
|
| 366 |
-
)
|
| 367 |
-
|
| 368 |
-
# Residual connection
|
| 369 |
-
attn_output = attn_output + query
|
| 370 |
-
|
| 371 |
-
# v6.1: Reconstruction-focused attention (복원 최적화)
|
| 372 |
-
recon_output, recon_weights = self.reconstruction_attn(
|
| 373 |
-
query_norm, query_norm, query_norm, # Self-attention for consistency
|
| 374 |
-
key_padding_mask=(query_mask == 0) if query_mask is not None else None
|
| 375 |
-
)
|
| 376 |
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
gate_input = torch.cat([query.mean(dim=1), key.mean(dim=1)], dim=-1)
|
| 382 |
-
gate_weights = self.gate(gate_input).unsqueeze(1)
|
| 383 |
-
|
| 384 |
-
# Gated fusion: 적응적으로 attention 결과 조절
|
| 385 |
-
fused_output = gate_weights * combined_attn + (1 - gate_weights) * query
|
| 386 |
-
|
| 387 |
-
# Pool for relation classification
|
| 388 |
-
query_pooled = query.mean(dim=1) if query_mask is None else \
|
| 389 |
-
(query * query_mask.unsqueeze(-1)).sum(1) / query_mask.sum(1, keepdim=True).clamp(min=1e-9)
|
| 390 |
-
key_pooled = key.mean(dim=1) if key_mask is None else \
|
| 391 |
-
(key * key_mask.unsqueeze(-1)).sum(1) / key_mask.sum(1, keepdim=True).clamp(min=1e-9)
|
| 392 |
-
|
| 393 |
-
# Classify relations with enhanced head
|
| 394 |
-
combined = torch.cat([query_pooled, key_pooled], dim=-1)
|
| 395 |
-
relation_logits = self.relation_head(combined)
|
| 396 |
-
|
| 397 |
-
return {
|
| 398 |
-
'cross_attention': fused_output, # Gated fusion output
|
| 399 |
-
'attention_weights': attn_weights,
|
| 400 |
-
'reconstruction_weights': recon_weights, # v6.1: 복원 어텐션 가중치
|
| 401 |
-
'relation_logits': relation_logits,
|
| 402 |
-
'gate_weights': gate_weights.squeeze(1), # For analysis
|
| 403 |
-
'reconstruction_score': F.softmax(relation_logits, dim=-1)[:, 0] # identity 확률 (복원도)
|
| 404 |
-
}
|
| 405 |
|
|
|
|
|
|
|
|
|
|
| 406 |
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
Transformer Decoder with Positional Encoding
|
| 410 |
-
"""
|
| 411 |
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
self.pos_encoding = PositionalEncoding(hidden_dim, max_seq_len, dropout)
|
| 428 |
-
|
| 429 |
-
# Transformer decoder
|
| 430 |
-
decoder_layer = nn.TransformerDecoderLayer(
|
| 431 |
-
d_model=hidden_dim,
|
| 432 |
-
nhead=num_heads,
|
| 433 |
-
dim_feedforward=hidden_dim * 4,
|
| 434 |
-
dropout=dropout,
|
| 435 |
-
activation='gelu',
|
| 436 |
-
batch_first=True,
|
| 437 |
-
norm_first=True
|
| 438 |
-
)
|
| 439 |
-
|
| 440 |
-
self.transformer = nn.TransformerDecoder(decoder_layer, num_layers)
|
| 441 |
-
|
| 442 |
-
# Output projection
|
| 443 |
-
self.output_projection = nn.Linear(hidden_dim, vocab_size)
|
| 444 |
-
|
| 445 |
-
self.hidden_dim = hidden_dim
|
| 446 |
-
self.vocab_size = vocab_size
|
| 447 |
-
|
| 448 |
-
def forward(
|
| 449 |
-
self,
|
| 450 |
-
encoder_hidden: torch.Tensor,
|
| 451 |
-
decoder_input_ids: Optional[torch.Tensor] = None,
|
| 452 |
-
encoder_mask: Optional[torch.Tensor] = None,
|
| 453 |
-
decoder_mask: Optional[torch.Tensor] = None
|
| 454 |
-
) -> Dict[str, torch.Tensor]:
|
| 455 |
-
batch_size = encoder_hidden.size(0)
|
| 456 |
-
|
| 457 |
-
# Start with BOS if no input
|
| 458 |
-
if decoder_input_ids is None:
|
| 459 |
-
decoder_input_ids = torch.full((batch_size, 1), 257, device=encoder_hidden.device)
|
| 460 |
-
|
| 461 |
-
# Embed and add positional encoding
|
| 462 |
-
dec_seq_len = decoder_input_ids.size(1)
|
| 463 |
-
x = self.token_embedding(decoder_input_ids)
|
| 464 |
-
x = self.pos_encoding(x)
|
| 465 |
-
|
| 466 |
-
# Create causal mask
|
| 467 |
-
causal_mask = torch.triu(
|
| 468 |
-
torch.ones(dec_seq_len, dec_seq_len, device=x.device) * float('-inf'),
|
| 469 |
-
diagonal=1
|
| 470 |
-
)
|
| 471 |
-
|
| 472 |
-
# Decoder forward - handle variable-length encoder outputs
|
| 473 |
-
# The encoder may compress the sequence, so memory (encoder_hidden) might be shorter
|
| 474 |
-
# than the decoder sequence. This is expected and correct behavior.
|
| 475 |
-
enc_seq_len = encoder_hidden.size(1)
|
| 476 |
-
|
| 477 |
-
# Adjust encoder mask if needed
|
| 478 |
-
if encoder_mask is not None:
|
| 479 |
-
if encoder_mask.size(1) != enc_seq_len:
|
| 480 |
-
# Encoder compressed the sequence, create new mask for compressed length
|
| 481 |
-
# All compressed positions are valid (not masked)
|
| 482 |
-
memory_key_padding_mask = torch.zeros(
|
| 483 |
-
encoder_hidden.size(0), enc_seq_len,
|
| 484 |
-
dtype=torch.bool, device=encoder_hidden.device
|
| 485 |
-
)
|
| 486 |
-
else:
|
| 487 |
-
memory_key_padding_mask = (encoder_mask == 0)
|
| 488 |
else:
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
# Decoder attends to compressed encoder states via cross-attention
|
| 492 |
-
# This naturally handles different sequence lengths
|
| 493 |
-
decoder_output = self.transformer(
|
| 494 |
-
tgt=x, # Decoder sequence (original length)
|
| 495 |
-
memory=encoder_hidden, # Encoder sequence (possibly compressed)
|
| 496 |
-
tgt_mask=causal_mask,
|
| 497 |
-
memory_key_padding_mask=memory_key_padding_mask,
|
| 498 |
-
tgt_key_padding_mask=(decoder_mask == 0) if decoder_mask is not None else None
|
| 499 |
-
)
|
| 500 |
-
|
| 501 |
-
# Project to vocabulary
|
| 502 |
-
logits = self.output_projection(decoder_output)
|
| 503 |
-
|
| 504 |
-
return {
|
| 505 |
-
'logits': logits,
|
| 506 |
-
'hidden_states': decoder_output
|
| 507 |
-
}
|
| 508 |
-
|
| 509 |
-
@torch.no_grad()
|
| 510 |
-
def generate(
|
| 511 |
-
self,
|
| 512 |
-
encoder_hidden: torch.Tensor,
|
| 513 |
-
encoder_mask: Optional[torch.Tensor] = None,
|
| 514 |
-
max_length: int = 128,
|
| 515 |
-
temperature: float = 0.1, # 토크나이저는 보수적 생성 (정확한 복원)
|
| 516 |
-
top_k: int = 10, # 상위 10개만 고려
|
| 517 |
-
top_p: float = 0.95
|
| 518 |
-
) -> torch.Tensor:
|
| 519 |
-
batch_size = encoder_hidden.size(0)
|
| 520 |
-
device = encoder_hidden.device
|
| 521 |
|
| 522 |
-
|
| 523 |
-
decoder_input_ids = torch.full((batch_size, 1), 257, device=device)
|
| 524 |
|
| 525 |
-
|
| 526 |
-
|
|
|
|
| 527 |
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
outputs = self.forward(encoder_hidden, decoder_input_ids, encoder_mask)
|
| 531 |
-
next_token_logits = outputs['logits'][:, -1, :] / temperature
|
| 532 |
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
|
|
|
|
|
|
|
|
|
| 537 |
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
|
|
|
|
|
|
|
|
|
| 546 |
|
| 547 |
-
|
| 548 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 549 |
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
|
| 554 |
-
|
| 555 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 556 |
|
| 557 |
-
|
|
|
|
| 558 |
|
| 559 |
-
#
|
| 560 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 561 |
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
|
| 566 |
-
|
|
|
|
|
|
|
|
|
|
| 567 |
|
|
|
|
|
|
|
|
|
|
| 568 |
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
Pure learning-based with curriculum learning
|
| 573 |
-
- No language labels during training
|
| 574 |
-
- Curriculum learning for boundaries
|
| 575 |
-
- Group-aware position encodings
|
| 576 |
-
"""
|
| 577 |
|
| 578 |
-
|
| 579 |
-
self,
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
num_decoder_layers: int = 8, # v6.1 FINAL: 8 layers for better reconstruction
|
| 586 |
-
dropout: float = 0.1,
|
| 587 |
-
max_seq_len: int = 64 # v6.1.2: 64 chunk for compression-first
|
| 588 |
-
):
|
| 589 |
-
super().__init__()
|
| 590 |
|
| 591 |
-
#
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
self.decoder = TransformerDecoder(vocab_size, decoder_hidden, decoder_heads, num_decoder_layers, dropout, max_seq_len)
|
| 595 |
-
self.cross_attention = CrossAttention(encoder_dims[-1], 20, dropout) # 20 heads for 1280d
|
| 596 |
-
|
| 597 |
-
def forward(
|
| 598 |
-
self,
|
| 599 |
-
input_texts: Optional[List[str]] = None,
|
| 600 |
-
input_ids: Optional[torch.Tensor] = None,
|
| 601 |
-
attention_mask: Optional[torch.Tensor] = None,
|
| 602 |
-
decoder_input_ids: Optional[torch.Tensor] = None,
|
| 603 |
-
labels: Optional[torch.Tensor] = None,
|
| 604 |
-
boundary_labels: Optional[torch.Tensor] = None, # v6.1: for curriculum learning
|
| 605 |
-
epoch: int = 0, # v6.1: for curriculum schedule
|
| 606 |
-
use_cross_attention: bool = True
|
| 607 |
-
) -> Dict[str, torch.Tensor]:
|
| 608 |
-
# Tokenize if text input
|
| 609 |
-
if input_texts is not None:
|
| 610 |
-
tokenized = self.tokenizer.encode_batch(input_texts)
|
| 611 |
-
input_ids = tokenized['input_ids']
|
| 612 |
-
attention_mask = tokenized['attention_mask']
|
| 613 |
-
|
| 614 |
-
# 시퀀스 길이 체크 및 조정
|
| 615 |
-
batch_size, seq_len = input_ids.shape
|
| 616 |
-
device = input_ids.device
|
| 617 |
-
|
| 618 |
-
# v6.1: Encode with curriculum learning
|
| 619 |
-
encoder_outputs = self.encoder(input_ids, attention_mask, boundary_labels, epoch)
|
| 620 |
-
encoder_hidden = encoder_outputs['last_hidden_state'] # v6.1: [batch, seq, 1280]
|
| 621 |
-
|
| 622 |
-
# v6.1: 차원 확인 - 최종 차원은 1280
|
| 623 |
-
assert encoder_hidden.size(-1) == 1280, f"Encoder dim mismatch: {encoder_hidden.size(-1)}"
|
| 624 |
-
|
| 625 |
-
# Prepare decoder input for teacher forcing during training
|
| 626 |
-
if decoder_input_ids is None:
|
| 627 |
-
if labels is not None:
|
| 628 |
-
# During training, use shifted labels as decoder input (teacher forcing)
|
| 629 |
-
# Add BOS at the beginning and remove last token
|
| 630 |
-
bos_tokens = torch.full((batch_size, 1), self.tokenizer.BOS, device=labels.device, dtype=labels.dtype)
|
| 631 |
-
decoder_input_ids = torch.cat([bos_tokens, labels[:, :-1]], dim=1)
|
| 632 |
-
else:
|
| 633 |
-
# For inference/test, start with BOS token
|
| 634 |
-
decoder_input_ids = torch.full((batch_size, 1), self.tokenizer.BOS, device=device, dtype=torch.long)
|
| 635 |
-
|
| 636 |
-
# Decode
|
| 637 |
-
decoder_outputs = self.decoder(
|
| 638 |
-
encoder_hidden,
|
| 639 |
-
decoder_input_ids,
|
| 640 |
-
attention_mask
|
| 641 |
-
)
|
| 642 |
-
decoder_hidden = decoder_outputs['hidden_states'] # [batch, seq, 768]
|
| 643 |
-
|
| 644 |
-
# Cross-Attention (마지막 레이어에서 관계 학습)
|
| 645 |
-
cross_attn_outputs = None
|
| 646 |
-
relation_logits = None
|
| 647 |
-
|
| 648 |
-
if use_cross_attention and decoder_hidden is not None:
|
| 649 |
-
# 디코더 출력과 인코더 출력 간 크로스어텐션
|
| 650 |
-
cross_attn_outputs = self.cross_attention(
|
| 651 |
-
query=decoder_hidden, # 디코더가 query
|
| 652 |
-
key=encoder_hidden, # 인코더가 key/value
|
| 653 |
-
query_mask=None, # decoder mask는 causal이므로 별도 처리
|
| 654 |
-
key_mask=attention_mask
|
| 655 |
-
)
|
| 656 |
-
|
| 657 |
-
# 관계 학습 결과
|
| 658 |
-
relation_logits = cross_attn_outputs['relation_logits']
|
| 659 |
-
|
| 660 |
-
# Cross-attention으로 강화된 디코더 표현
|
| 661 |
-
enhanced_decoder = decoder_hidden + cross_attn_outputs['cross_attention']
|
| 662 |
-
|
| 663 |
-
# 최종 로짓 재계산 (cross-attention 적용 후)
|
| 664 |
-
if hasattr(self.decoder, 'output_projection'):
|
| 665 |
-
decoder_outputs['logits'] = self.decoder.output_projection(enhanced_decoder)
|
| 666 |
-
|
| 667 |
-
# Calculate loss if labels provided
|
| 668 |
-
loss = None
|
| 669 |
-
if labels is not None:
|
| 670 |
-
# Reconstruction loss
|
| 671 |
-
loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer.PAD)
|
| 672 |
-
recon_loss = loss_fct(
|
| 673 |
-
decoder_outputs['logits'].reshape(-1, decoder_outputs['logits'].size(-1)),
|
| 674 |
-
labels.reshape(-1)
|
| 675 |
-
)
|
| 676 |
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
# Relation loss (if cross-attention used)
|
| 694 |
-
relation_loss = 0
|
| 695 |
-
if relation_logits is not None:
|
| 696 |
-
# 자기 관계는 identity (class 0)여야 함
|
| 697 |
-
batch_identity = torch.zeros(batch_size, dtype=torch.long, device=device)
|
| 698 |
-
relation_loss = F.cross_entropy(relation_logits, batch_identity) * 0.1
|
| 699 |
-
|
| 700 |
-
loss = recon_loss + boundary_loss + relation_loss
|
| 701 |
-
|
| 702 |
-
return {
|
| 703 |
-
'loss': loss,
|
| 704 |
-
'logits': decoder_outputs['logits'],
|
| 705 |
-
'decoder_logits': decoder_outputs['logits'], # Add for compatibility
|
| 706 |
-
'encoder_hidden_states': encoder_hidden,
|
| 707 |
-
'decoder_hidden_states': decoder_hidden,
|
| 708 |
-
'pooled_output': encoder_outputs['pooled_output'],
|
| 709 |
-
'cross_attention': cross_attn_outputs['cross_attention'] if cross_attn_outputs else None,
|
| 710 |
-
'relation_logits': relation_logits,
|
| 711 |
-
'all_encoder_states': encoder_outputs.get('all_hidden_states', None),
|
| 712 |
-
# Add boundary predictions for visualization
|
| 713 |
-
'char_boundaries': encoder_outputs.get('char_boundaries'),
|
| 714 |
-
'eojeol_boundaries': encoder_outputs.get('eojeol_boundaries'),
|
| 715 |
-
'phrase_boundaries': encoder_outputs.get('phrase_boundaries'),
|
| 716 |
-
'discovered_patterns': encoder_outputs.get('discovered_patterns')
|
| 717 |
}
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
"
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
Intelligent Tokenizer v6.2.0 - Unified Model
|
| 3 |
+
Integrates encoder, decoder, and tokenizer with all GPT improvements
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
import torch.nn.functional as F
|
|
|
|
| 9 |
from typing import Dict, List, Optional, Tuple, Union
|
| 10 |
+
import math
|
| 11 |
|
| 12 |
+
# Import our components
|
| 13 |
+
try:
|
| 14 |
+
from .encoder import EncoderV62
|
| 15 |
+
from .decoder import DecoderV62
|
| 16 |
+
from .tokenizer import ByteTokenizerV62
|
| 17 |
+
except ImportError:
|
| 18 |
+
# For standalone testing
|
| 19 |
+
from encoder import EncoderV62
|
| 20 |
+
from decoder import DecoderV62
|
| 21 |
+
from tokenizer import ByteTokenizerV62
|
| 22 |
|
| 23 |
+
|
| 24 |
+
class IntelligentTokenizerV62(nn.Module):
|
| 25 |
"""
|
| 26 |
+
Complete v6.2.0 model with progressive splitting and optimizations
|
| 27 |
+
|
| 28 |
+
Key features:
|
| 29 |
+
- 48-byte chunks (46+2 with BOS/EOS)
|
| 30 |
+
- Progressive splitting: 48→1→N→M tokens
|
| 31 |
+
- Multi-level cross-attention
|
| 32 |
+
- KV cache optimization (8x reduction)
|
| 33 |
+
- All GPT-5 improvements integrated
|
| 34 |
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, config: Optional[Dict] = None):
|
| 37 |
super().__init__()
|
| 38 |
+
|
| 39 |
+
# Default configuration
|
| 40 |
+
self.config = config or {}
|
| 41 |
+
|
| 42 |
+
# Model components
|
| 43 |
+
self.tokenizer = ByteTokenizerV62(config)
|
| 44 |
+
self.encoder = EncoderV62(config)
|
| 45 |
+
self.decoder = DecoderV62(config)
|
| 46 |
+
|
| 47 |
+
# Training configuration
|
| 48 |
+
self.compression_weight = 0.1
|
| 49 |
+
self.reconstruction_weight = 0.1
|
| 50 |
+
self.boundary_weight = 0.1
|
| 51 |
+
|
| 52 |
+
# Monitoring
|
| 53 |
+
self.register_buffer('training_step', torch.tensor(0))
|
| 54 |
+
self.register_buffer('current_epoch', torch.tensor(0))
|
| 55 |
+
|
| 56 |
+
def forward(self,
|
| 57 |
+
input_ids: torch.Tensor = None,
|
| 58 |
+
attention_mask: torch.Tensor = None,
|
| 59 |
+
labels: torch.Tensor = None,
|
| 60 |
+
text: str = None,
|
| 61 |
+
return_loss: bool = True,
|
| 62 |
+
temperature: float = 1.0) -> Dict[str, torch.Tensor]:
|
| 63 |
"""
|
| 64 |
+
Unified forward pass
|
| 65 |
+
|
| 66 |
Args:
|
| 67 |
+
input_ids: Pre-tokenized input (optional)
|
| 68 |
+
attention_mask: Attention mask (optional)
|
| 69 |
+
labels: Target labels for training (optional)
|
| 70 |
+
text: Raw text input (alternative to input_ids)
|
| 71 |
+
return_loss: Whether to compute loss
|
| 72 |
+
temperature: Temperature for Gumbel-Softmax in encoder
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Dictionary with model outputs
|
| 76 |
"""
|
| 77 |
+
# Handle text input
|
| 78 |
+
if text is not None:
|
| 79 |
+
encoded = self.tokenizer.encode(text, add_special_tokens=True)
|
| 80 |
+
input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
|
| 81 |
+
attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
|
| 82 |
+
|
| 83 |
+
# Handle string passed as input_ids (common mistake)
|
| 84 |
+
if isinstance(input_ids, str):
|
| 85 |
+
text = input_ids
|
| 86 |
+
encoded = self.tokenizer.encode(text, add_special_tokens=True)
|
| 87 |
+
input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
|
| 88 |
+
attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
|
| 89 |
+
|
| 90 |
+
# Ensure tensors are on the right device
|
| 91 |
+
device = next(self.parameters()).device
|
| 92 |
+
if input_ids is not None and torch.is_tensor(input_ids):
|
| 93 |
+
input_ids = input_ids.to(device)
|
| 94 |
+
if attention_mask is not None and torch.is_tensor(attention_mask):
|
| 95 |
+
attention_mask = attention_mask.to(device)
|
| 96 |
+
if labels is not None and torch.is_tensor(labels):
|
| 97 |
+
labels = labels.to(device)
|
| 98 |
+
|
| 99 |
+
# Encoder forward pass with temperature for Gumbel annealing
|
| 100 |
+
encoder_outputs = self.encoder(
|
| 101 |
+
input_ids=input_ids,
|
| 102 |
+
attention_mask=attention_mask,
|
| 103 |
+
temperature=temperature
|
| 104 |
+
)
|
| 105 |
|
| 106 |
+
# Decoder forward pass
|
| 107 |
+
if labels is not None:
|
| 108 |
+
# Training mode with teacher forcing (GPT suggestion: shift by 1)
|
| 109 |
+
# Input: labels[:-1], Target: labels[1:]
|
| 110 |
+
decoder_input = labels[:, :-1] if labels.dim() > 1 else labels[:-1]
|
| 111 |
+
decoder_mask = attention_mask[:, :-1] if attention_mask is not None and attention_mask.dim() > 1 else None
|
| 112 |
+
|
| 113 |
+
decoder_outputs = self.decoder(
|
| 114 |
+
encoder_all_hidden=encoder_outputs['all_hidden_states'],
|
| 115 |
+
decoder_input_ids=decoder_input,
|
| 116 |
+
attention_mask=decoder_mask
|
| 117 |
+
)
|
| 118 |
+
else:
|
| 119 |
+
# Inference mode (without teacher forcing)
|
| 120 |
+
# For now, fallback to using input as labels for stable training
|
| 121 |
+
# TODO: Implement proper autoregressive generation
|
| 122 |
+
if return_loss and input_ids is not None:
|
| 123 |
+
labels = input_ids # Use input as both input and target
|
| 124 |
+
decoder_input = labels[:, :-1] if labels.dim() > 1 else labels[:-1]
|
| 125 |
+
decoder_mask = attention_mask[:, :-1] if attention_mask is not None and attention_mask.dim() > 1 else None
|
| 126 |
+
|
| 127 |
+
decoder_outputs = self.decoder(
|
| 128 |
+
encoder_all_hidden=encoder_outputs['all_hidden_states'],
|
| 129 |
+
decoder_input_ids=decoder_input,
|
| 130 |
+
attention_mask=decoder_mask
|
| 131 |
+
)
|
| 132 |
+
else:
|
| 133 |
+
decoder_outputs = self.decoder(
|
| 134 |
+
encoder_all_hidden=encoder_outputs['all_hidden_states'],
|
| 135 |
+
decoder_input_ids=None,
|
| 136 |
+
attention_mask=attention_mask
|
| 137 |
+
)
|
| 138 |
|
| 139 |
+
# Combine outputs with prefix to avoid key collision (GPT suggestion)
|
| 140 |
+
outputs = {}
|
| 141 |
+
for key, value in encoder_outputs.items():
|
| 142 |
+
outputs[f'enc_{key}'] = value
|
| 143 |
+
for key, value in decoder_outputs.items():
|
| 144 |
+
outputs[f'dec_{key}'] = value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
+
# Compute loss if requested
|
| 147 |
+
if return_loss and labels is not None:
|
| 148 |
+
loss = self.compute_loss(outputs, labels, attention_mask)
|
| 149 |
+
outputs['loss'] = loss
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
+
return outputs
|
|
|
|
|
|
|
| 152 |
|
| 153 |
+
def compute_loss(self,
|
| 154 |
+
outputs: Dict[str, torch.Tensor],
|
| 155 |
+
labels: torch.Tensor,
|
| 156 |
+
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 157 |
+
"""
|
| 158 |
+
Compute combined loss with multiple objectives
|
| 159 |
|
| 160 |
+
Components:
|
| 161 |
+
1. Reconstruction loss (cross-entropy)
|
| 162 |
+
2. Compression loss (encourage higher compression)
|
| 163 |
+
3. Boundary loss (boundary prediction accuracy)
|
| 164 |
+
"""
|
| 165 |
+
losses = {}
|
| 166 |
+
|
| 167 |
+
# 1. Reconstruction loss (GPT suggestion: use shifted targets)
|
| 168 |
+
if 'dec_logits' in outputs:
|
| 169 |
+
logits = outputs['dec_logits']
|
| 170 |
+
|
| 171 |
+
# Shift targets for next-token prediction
|
| 172 |
+
target_labels = labels[:, 1:] if labels.dim() > 1 else labels[1:]
|
| 173 |
+
target_mask = attention_mask[:, 1:] if attention_mask is not None and attention_mask.dim() > 1 else None
|
| 174 |
+
|
| 175 |
+
# Reshape for cross-entropy
|
| 176 |
+
batch_size, seq_len, vocab_size = logits.shape
|
| 177 |
+
logits_flat = logits.reshape(-1, vocab_size)
|
| 178 |
+
labels_flat = target_labels.reshape(-1)
|
| 179 |
+
|
| 180 |
+
# Mask out padding (GPT suggestion: use bool mask)
|
| 181 |
+
if target_mask is not None:
|
| 182 |
+
mask_flat = target_mask.reshape(-1).bool()
|
| 183 |
+
reconstruction_loss = F.cross_entropy(
|
| 184 |
+
logits_flat[mask_flat],
|
| 185 |
+
labels_flat[mask_flat],
|
| 186 |
+
ignore_index=self.tokenizer.PAD,
|
| 187 |
+
label_smoothing=0.1 # Added label smoothing
|
| 188 |
+
)
|
| 189 |
+
else:
|
| 190 |
+
reconstruction_loss = F.cross_entropy(
|
| 191 |
+
logits_flat,
|
| 192 |
+
labels_flat,
|
| 193 |
+
ignore_index=self.tokenizer.PAD,
|
| 194 |
+
label_smoothing=0.1
|
| 195 |
+
)
|
| 196 |
|
| 197 |
+
losses['reconstruction'] = reconstruction_loss * self.reconstruction_weight
|
|
|
|
| 198 |
|
| 199 |
+
# 2. Compression loss (GPT suggestion: use proper device tensor creation)
|
| 200 |
+
if 'enc_compression_ratio' in outputs:
|
| 201 |
+
# Target compression ratio (e.g., 24:1 as per config)
|
| 202 |
+
target_ratio = 24.0
|
| 203 |
+
current_ratio = outputs['enc_compression_ratio']
|
| 204 |
|
| 205 |
+
# Create tensors on same device (GPT suggestion)
|
| 206 |
+
if isinstance(current_ratio, (int, float)):
|
| 207 |
+
current_ratio_tensor = labels.new_tensor(current_ratio, dtype=torch.float32)
|
| 208 |
+
else:
|
| 209 |
+
current_ratio_tensor = current_ratio.float()
|
| 210 |
+
target_ratio_tensor = labels.new_tensor(target_ratio, dtype=torch.float32)
|
| 211 |
+
|
| 212 |
+
# Penalize deviation from target (use smooth L1 to avoid explosion)
|
| 213 |
+
compression_loss = F.smooth_l1_loss(
|
| 214 |
+
current_ratio_tensor,
|
| 215 |
+
target_ratio_tensor,
|
| 216 |
+
beta=2.0 # Transition point from L2 to L1
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
losses['compression'] = compression_loss * self.compression_weight
|
| 220 |
+
|
| 221 |
+
# 3. Boundary loss (GPT suggestion: more meaningful boundary learning)
|
| 222 |
+
if 'enc_boundaries' in outputs and outputs['enc_boundaries'] is not None:
|
| 223 |
+
boundary_scores = outputs['enc_boundaries']
|
| 224 |
+
|
| 225 |
+
# Boundary sparsity + smoothness (GPT suggestion)
|
| 226 |
+
# Encourage sparse but clear boundaries
|
| 227 |
+
boundary_probs = torch.sigmoid(boundary_scores)
|
| 228 |
|
| 229 |
+
# Sparsity loss (boundaries should be rare)
|
| 230 |
+
sparsity_loss = boundary_probs.mean() * 0.1
|
|
|
|
|
|
|
|
|
|
| 231 |
|
| 232 |
+
# Smoothness loss (adjacent boundaries should be different)
|
| 233 |
+
if boundary_scores.size(1) > 1:
|
| 234 |
+
diff = boundary_scores[:, 1:] - boundary_scores[:, :-1]
|
| 235 |
+
smoothness_loss = (diff ** 2).mean() * 0.01
|
| 236 |
else:
|
| 237 |
+
smoothness_loss = 0.0
|
| 238 |
+
|
| 239 |
+
boundary_loss = sparsity_loss + smoothness_loss
|
| 240 |
+
|
| 241 |
+
losses['boundary'] = boundary_loss * self.boundary_weight
|
| 242 |
+
|
| 243 |
+
# Combine all losses
|
| 244 |
+
total_loss = sum(losses.values())
|
| 245 |
+
|
| 246 |
+
# Store individual losses for monitoring
|
| 247 |
+
self.last_losses = losses
|
| 248 |
+
|
| 249 |
+
return total_loss
|
| 250 |
+
|
| 251 |
+
def generate(self,
|
| 252 |
+
text: str = None,
|
| 253 |
+
input_ids: torch.Tensor = None,
|
| 254 |
+
max_length: int = 48,
|
| 255 |
+
temperature: float = 0.1,
|
| 256 |
+
top_k: int = 10,
|
| 257 |
+
top_p: float = 0.95) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
"""
|
| 259 |
+
Generate/reconstruct text
|
| 260 |
+
|
| 261 |
Args:
|
| 262 |
+
text: Input text to encode and reconstruct
|
| 263 |
+
input_ids: Pre-encoded input
|
| 264 |
+
max_length: Maximum generation length
|
| 265 |
+
temperature: Sampling temperature
|
| 266 |
+
top_k: Top-k sampling
|
| 267 |
+
top_p: Top-p (nucleus) sampling
|
| 268 |
+
|
| 269 |
+
Returns:
|
| 270 |
+
Reconstructed/generated text
|
| 271 |
"""
|
| 272 |
+
# Encode input if text is provided (GPT suggestion: handle multi-chunk properly)
|
| 273 |
+
chunk_positions = None
|
| 274 |
+
if text is not None:
|
| 275 |
+
# Check if text needs chunking
|
| 276 |
+
if len(text.encode('utf-8')) > self.tokenizer.content_size:
|
| 277 |
+
encoded = self.tokenizer.encode(text, add_special_tokens=True, return_chunks=True)
|
| 278 |
+
chunk_positions = encoded.get('chunk_positions', None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
else:
|
| 280 |
+
encoded = self.tokenizer.encode(text, add_special_tokens=True)
|
| 281 |
+
|
| 282 |
+
input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
|
| 283 |
+
attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
else:
|
| 285 |
+
attention_mask = (input_ids != self.tokenizer.PAD).bool() # GPT suggestion: bool mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
+
# Move to device
|
| 288 |
+
device = next(self.parameters()).device
|
| 289 |
+
input_ids = input_ids.to(device)
|
| 290 |
+
attention_mask = attention_mask.to(device)
|
| 291 |
|
| 292 |
+
# Encode
|
| 293 |
+
with torch.no_grad():
|
| 294 |
+
encoder_outputs = self.encoder(
|
| 295 |
+
input_ids=input_ids,
|
| 296 |
+
attention_mask=attention_mask
|
| 297 |
+
)
|
|
|
|
|
|
|
| 298 |
|
| 299 |
+
# Prepare all hidden states for decoder
|
| 300 |
+
if 'all_hidden_states' in encoder_outputs:
|
| 301 |
+
encoder_all_hidden = encoder_outputs['all_hidden_states']
|
| 302 |
+
else:
|
| 303 |
+
compressed = encoder_outputs.get('compressed', encoder_outputs.get('hidden_states'))
|
| 304 |
+
encoder_all_hidden = [compressed] * 4
|
| 305 |
+
|
| 306 |
+
# Autoregressive generation (fixed version)
|
| 307 |
+
batch_size = input_ids.size(0)
|
| 308 |
+
|
| 309 |
+
# Start with BOS token
|
| 310 |
+
generated_ids = torch.full((batch_size, 1), self.tokenizer.BOS, device=device)
|
| 311 |
+
|
| 312 |
+
for step in range(max_length - 1):
|
| 313 |
+
with torch.no_grad():
|
| 314 |
+
# Decode current sequence
|
| 315 |
+
decoder_outputs = self.decoder(
|
| 316 |
+
encoder_all_hidden=encoder_all_hidden,
|
| 317 |
+
decoder_input_ids=generated_ids,
|
| 318 |
+
attention_mask=torch.ones_like(generated_ids),
|
| 319 |
+
use_cache=False
|
| 320 |
+
)
|
| 321 |
|
| 322 |
+
# Get next token prediction
|
| 323 |
+
logits = decoder_outputs['logits'][:, -1, :] / temperature
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
|
| 325 |
+
# Top-k filtering
|
| 326 |
+
if top_k > 0:
|
| 327 |
+
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
| 328 |
+
logits[indices_to_remove] = float('-inf')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
|
| 330 |
+
# Sample next token
|
| 331 |
+
probs = F.softmax(logits, dim=-1)
|
| 332 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 333 |
|
| 334 |
+
# Append to generated sequence
|
| 335 |
+
generated_ids = torch.cat([generated_ids, next_token], dim=1)
|
|
|
|
|
|
|
| 336 |
|
| 337 |
+
# Check for EOS
|
| 338 |
+
if (next_token == self.tokenizer.EOS).all():
|
| 339 |
+
break
|
| 340 |
+
|
| 341 |
+
# Decode to text (GPT suggestion: proper multi-chunk reconstruction)
|
| 342 |
+
if generated_ids.dim() > 2 and chunk_positions is not None:
|
| 343 |
+
# Multi-chunk output with positions
|
| 344 |
+
text = self.tokenizer.reconstruct(
|
| 345 |
+
generated_ids,
|
| 346 |
+
positions=chunk_positions,
|
| 347 |
+
overlap=self.tokenizer.chunk_overlap
|
| 348 |
+
)
|
| 349 |
+
elif generated_ids.dim() > 2:
|
| 350 |
+
# Multi-chunk without positions (fallback)
|
| 351 |
+
text = self.tokenizer.reconstruct(generated_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
else:
|
| 353 |
+
# Single sequence
|
| 354 |
+
text = self.tokenizer.decode(generated_ids[0] if generated_ids.dim() > 1 else generated_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
|
| 356 |
+
return text
|
|
|
|
| 357 |
|
| 358 |
+
def compress(self, text: str) -> Dict[str, Union[torch.Tensor, float]]:
|
| 359 |
+
"""
|
| 360 |
+
Compress text and return compression statistics
|
| 361 |
|
| 362 |
+
Args:
|
| 363 |
+
text: Input text to compress
|
|
|
|
|
|
|
| 364 |
|
| 365 |
+
Returns:
|
| 366 |
+
Dictionary with compressed representation and statistics
|
| 367 |
+
"""
|
| 368 |
+
# Encode text
|
| 369 |
+
encoded = self.tokenizer.encode(text, add_special_tokens=True)
|
| 370 |
+
input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
|
| 371 |
+
attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
|
| 372 |
|
| 373 |
+
# Move to device
|
| 374 |
+
device = next(self.parameters()).device
|
| 375 |
+
input_ids = input_ids.to(device)
|
| 376 |
+
attention_mask = attention_mask.to(device)
|
| 377 |
|
| 378 |
+
# Get compressed representation
|
| 379 |
+
with torch.no_grad():
|
| 380 |
+
encoder_outputs = self.encoder(
|
| 381 |
+
input_ids=input_ids,
|
| 382 |
+
attention_mask=attention_mask
|
| 383 |
+
)
|
| 384 |
|
| 385 |
+
return {
|
| 386 |
+
'compressed': encoder_outputs['compressed'],
|
| 387 |
+
'num_tokens': encoder_outputs['num_tokens'],
|
| 388 |
+
'compression_ratio': encoder_outputs['compression_ratio'],
|
| 389 |
+
'original_bytes': len(text.encode('utf-8')),
|
| 390 |
+
'compressed_size': encoder_outputs['num_tokens'] * 2 # Approximate bytes
|
| 391 |
+
}
|
| 392 |
|
| 393 |
+
def update_training_state(self, epoch: int, step: int = 0, reconstruction_loss: float = None):
|
| 394 |
+
"""
|
| 395 |
+
Update training state - adaptive, not phase-based
|
| 396 |
|
| 397 |
+
Args:
|
| 398 |
+
epoch: Current epoch
|
| 399 |
+
step: Current training step
|
| 400 |
+
reconstruction_loss: Current reconstruction quality
|
| 401 |
+
"""
|
| 402 |
+
self.current_epoch = torch.tensor(epoch)
|
| 403 |
+
self.training_step = torch.tensor(step)
|
| 404 |
+
|
| 405 |
+
# Update encoder warmup (gates only)
|
| 406 |
+
self.encoder.set_warmup_step(step)
|
| 407 |
+
|
| 408 |
+
# Adaptive weight adjustment based on performance
|
| 409 |
+
if reconstruction_loss is not None:
|
| 410 |
+
# If reconstruction is poor, increase its weight
|
| 411 |
+
if reconstruction_loss > 1.0:
|
| 412 |
+
self.reconstruction_weight = 1.0
|
| 413 |
+
self.compression_weight = 0.1 # Less compression focus
|
| 414 |
+
else:
|
| 415 |
+
# Good reconstruction, can focus on compression
|
| 416 |
+
self.reconstruction_weight = 0.5
|
| 417 |
+
self.compression_weight = 0.1
|
| 418 |
|
| 419 |
+
# Boundary weight stays moderate
|
| 420 |
+
self.boundary_weight = 0.1
|
| 421 |
|
| 422 |
+
# Let encoder know about reconstruction quality
|
| 423 |
+
self.encoder.adaptive_compression_control(reconstruction_loss)
|
| 424 |
+
else:
|
| 425 |
+
# Default balanced weights
|
| 426 |
+
self.reconstruction_weight = 0.5
|
| 427 |
+
self.compression_weight = 0.1
|
| 428 |
+
self.boundary_weight = 0.1
|
| 429 |
|
| 430 |
+
def get_model_stats(self) -> Dict[str, float]:
|
| 431 |
+
"""
|
| 432 |
+
Get model statistics for monitoring
|
| 433 |
|
| 434 |
+
Returns:
|
| 435 |
+
Dictionary with various model statistics
|
| 436 |
+
"""
|
| 437 |
+
stats = {}
|
| 438 |
|
| 439 |
+
# Encoder stats (GPT suggestion: already prefixed)
|
| 440 |
+
encoder_stats = self.encoder.get_monitoring_stats()
|
| 441 |
+
stats.update({f'encoder_{k}': v for k, v in encoder_stats.items()})
|
| 442 |
|
| 443 |
+
# Decoder memory stats
|
| 444 |
+
decoder_memory = self.decoder.get_memory_usage()
|
| 445 |
+
stats.update({f'decoder_{k}': v for k, v in decoder_memory.items()})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
|
| 447 |
+
# Loss stats (if available) - check for tensor items
|
| 448 |
+
if hasattr(self, 'last_losses'):
|
| 449 |
+
for k, v in self.last_losses.items():
|
| 450 |
+
if isinstance(v, torch.Tensor):
|
| 451 |
+
stats[f'loss_{k}'] = v.item() if v.numel() == 1 else v.mean().item()
|
| 452 |
+
else:
|
| 453 |
+
stats[f'loss_{k}'] = float(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
|
| 455 |
+
# Training info
|
| 456 |
+
stats['current_epoch'] = self.current_epoch.item()
|
| 457 |
+
stats['training_step'] = self.training_step.item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
|
| 459 |
+
return stats
|
| 460 |
+
|
| 461 |
+
def save_checkpoint(self, path: str):
|
| 462 |
+
"""
|
| 463 |
+
Save model checkpoint
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
path: Path to save checkpoint
|
| 467 |
+
"""
|
| 468 |
+
checkpoint = {
|
| 469 |
+
'model_state_dict': self.state_dict(),
|
| 470 |
+
'config': self.config,
|
| 471 |
+
'epoch': self.current_epoch.item(),
|
| 472 |
+
'step': self.training_step.item(),
|
| 473 |
+
'stats': self.get_model_stats()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
}
|
| 475 |
+
torch.save(checkpoint, path)
|
| 476 |
+
print(f"Checkpoint saved to {path}")
|
| 477 |
+
|
| 478 |
+
@classmethod
|
| 479 |
+
def from_checkpoint(cls, path: str, device: str = 'cuda'):
|
| 480 |
+
"""
|
| 481 |
+
Load model from checkpoint
|
| 482 |
+
|
| 483 |
+
Args:
|
| 484 |
+
path: Path to checkpoint
|
| 485 |
+
device: Device to load model on
|
| 486 |
+
|
| 487 |
+
Returns:
|
| 488 |
+
Loaded model instance
|
| 489 |
+
"""
|
| 490 |
+
checkpoint = torch.load(path, map_location=device)
|
| 491 |
+
|
| 492 |
+
# Create model with saved config
|
| 493 |
+
model = cls(checkpoint.get('config', {}))
|
| 494 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 495 |
+
model.to(device)
|
| 496 |
+
|
| 497 |
+
# Restore training state
|
| 498 |
+
if 'epoch' in checkpoint:
|
| 499 |
+
model.current_epoch = torch.tensor(checkpoint['epoch'])
|
| 500 |
+
if 'step' in checkpoint:
|
| 501 |
+
model.training_step = torch.tensor(checkpoint['step'])
|
| 502 |
+
|
| 503 |
+
print(f"Model loaded from {path} (Epoch {checkpoint.get('epoch', 0)})")
|
| 504 |
+
return model
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
if __name__ == "__main__":
|
| 508 |
+
# Test unified model
|
| 509 |
+
print("Testing Intelligent Tokenizer v6.2.0")
|
| 510 |
+
|
| 511 |
+
# Create model
|
| 512 |
+
model = IntelligentTokenizerV62()
|
| 513 |
+
print(f"Model created with {sum(p.numel() for p in model.parameters())/1e6:.1f}M parameters")
|
| 514 |
+
|
| 515 |
+
# Test texts
|
| 516 |
+
test_texts = [
|
| 517 |
+
"Hello, world!",
|
| 518 |
+
"안녕하세요, 만나서 반갑습니다. 오늘 날씨가 좋네요!",
|
| 519 |
+
"今天天气很好。",
|
| 520 |
+
]
|
| 521 |
+
|
| 522 |
+
for text in test_texts:
|
| 523 |
+
print(f"\nInput: {text}")
|
| 524 |
+
|
| 525 |
+
# Compress
|
| 526 |
+
compression = model.compress(text)
|
| 527 |
+
print(f" Compression ratio: {compression['compression_ratio']:.1f}:1")
|
| 528 |
+
print(f" Tokens: {compression['num_tokens']}")
|
| 529 |
+
|
| 530 |
+
# Generate (reconstruct)
|
| 531 |
+
reconstructed = model.generate(text, temperature=0.1)
|
| 532 |
+
print(f" Reconstructed: {reconstructed}")
|
| 533 |
+
|
| 534 |
+
# Get model stats
|
| 535 |
+
stats = model.get_model_stats()
|
| 536 |
+
print(f"\nModel Statistics:")
|
| 537 |
+
for key, value in stats.items():
|
| 538 |
+
if isinstance(value, float):
|
| 539 |
+
print(f" {key}: {value:.4f}")
|
| 540 |
+
else:
|
| 541 |
+
print(f" {key}: {value}")
|