ggunio commited on
Commit
ff85374
·
verified ·
1 Parent(s): 8f9e907

Upload folder using huggingface_hub

Browse files
core/__pycache__/decoder.cpython-310.pyc ADDED
Binary file (11.8 kB). View file
 
core/__pycache__/decoder.cpython-313.pyc ADDED
Binary file (20.8 kB). View file
 
core/__pycache__/encoder.cpython-310.pyc ADDED
Binary file (13.4 kB). View file
 
core/__pycache__/encoder.cpython-313.pyc ADDED
Binary file (24.8 kB). View file
 
core/__pycache__/tokenizer.cpython-310.pyc ADDED
Binary file (12.3 kB). View file
 
core/__pycache__/tokenizer.cpython-313.pyc ADDED
Binary file (19.8 kB). View file
 
core/__pycache__/unified_model.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
core/decoder.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Intelligent Tokenizer v6.2.0 - 6-Layer Decoder with Multi-Level Cross-Attention
3
+ Incorporates GPT-5 suggestions for KV cache optimization
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import Dict, List, Optional, Tuple
10
+ import math
11
+
12
+
13
+ class KVCacheOptimizedAttention(nn.Module):
14
+ """
15
+ KV Cache Optimized Attention - GPT-5 suggestion
16
+ 16Q → 2K/V for 8x memory reduction
17
+ """
18
+
19
+ def __init__(self, hidden_dim: int = 1280, num_heads: int = 16, kv_compression: int = 8):
20
+ super().__init__()
21
+ self.hidden_dim = hidden_dim
22
+ self.num_heads = num_heads
23
+ self.kv_heads = max(2, num_heads // kv_compression) # 16/8 = 2 KV heads
24
+ self.head_dim = hidden_dim // num_heads # 80
25
+
26
+ # Query uses all heads
27
+ self.q_proj = nn.Linear(hidden_dim, hidden_dim) # 16 heads
28
+
29
+ # Key/Value use fewer heads (GPT-5 suggestion)
30
+ self.k_proj = nn.Linear(hidden_dim, self.kv_heads * self.head_dim) # 2 heads
31
+ self.v_proj = nn.Linear(hidden_dim, self.kv_heads * self.head_dim) # 2 heads
32
+
33
+ # Output projection
34
+ self.o_proj = nn.Linear(hidden_dim, hidden_dim)
35
+
36
+ # KV cache for inference
37
+ self.register_buffer('cached_keys', None)
38
+ self.register_buffer('cached_values', None)
39
+
40
+ def forward(self,
41
+ hidden_states: torch.Tensor,
42
+ encoder_hidden: Optional[torch.Tensor] = None,
43
+ attention_mask: Optional[torch.Tensor] = None,
44
+ use_cache: bool = False) -> Tuple[torch.Tensor, Optional[Tuple]]:
45
+ """
46
+ Forward pass with KV cache optimization
47
+ """
48
+ batch_size, seq_len = hidden_states.shape[:2]
49
+
50
+ # Query projection (all heads)
51
+ Q = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim)
52
+ Q = Q.transpose(1, 2) # [batch, heads, seq, dim]
53
+
54
+ # Key/Value source (self or cross)
55
+ kv_source = encoder_hidden if encoder_hidden is not None else hidden_states
56
+
57
+ # Key/Value projection (fewer heads)
58
+ K = self.k_proj(kv_source).view(batch_size, -1, self.kv_heads, self.head_dim)
59
+ V = self.v_proj(kv_source).view(batch_size, -1, self.kv_heads, self.head_dim)
60
+ K = K.transpose(1, 2) # [batch, kv_heads, seq, dim]
61
+ V = V.transpose(1, 2)
62
+
63
+ # Repeat KV heads to match Q heads (broadcast)
64
+ K = K.repeat_interleave(self.num_heads // self.kv_heads, dim=1)
65
+ V = V.repeat_interleave(self.num_heads // self.kv_heads, dim=1)
66
+
67
+ # Cache management for incremental generation (GPT suggestion)
68
+ if use_cache:
69
+ # For incremental generation, only process new token
70
+ if self.cached_keys is not None and hidden_states.size(1) == 1:
71
+ # Append new K/V to cache
72
+ K = torch.cat([self.cached_keys, K], dim=2)
73
+ V = torch.cat([self.cached_values, V], dim=2)
74
+ # Update cache
75
+ self.cached_keys = K
76
+ self.cached_values = V
77
+
78
+ # Scaled dot-product attention
79
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
80
+
81
+ # Use additive mask (GPT suggestion)
82
+ if attention_mask is not None:
83
+ scores = scores + attention_mask # additive mask: -inf where masked, 0 elsewhere
84
+
85
+ attn_weights = F.softmax(scores, dim=-1)
86
+ attn_output = torch.matmul(attn_weights, V)
87
+
88
+ # Reshape and project
89
+ attn_output = attn_output.transpose(1, 2).contiguous()
90
+ attn_output = attn_output.view(batch_size, seq_len, self.hidden_dim)
91
+ output = self.o_proj(attn_output)
92
+
93
+ return output, (K, V) if use_cache else None
94
+
95
+
96
+ class SelectiveCrossAttention(nn.Module):
97
+ """
98
+ Selective cross-attention - only attend to relevant encoder layers
99
+ Reduces 24 → 8 cross-attentions for efficiency
100
+ """
101
+
102
+ def __init__(self, hidden_dim: int = 1280, layer_id: int = 0):
103
+ super().__init__()
104
+ self.hidden_dim = hidden_dim
105
+ self.layer_id = layer_id
106
+
107
+ # Define which encoder layers this decoder layer should attend to
108
+ self.encoder_connections = {
109
+ 0: [0], # Decoder L0 → Encoder L0 (byte info)
110
+ 1: [0], # Decoder L1 → Encoder L0 (byte info)
111
+ 2: [1, 2], # Decoder L2 → Encoder L1,2 (language info)
112
+ 3: [1, 2], # Decoder L3 → Encoder L1,2 (language info)
113
+ 4: [3], # Decoder L4 → Encoder L3 (semantic info)
114
+ 5: [3], # Decoder L5 → Encoder L3 (semantic info)
115
+ }
116
+
117
+ # Get connections for this layer
118
+ self.connected_layers = self.encoder_connections.get(layer_id, [0])
119
+
120
+ # Create attention modules only for connected layers
121
+ self.cross_attentions = nn.ModuleList([
122
+ KVCacheOptimizedAttention(hidden_dim, num_heads=16, kv_compression=8)
123
+ for _ in self.connected_layers
124
+ ])
125
+
126
+ # Lightweight fusion with weighted sum (GPT suggestion)
127
+ self.fusion = nn.Sequential(
128
+ nn.Linear(hidden_dim, hidden_dim),
129
+ nn.LayerNorm(hidden_dim),
130
+ nn.SiLU(),
131
+ nn.Dropout(0.1)
132
+ )
133
+
134
+ # Learnable weights for connected layers only
135
+ self.layer_weights = nn.Parameter(torch.ones(len(self.connected_layers)) / len(self.connected_layers))
136
+
137
+ def forward(self,
138
+ decoder_hidden: torch.Tensor,
139
+ encoder_all_hidden: List[torch.Tensor],
140
+ attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
141
+ """
142
+ Selectively attend to relevant encoder layers only
143
+ """
144
+ # Only attend to connected encoder layers
145
+ cross_outputs = []
146
+ for i, layer_idx in enumerate(self.connected_layers):
147
+ if layer_idx < len(encoder_all_hidden):
148
+ encoder_hidden = encoder_all_hidden[layer_idx]
149
+ cross_out, _ = self.cross_attentions[i](
150
+ hidden_states=decoder_hidden,
151
+ encoder_hidden=encoder_hidden,
152
+ attention_mask=attention_mask
153
+ )
154
+ cross_outputs.append(cross_out)
155
+
156
+ # Weighted sum fusion for connected layers only
157
+ if len(cross_outputs) > 1:
158
+ weighted_outputs = torch.stack(cross_outputs, dim=0) # [N, batch, seq, hidden]
159
+ weights = F.softmax(self.layer_weights, dim=0).view(-1, 1, 1, 1)
160
+ fused = (weighted_outputs * weights).sum(dim=0) # [batch, seq, hidden]
161
+ else:
162
+ # Single connection - no fusion needed
163
+ fused = cross_outputs[0] if cross_outputs else decoder_hidden
164
+
165
+ # Apply lightweight fusion layer
166
+ fused = self.fusion(fused)
167
+
168
+ return fused
169
+
170
+
171
+ class SwiGLU(nn.Module):
172
+ """SwiGLU activation for better convergence (GPT suggestion)"""
173
+ def __init__(self, dim: int, mult: float = 2.66):
174
+ super().__init__()
175
+ inner = int(round(dim * mult / 2)) * 2 # Even alignment
176
+ self.w1 = nn.Linear(dim, inner // 2)
177
+ self.w2 = nn.Linear(dim, inner // 2)
178
+ self.w3 = nn.Linear(inner // 2, dim)
179
+
180
+ def forward(self, x):
181
+ return self.w3(F.silu(self.w1(x)) * self.w2(x))
182
+
183
+
184
+ class DecoderLayer(nn.Module):
185
+ """
186
+ Single decoder layer with self-attention and selective cross-attention
187
+ """
188
+
189
+ def __init__(self, hidden_dim: int = 1280, num_heads: int = 16, layer_id: int = 0):
190
+ super().__init__()
191
+ self.hidden_dim = hidden_dim
192
+ self.layer_id = layer_id
193
+
194
+ # Self-attention (with KV cache optimization)
195
+ self.self_attn = KVCacheOptimizedAttention(hidden_dim, num_heads, kv_compression=8)
196
+ self.self_attn_norm = nn.LayerNorm(hidden_dim)
197
+
198
+ # Selective cross-attention to specific encoder layers
199
+ self.cross_attn = SelectiveCrossAttention(hidden_dim, layer_id=layer_id)
200
+ self.cross_attn_norm = nn.LayerNorm(hidden_dim)
201
+
202
+ # Feed-forward network with SwiGLU (GPT suggestion)
203
+ self.ffn = SwiGLU(hidden_dim, mult=2.66)
204
+ self.ffn_norm = nn.LayerNorm(hidden_dim)
205
+
206
+ # Dropout for residual connections
207
+ self.dropout = nn.Dropout(0.1)
208
+
209
+ def forward(self,
210
+ hidden_states: torch.Tensor,
211
+ encoder_all_hidden: List[torch.Tensor],
212
+ self_attention_mask: Optional[torch.Tensor] = None,
213
+ cross_attention_mask: Optional[torch.Tensor] = None,
214
+ use_cache: bool = False) -> Tuple[torch.Tensor, Optional[Tuple]]:
215
+ """
216
+ Forward pass through decoder layer
217
+ """
218
+ # Self-attention with residual
219
+ residual = hidden_states
220
+ hidden_states = self.self_attn_norm(hidden_states)
221
+ self_attn_out, cache = self.self_attn(
222
+ hidden_states,
223
+ attention_mask=self_attention_mask,
224
+ use_cache=use_cache
225
+ )
226
+ hidden_states = residual + self.dropout(self_attn_out)
227
+
228
+ # Cross-attention with residual
229
+ residual = hidden_states
230
+ hidden_states = self.cross_attn_norm(hidden_states)
231
+ cross_attn_out = self.cross_attn(
232
+ hidden_states,
233
+ encoder_all_hidden,
234
+ attention_mask=cross_attention_mask
235
+ )
236
+ hidden_states = residual + self.dropout(cross_attn_out)
237
+
238
+ # FFN with residual
239
+ residual = hidden_states
240
+ hidden_states = self.ffn_norm(hidden_states)
241
+ ffn_out = self.ffn(hidden_states)
242
+ hidden_states = residual + self.dropout(ffn_out)
243
+
244
+ return hidden_states, cache
245
+
246
+
247
+ class DecoderV62(nn.Module):
248
+ """
249
+ 6-Layer Decoder with Multi-Level Cross-Attention
250
+ Reduced from 8 layers but compensated with better cross-attention
251
+ """
252
+
253
+ def __init__(self, config: Optional[Dict] = None):
254
+ super().__init__()
255
+
256
+ # Configuration
257
+ self.hidden_dim = 1280
258
+ self.num_heads = 16
259
+ self.num_layers = 6 # Reduced from 8
260
+ self.vocab_size = 260 # 256 bytes + special tokens
261
+ self.max_seq_len = 48
262
+
263
+ # Token constants (GPT suggestion - explicit constants)
264
+ self.PAD = 256
265
+ self.BOS = 257
266
+ self.EOS = 258
267
+ self.MASK = 259
268
+
269
+ # Token embedding and position encoding
270
+ self.token_embedding = nn.Embedding(self.vocab_size, self.hidden_dim)
271
+ self.position_embedding = nn.Embedding(self.max_seq_len, self.hidden_dim)
272
+
273
+ # 6 decoder layers with layer-specific cross-attention
274
+ self.layers = nn.ModuleList([
275
+ DecoderLayer(self.hidden_dim, self.num_heads, layer_id=i)
276
+ for i in range(self.num_layers)
277
+ ])
278
+
279
+ # Output projection
280
+ self.output_norm = nn.LayerNorm(self.hidden_dim)
281
+ self.output_projection = nn.Linear(self.hidden_dim, self.vocab_size)
282
+
283
+ # Monitoring (GPT-5 suggestion)
284
+ # Track importance of ENCODER layers (4) used by decoder
285
+ self.register_buffer('layer_importance', torch.zeros(4)) # Track importance of 4 encoder layers
286
+
287
+ def forward(self,
288
+ encoder_all_hidden: List[torch.Tensor],
289
+ decoder_input_ids: Optional[torch.Tensor] = None,
290
+ attention_mask: Optional[torch.Tensor] = None,
291
+ use_cache: bool = False,
292
+ past_key_values: Optional[List] = None) -> Dict[str, torch.Tensor]:
293
+ """
294
+ Forward pass through decoder
295
+
296
+ Args:
297
+ encoder_all_hidden: All encoder layer outputs (4 layers)
298
+ decoder_input_ids: Input token IDs for teacher forcing
299
+ attention_mask: Attention mask
300
+ use_cache: Whether to cache KV for inference
301
+ past_key_values: Cached KV from previous steps
302
+ """
303
+ batch_size = encoder_all_hidden[0].size(0)
304
+ device = encoder_all_hidden[0].device
305
+
306
+ # If no decoder input, start with compressed representation
307
+ if decoder_input_ids is None:
308
+ # Use encoder's final compressed output as starting point
309
+ hidden_states = encoder_all_hidden[-1] # [batch, M tokens, 1280]
310
+ seq_len = hidden_states.size(1)
311
+ else:
312
+ # Teacher forcing mode: use provided tokens
313
+ seq_len = decoder_input_ids.size(1)
314
+
315
+ # Embeddings
316
+ token_embeds = self.token_embedding(decoder_input_ids)
317
+ position_ids = torch.arange(seq_len, device=device).expand(batch_size, -1)
318
+ position_embeds = self.position_embedding(position_ids)
319
+
320
+ hidden_states = token_embeds + position_embeds
321
+
322
+ # Create causal mask for self-attention (additive mask - GPT suggestion)
323
+ causal_mask = torch.full((1, 1, seq_len, seq_len), float('-inf'), device=device)
324
+ causal_mask = torch.triu(causal_mask, diagonal=1) # [1, 1, seq, seq]
325
+
326
+ # Pass through decoder layers
327
+ all_hidden_states = []
328
+ all_caches = [] if use_cache else None
329
+
330
+ for i, layer in enumerate(self.layers):
331
+ # GPT final check: Create proper cross-attention mask for encoder hidden states
332
+ if encoder_all_hidden is not None and len(encoder_all_hidden) > 0:
333
+ S_enc = encoder_all_hidden[0].size(1) # Encoder sequence length
334
+ # Create additive mask (0 = attend, -inf = mask)
335
+ cross_mask = torch.zeros((batch_size, 1, 1, S_enc), device=hidden_states.device)
336
+ else:
337
+ cross_mask = None
338
+
339
+ hidden_states, cache = layer(
340
+ hidden_states,
341
+ encoder_all_hidden,
342
+ self_attention_mask=causal_mask,
343
+ cross_attention_mask=cross_mask, # Use proper cross mask
344
+ use_cache=use_cache
345
+ )
346
+
347
+ all_hidden_states.append(hidden_states)
348
+ if use_cache:
349
+ all_caches.append(cache)
350
+
351
+ # Final output projection
352
+ hidden_states = self.output_norm(hidden_states)
353
+ logits = self.output_projection(hidden_states)
354
+
355
+ # Update monitoring: track encoder layer importance
356
+ # (This would be computed based on cross-attention weights in practice)
357
+ with torch.no_grad():
358
+ # Simplified: assume equal importance for now
359
+ self.layer_importance = torch.tensor([0.25, 0.25, 0.25, 0.25])
360
+
361
+ outputs = {
362
+ 'logits': logits,
363
+ 'last_hidden_state': hidden_states,
364
+ 'all_hidden_states': all_hidden_states,
365
+ 'encoder_layer_importance': self.layer_importance
366
+ }
367
+
368
+ if use_cache:
369
+ outputs['past_key_values'] = all_caches
370
+
371
+ return outputs
372
+
373
+ def generate(self,
374
+ encoder_all_hidden: List[torch.Tensor],
375
+ max_length: int = 48,
376
+ temperature: float = 1.0,
377
+ top_k: int = 50,
378
+ top_p: float = 0.95) -> torch.Tensor:
379
+ """
380
+ Autoregressive generation
381
+ """
382
+ batch_size = encoder_all_hidden[0].size(0)
383
+ device = encoder_all_hidden[0].device
384
+
385
+ # Start with BOS token
386
+ generated = torch.full((batch_size, 1), self.BOS, device=device)
387
+
388
+ # Generate tokens one by one
389
+ past_key_values = None
390
+ for _ in range(max_length - 1):
391
+ # GPT optimization: Only pass last token for O(T) complexity
392
+ if past_key_values is not None:
393
+ decoder_input = generated[:, -1:] # Last token only
394
+ else:
395
+ decoder_input = generated # Full sequence for first step
396
+
397
+ outputs = self.forward(
398
+ encoder_all_hidden,
399
+ decoder_input_ids=decoder_input,
400
+ use_cache=True,
401
+ past_key_values=past_key_values
402
+ )
403
+
404
+ logits = outputs['logits'][:, -1, :] / temperature
405
+
406
+ # Top-k filtering
407
+ if top_k > 0:
408
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
409
+ logits[indices_to_remove] = float('-inf')
410
+
411
+ # Top-p (nucleus) filtering
412
+ if top_p < 1.0:
413
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
414
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
415
+
416
+ # Remove tokens with cumulative probability above threshold
417
+ sorted_indices_to_remove = cumulative_probs > top_p
418
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
419
+ sorted_indices_to_remove[..., 0] = 0
420
+
421
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
422
+ logits[indices_to_remove] = float('-inf')
423
+
424
+ # Sample
425
+ probs = F.softmax(logits, dim=-1)
426
+ next_token = torch.multinomial(probs, num_samples=1)
427
+
428
+ # Append to generated sequence
429
+ generated = torch.cat([generated, next_token], dim=1)
430
+
431
+ # Check for EOS
432
+ if (next_token == self.EOS).all():
433
+ break
434
+
435
+ past_key_values = outputs.get('past_key_values')
436
+
437
+ return generated
438
+
439
+ def get_memory_usage(self) -> Dict[str, float]:
440
+ """
441
+ Calculate memory usage with KV cache optimization (GPT-5 metric)
442
+ """
443
+ # Standard attention: 16 heads for K and V
444
+ standard_kv_memory = 2 * 16 * self.max_seq_len * 80 * 4 # bytes
445
+
446
+ # Optimized: 2 heads for K and V
447
+ optimized_kv_memory = 2 * 2 * self.max_seq_len * 80 * 4 # bytes
448
+
449
+ return {
450
+ 'standard_kv_mb': standard_kv_memory / (1024 * 1024),
451
+ 'optimized_kv_mb': optimized_kv_memory / (1024 * 1024),
452
+ 'reduction_ratio': standard_kv_memory / optimized_kv_memory,
453
+ 'total_params_m': sum(p.numel() for p in self.parameters()) / 1e6
454
+ }
455
+
456
+
457
+ if __name__ == "__main__":
458
+ # Test the decoder
459
+ decoder = DecoderV62()
460
+
461
+ # Simulate encoder outputs (4 layers, 6 tokens each)
462
+ batch_size = 2
463
+ num_tokens = 6 # After progressive splitting
464
+ hidden_dim = 1280
465
+
466
+ encoder_outputs = [
467
+ torch.randn(batch_size, num_tokens, hidden_dim)
468
+ for _ in range(4)
469
+ ]
470
+
471
+ # Test with teacher forcing
472
+ decoder_input = torch.randint(0, 256, (batch_size, 48))
473
+ output = decoder(encoder_outputs, decoder_input_ids=decoder_input)
474
+
475
+ print(f"Decoder output shape: {output['logits'].shape}")
476
+ print(f"Encoder layer importance: {output['encoder_layer_importance']}")
477
+
478
+ # Test generation
479
+ generated = decoder.generate(encoder_outputs, max_length=48)
480
+ print(f"Generated shape: {generated.shape}")
481
+
482
+ # Memory usage
483
+ memory_stats = decoder.get_memory_usage()
484
+ print(f"Memory optimization: {memory_stats['reduction_ratio']:.1f}x reduction")
485
+ print(f"Total parameters: {memory_stats['total_params_m']:.1f}M")
core/encoder.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Intelligent Tokenizer v6.2.0 - Progressive Splitting Encoder
3
+ With GPT-5 suggested improvements
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import Dict, List, Optional, Tuple
10
+ import math
11
+
12
+
13
+ class RoPEPositionalEncoding(nn.Module):
14
+ """
15
+ Rotary Position Embedding (RoPE) - GPT-5 suggestion
16
+ Better for handling chunk boundaries and variable sequence lengths
17
+ """
18
+
19
+ def __init__(self, dim: int, max_seq_len: int = 48, base: int = 10000):
20
+ super().__init__()
21
+ self.dim = dim
22
+ self.max_seq_len = max_seq_len
23
+ self.base = base
24
+
25
+ # Precompute sinusoidal frequencies
26
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
27
+ self.register_buffer('inv_freq', inv_freq)
28
+
29
+ # Precompute positional encodings
30
+ t = torch.arange(max_seq_len).type_as(self.inv_freq)
31
+ freqs = torch.outer(t, self.inv_freq)
32
+ self.register_buffer('cos_cached', freqs.cos())
33
+ self.register_buffer('sin_cached', freqs.sin())
34
+
35
+ def forward(self, x: torch.Tensor, seq_len: int = None) -> torch.Tensor:
36
+ """
37
+ Apply RoPE to input tensor
38
+ Handles chunk boundary corrections as suggested by GPT-5
39
+ """
40
+ if seq_len is None:
41
+ seq_len = x.shape[1]
42
+
43
+ # Get cached cos/sin values
44
+ cos = self.cos_cached[:seq_len]
45
+ sin = self.sin_cached[:seq_len]
46
+
47
+ # Apply rotary embedding
48
+ x_rot = self._apply_rotary_emb(x, cos, sin)
49
+
50
+ return x_rot
51
+
52
+ def _apply_rotary_emb(self, x, cos, sin):
53
+ """Apply rotary embedding to input"""
54
+ x1, x2 = x[..., ::2], x[..., 1::2]
55
+ x_rot = torch.stack([
56
+ x1 * cos - x2 * sin,
57
+ x1 * sin + x2 * cos
58
+ ], dim=-1).flatten(-2)
59
+ return x_rot
60
+
61
+
62
+ class GatedCrossAttention(nn.Module):
63
+ """
64
+ Gated Cross-Attention with MQA - GPT-5 suggestion
65
+ Monitor gate values for quality assessment
66
+ 16Q → 2K/V for 8x memory reduction
67
+ """
68
+
69
+ def __init__(self, hidden_dim: int = 1280, num_heads: int = 16, kv_heads: int = 2):
70
+ super().__init__()
71
+ self.hidden_dim = hidden_dim
72
+ self.num_heads = num_heads
73
+ self.kv_heads = kv_heads # Reduced KV heads (GPT suggestion)
74
+ self.head_dim = hidden_dim // num_heads # 80
75
+
76
+ # Multi-Query Attention projections
77
+ self.q_proj = nn.Linear(hidden_dim, hidden_dim) # 16 heads
78
+ self.k_proj = nn.Linear(hidden_dim, kv_heads * self.head_dim) # 2 heads
79
+ self.v_proj = nn.Linear(hidden_dim, kv_heads * self.head_dim) # 2 heads
80
+ self.o_proj = nn.Linear(hidden_dim, hidden_dim)
81
+
82
+ # Gating mechanism (GPT-5 suggestion)
83
+ self.gate = nn.Sequential(
84
+ nn.Linear(hidden_dim * 2, hidden_dim),
85
+ nn.Sigmoid()
86
+ )
87
+
88
+ # Gate monitoring (for analysis)
89
+ self.register_buffer('gate_values', torch.zeros(1))
90
+
91
+ # Warmup factor (GPT suggestion)
92
+ self.register_buffer('warmup_alpha', torch.tensor(1.0))
93
+
94
+ def forward(self,
95
+ query: torch.Tensor,
96
+ key: torch.Tensor,
97
+ value: torch.Tensor,
98
+ mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
99
+ """
100
+ Forward pass with gate monitoring
101
+ Returns: (output, gate_values)
102
+ """
103
+ batch_size, seq_len = query.shape[:2]
104
+
105
+ # Multi-head attention projections
106
+ Q = self.q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim)
107
+ K = self.k_proj(key).view(batch_size, -1, self.kv_heads, self.head_dim)
108
+ V = self.v_proj(value).view(batch_size, -1, self.kv_heads, self.head_dim)
109
+
110
+ # Transpose for attention computation
111
+ Q = Q.transpose(1, 2) # [batch, heads, seq, dim]
112
+ K = K.transpose(1, 2) # [batch, kv_heads, seq, dim]
113
+ V = V.transpose(1, 2)
114
+
115
+ # Repeat KV heads to match Q heads if necessary
116
+ if self.kv_heads < self.num_heads:
117
+ repeat_factor = self.num_heads // self.kv_heads
118
+ K = K.repeat_interleave(repeat_factor, dim=1)
119
+ V = V.repeat_interleave(repeat_factor, dim=1)
120
+
121
+ # Scaled dot-product attention
122
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
123
+
124
+ if mask is not None:
125
+ scores = scores.masked_fill(mask == 0, -1e9)
126
+
127
+ attn_weights = F.softmax(scores, dim=-1)
128
+ attn_output = torch.matmul(attn_weights, V)
129
+
130
+ # Reshape back
131
+ attn_output = attn_output.transpose(1, 2).contiguous()
132
+ attn_output = attn_output.view(batch_size, seq_len, self.hidden_dim)
133
+ attn_output = self.o_proj(attn_output)
134
+
135
+ # Gating mechanism
136
+ gate_input = torch.cat([query, attn_output], dim=-1)
137
+ gate_values = self.gate(gate_input)
138
+
139
+ # Store gate values for monitoring (keep tensor shape consistent)
140
+ self.gate_values[0] = gate_values.mean().detach()
141
+
142
+ # Apply gate with warmup factor (GPT suggestion)
143
+ gate_values = gate_values * self.warmup_alpha
144
+ output = gate_values * attn_output + (1 - gate_values) * query
145
+
146
+ return output, gate_values
147
+
148
+
149
+
150
+ class ProgressiveSplittingLayer(nn.Module):
151
+ """
152
+ Core innovation: 48 bytes → 1 token → N tokens → M tokens
153
+ """
154
+
155
+ def __init__(self, hidden_dim: int = 1280, config: Optional[Dict] = None):
156
+ super().__init__()
157
+ self.hidden_dim = hidden_dim
158
+ self.config = config or {}
159
+
160
+ # Dynamic splitting: 1~4 tokens for efficiency
161
+ # 48 bytes / 4 tokens = 12:1 compression (still beats BPE's 4:1)
162
+ self.min_tokens = 1 # 48:1 compression
163
+ self.max_tokens = 4 # 12:1 compression (still 3x better than BPE)
164
+
165
+ # Initial compression: 48 bytes → 1 super token
166
+ self.byte_embed = nn.Embedding(260, 64) # Small embedding
167
+ self.initial_compressor = nn.Sequential(
168
+ nn.Linear(48 * 64, 2048),
169
+ nn.LayerNorm(2048),
170
+ nn.ReLU(),
171
+ nn.Dropout(0.1),
172
+ nn.Linear(2048, hidden_dim),
173
+ nn.LayerNorm(hidden_dim)
174
+ )
175
+
176
+ # Language-aware splitting: 1 → N tokens (config-based)
177
+ self.language_splitter = nn.ModuleDict({
178
+ 'analyzer': nn.Sequential(
179
+ nn.Linear(hidden_dim, 512),
180
+ nn.ReLU(),
181
+ nn.Linear(512, 256) # Language features
182
+ ),
183
+ 'split_predictor': nn.Linear(256, self.max_tokens), # Predict 1~4 tokens
184
+ # Single unified expander that can produce any number of tokens
185
+ 'dynamic_expander': nn.Sequential(
186
+ nn.Linear(hidden_dim, hidden_dim * 2),
187
+ nn.LayerNorm(hidden_dim * 2),
188
+ nn.GELU(), # Better than ReLU for transformers
189
+ nn.Linear(hidden_dim * 2, hidden_dim * self.max_tokens) # Can produce up to 4 tokens
190
+ ),
191
+ # Token-wise importance predictor
192
+ 'importance_predictor': nn.Sequential(
193
+ nn.Linear(hidden_dim, 256),
194
+ nn.ReLU(),
195
+ nn.Linear(256, self.max_tokens), # Importance for each potential token
196
+ nn.Softmax(dim=-1)
197
+ )
198
+ })
199
+
200
+ # Boundary refinement: N → M tokens with linguistic awareness
201
+ self.boundary_refiner = nn.ModuleDict({
202
+ 'scorer': nn.Sequential(
203
+ nn.Linear(hidden_dim, 512),
204
+ nn.ReLU(),
205
+ nn.Linear(512, 1)
206
+ ),
207
+ 'morpheme_detector': nn.Conv1d(256, 64, 3), # 형태소
208
+ 'word_detector': nn.Conv1d(256, 64, 5), # 단어
209
+ 'phrase_detector': nn.Conv1d(256, 64, 7), # 구
210
+ 'adjuster': nn.TransformerEncoderLayer(
211
+ d_model=hidden_dim,
212
+ nhead=16,
213
+ dim_feedforward=4 * hidden_dim,
214
+ dropout=0.1,
215
+ batch_first=True
216
+ )
217
+ })
218
+
219
+ # Initialize split_predictor bias to prefer 1 token initially
220
+ # This ensures untrained model starts with maximum compression
221
+ with torch.no_grad():
222
+ self.language_splitter['split_predictor'].bias.data = torch.tensor([2.0, -1.0, -1.0, -1.0])
223
+ # High bias for 1 token, negative for others
224
+
225
+ def forward(self, input_ids: torch.Tensor, temperature: float = 1.0) -> Dict[str, torch.Tensor]:
226
+ """
227
+ Progressive splitting forward pass
228
+
229
+ Args:
230
+ input_ids: Input byte sequence [batch, seq_len]
231
+ temperature: Gumbel-Softmax temperature for annealing
232
+ """
233
+ batch_size = input_ids.size(0)
234
+
235
+ # Step 1: 48 bytes → 1 super token
236
+ byte_embeddings = self.byte_embed(input_ids) # [batch, 48, 64]
237
+ flattened = byte_embeddings.view(batch_size, -1) # [batch, 3072]
238
+ super_token = self.initial_compressor(flattened) # [batch, 1280]
239
+ super_token = super_token.unsqueeze(1) # [batch, 1, 1280]
240
+
241
+ # Step 2: Language analysis and splitting (1 → N)
242
+ lang_features = self.language_splitter['analyzer'](super_token)
243
+ split_logits = self.language_splitter['split_predictor'](lang_features)
244
+ split_weights = F.softmax(split_logits, dim=-1) # [batch, 1, 8]
245
+
246
+ # Direct transformation from super token to initial representation
247
+ # No hardcoded splits - let the model learn everything
248
+ lang_tokens = super_token # Start with compressed representation
249
+
250
+ # TRUE Adaptive expansion - Model learns optimal split (1~4 tokens)
251
+ # Analyze content to decide how many tokens needed
252
+ expansion_features = self.language_splitter['analyzer'](lang_tokens) # [batch, 1, 256]
253
+
254
+ # Dynamic expansion: generate up to 4 tokens from super token
255
+ expanded = self.language_splitter['dynamic_expander'](lang_tokens.squeeze(1)) # [batch, hidden_dim*4]
256
+ expanded = expanded.reshape(batch_size, self.max_tokens, self.hidden_dim) # [batch, 4, hidden_dim]
257
+
258
+ # Predict how many tokens we actually need (1~4)
259
+ split_logits = self.language_splitter['split_predictor'](expansion_features.squeeze(1)) # [batch, 4]
260
+ # Clamp logits to prevent extreme values that cause NaN
261
+ split_logits = torch.clamp(split_logits, min=-10, max=10)
262
+ # Ensure minimum temperature to prevent instability
263
+ safe_temperature = max(temperature, 0.5)
264
+ split_weights = F.gumbel_softmax(split_logits, tau=safe_temperature, hard=False, dim=-1) # [batch, 4]
265
+
266
+ # Predict importance for each potential token position
267
+ importance = self.language_splitter['importance_predictor'](lang_tokens.squeeze(1)) # [batch, 4]
268
+
269
+ # Dynamic token selection with importance-weighted allocation
270
+ # Create cumulative mask for progressive token usage
271
+ # If split_weights = [0.1, 0.2, 0.6, 0.1], we mainly use 3 tokens
272
+
273
+ # Create progressive masks for 1, 2, 3, 4 tokens
274
+ masks = []
275
+ for n in range(1, self.max_tokens + 1):
276
+ mask = torch.zeros(batch_size, self.max_tokens, 1, device=expanded.device)
277
+ mask[:, :n, :] = 1.0
278
+ masks.append(mask)
279
+
280
+ # Apply importance-weighted masking
281
+ # Important parts get more tokens, less important parts get fewer
282
+ weighted_outputs = []
283
+ for i, mask in enumerate(masks):
284
+ num_tokens = i + 1
285
+ # Weight by both split decision and importance
286
+ token_weight = split_weights[:, i:i+1].unsqueeze(-1) # [batch, 1, 1]
287
+
288
+ # Apply importance modulation for asymmetric splits
289
+ if num_tokens > 1:
290
+ # Redistribute tokens based on importance
291
+ importance_adjusted = importance[:, :num_tokens].unsqueeze(-1) # [batch, n, 1]
292
+ masked = expanded[:, :num_tokens] * importance_adjusted
293
+ else:
294
+ masked = expanded[:, :num_tokens]
295
+
296
+ # Pad to max length
297
+ if num_tokens < self.max_tokens:
298
+ padding = torch.zeros(batch_size, self.max_tokens - num_tokens, self.hidden_dim,
299
+ device=expanded.device)
300
+ masked = torch.cat([masked, padding], dim=1)
301
+
302
+ weighted_outputs.append(masked * token_weight)
303
+
304
+ # Sum all weighted possibilities (differentiable selection)
305
+ lang_tokens = sum(weighted_outputs)
306
+
307
+ # Determine effective number of tokens (for monitoring)
308
+ # Weighted average of token counts
309
+ token_counts = torch.arange(1, self.max_tokens + 1, device=split_weights.device, dtype=torch.float32)
310
+ avg_tokens = (split_weights * token_counts).sum(dim=-1).mean().item()
311
+
312
+ k = lang_tokens.size(1)
313
+
314
+ # Step 3: Boundary refinement (N → M)
315
+ # Calculate boundary scores for each token position
316
+ boundary_scores = self.boundary_refiner['scorer'](lang_tokens) # [batch, N, 1]
317
+
318
+ # Detect linguistic boundaries (morpheme, word, phrase)
319
+ # Extract features for boundary detection
320
+ if hasattr(lang_tokens, 'shape') and len(lang_tokens.shape) == 3:
321
+ batch_size, num_tokens, hidden_dim = lang_tokens.shape
322
+
323
+ # For boundary detection, we need to consider the original byte sequence
324
+ # But we're working with compressed tokens here
325
+ # So we detect boundaries based on learned representations
326
+
327
+ # Apply boundary adjustment with TransformerEncoderLayer
328
+ # This learns to adjust token boundaries based on context
329
+ refined_tokens = self.boundary_refiner['adjuster'](lang_tokens)
330
+
331
+ # The adjuster should learn to:
332
+ # 1. Respect UTF-8 boundaries (learned during training)
333
+ # 2. Align with word/phrase boundaries (learned from language patterns)
334
+ # 3. Maintain semantic coherence within each token
335
+ else:
336
+ refined_tokens = lang_tokens
337
+
338
+ # Determine actual number of tokens based on highest probability
339
+ # During inference, use argmax. During training, use weighted average.
340
+ if self.training:
341
+ # During training, use weighted average for differentiability
342
+ actual_num_tokens = avg_tokens
343
+ else:
344
+ # During inference, select the split with highest probability
345
+ split_decision = torch.argmax(split_weights, dim=-1) # [batch]
346
+ actual_num_tokens = (split_decision.float().mean() + 1).item() # +1 because indices are 0-3
347
+
348
+ # Calculate compression ratio based on actual tokens used
349
+ compression_ratio = 48.0 / max(1, actual_num_tokens)
350
+
351
+ return {
352
+ 'tokens': refined_tokens,
353
+ 'num_tokens': actual_num_tokens,
354
+ 'compression_ratio': torch.tensor(compression_ratio, device=refined_tokens.device),
355
+ 'gate_values': None, # Will be filled by cross-attention
356
+ 'language_features': lang_features,
357
+ 'split_weights': split_weights,
358
+ 'avg_tokens': avg_tokens if 'avg_tokens' in locals() else refined_tokens.size(1),
359
+ 'split_distribution': split_weights.mean(dim=0) if 'split_weights' in locals() else None
360
+ }
361
+
362
+
363
+ class EncoderV62(nn.Module):
364
+ """
365
+ 4-Layer Progressive Splitting Encoder with Cross-Attention
366
+ All layers: 1280 dimensions
367
+ """
368
+
369
+ def __init__(self, config: Optional[Dict] = None):
370
+ super().__init__()
371
+
372
+ # Store config for later use
373
+ self.config = config or {}
374
+
375
+ # Configuration
376
+ self.hidden_dim = 1280
377
+ self.num_heads = 16
378
+ self.num_layers = 4
379
+ self.max_seq_len = 48
380
+ self.dropout = 0.1
381
+
382
+ # RoPE positional encoding (GPT-5 suggestion)
383
+ self.rope = RoPEPositionalEncoding(self.hidden_dim, self.max_seq_len)
384
+
385
+ # Layer 0: Progressive Splitting (48→1→N→M) - Pass config
386
+ self.progressive_splitter = ProgressiveSplittingLayer(self.hidden_dim, config)
387
+
388
+ # Layers 1-3: Transformer encoders with cross-attention
389
+ self.encoder_layers = nn.ModuleList([
390
+ nn.TransformerEncoderLayer(
391
+ d_model=self.hidden_dim,
392
+ nhead=self.num_heads,
393
+ dim_feedforward=4 * self.hidden_dim, # 5120
394
+ dropout=self.dropout,
395
+ batch_first=True
396
+ ) for _ in range(3)
397
+ ])
398
+
399
+ # Cross-attention between layers with MQA (GPT-5 suggestion)
400
+ self.cross_attentions = nn.ModuleList([
401
+ GatedCrossAttention(self.hidden_dim, self.num_heads, kv_heads=2) # 8x memory reduction
402
+ for _ in range(3)
403
+ ])
404
+
405
+ # Output heads for different tasks
406
+ self.boundary_head = nn.Linear(self.hidden_dim, 4)
407
+ self.language_head = nn.Linear(self.hidden_dim, 128) # Reduced from 512 (GPT suggestion)
408
+ self.compression_head = nn.Linear(self.hidden_dim, self.hidden_dim)
409
+
410
+ # Monitoring metrics (GPT-5 suggestion)
411
+ self.register_buffer('compression_ratios', torch.zeros(1))
412
+ self.register_buffer('gate_averages', torch.zeros(3))
413
+
414
+ def forward(self,
415
+ input_ids: torch.Tensor,
416
+ attention_mask: Optional[torch.Tensor] = None,
417
+ temperature: float = 1.0) -> Dict[str, torch.Tensor]:
418
+ """
419
+ Forward pass through the encoder
420
+
421
+ Args:
422
+ input_ids: Input byte sequence
423
+ attention_mask: Optional attention mask
424
+ temperature: Gumbel-Softmax temperature for annealing
425
+ """
426
+ # Layer 0: Progressive splitting with temperature
427
+ split_output = self.progressive_splitter(input_ids, temperature)
428
+ x = split_output['tokens'] # [batch, M, 1280]
429
+
430
+ # Apply RoPE
431
+ x = self.rope(x, x.size(1))
432
+
433
+ # Store all hidden states for decoder
434
+ all_hidden_states = [x]
435
+ gate_values_list = []
436
+
437
+ # Layers 1-3 with cross-attention
438
+ for i, (encoder_layer, cross_attn) in enumerate(
439
+ zip(self.encoder_layers, self.cross_attentions)
440
+ ):
441
+ # Self-attention through transformer layer
442
+ # GPT final check: Don't pass mask after progressive splitting changes sequence length
443
+ x = encoder_layer(x) # No mask needed (no padding after compression)
444
+
445
+ # Cross-attention with previous layer
446
+ if i > 0:
447
+ # Cross-attention with previous layer
448
+ x, gate_values = cross_attn(
449
+ query=x,
450
+ key=all_hidden_states[-1],
451
+ value=all_hidden_states[-1],
452
+ mask=None # Mask not applicable after compression
453
+ )
454
+ gate_values_list.append(gate_values)
455
+ # Keep tensor shape consistent - store in existing buffer element
456
+ self.gate_averages[i-1] = gate_values.mean().detach().item() # Fix indexing
457
+
458
+ all_hidden_states.append(x)
459
+
460
+ # Output projections
461
+ boundaries = self.boundary_head(x)
462
+ language_clusters = self.language_head(x)
463
+ compressed = self.compression_head(x)
464
+
465
+ # Update monitoring metrics
466
+ # Ensure tensor is 1-dimensional for buffer assignment
467
+ compression_ratio = split_output['compression_ratio']
468
+ if compression_ratio.dim() == 0: # Scalar tensor
469
+ self.compression_ratios[0] = compression_ratio
470
+ else:
471
+ self.compression_ratios = compression_ratio
472
+
473
+ return {
474
+ 'last_hidden_state': x,
475
+ 'all_hidden_states': all_hidden_states,
476
+ 'boundaries': boundaries,
477
+ 'language_clusters': language_clusters,
478
+ 'compressed': compressed,
479
+ 'compression_ratio': split_output['compression_ratio'],
480
+ 'num_tokens': split_output['num_tokens'],
481
+ 'splitting_probs': split_output.get('split_weights', None), # Add for diagnostics
482
+ 'gate_values': gate_values_list,
483
+ 'gate_averages': self.gate_averages,
484
+ 'split_info': {
485
+ 'language_features': split_output['language_features'],
486
+ 'split_weights': split_output['split_weights']
487
+ }
488
+ }
489
+
490
+ def get_monitoring_stats(self) -> Dict[str, float]:
491
+ """
492
+ Get monitoring statistics (GPT-5 suggestion)
493
+ """
494
+ return {
495
+ 'avg_compression_ratio': self.compression_ratios.item(),
496
+ 'gate_layer1': self.gate_averages[0].item(),
497
+ 'gate_layer2': self.gate_averages[1].item(),
498
+ 'gate_layer3': self.gate_averages[2].item(),
499
+ }
500
+
501
+ def set_warmup_step(self, step: int, total_warmup: int = 1000):
502
+ """
503
+ Set warmup alpha for all gates (GPT suggestion)
504
+ Gradually increase gate influence from 0 to 1
505
+ """
506
+ alpha = min(1.0, step / total_warmup)
507
+ for cross_attn in self.cross_attentions:
508
+ cross_attn.warmup_alpha = torch.tensor(alpha, device=cross_attn.warmup_alpha.device)
509
+
510
+ def adaptive_compression_control(self, reconstruction_loss: float):
511
+ """
512
+ Adaptive compression based on reconstruction quality
513
+ No fixed phases - model learns optimal compression
514
+ """
515
+ # If reconstruction is poor, model will learn to use more tokens
516
+ # This happens automatically through gradient descent
517
+ # No manual phase control needed
518
+ pass # Let gradients handle it
519
+
520
+
521
+ class DualSlidingWindowEncoder(EncoderV62):
522
+ """
523
+ Extension with dual sliding window system
524
+ Handles both chunk-level and token-level boundaries
525
+ """
526
+
527
+ def __init__(self, config: Optional[Dict] = None):
528
+ super().__init__(config)
529
+
530
+ # Chunk-level sliding window
531
+ self.chunk_window = nn.Conv1d(
532
+ in_channels=1,
533
+ out_channels=1,
534
+ kernel_size=8, # 8-byte overlap
535
+ stride=40, # 48-8=40 stride
536
+ padding=4
537
+ )
538
+
539
+ # Token-level sliding window
540
+ self.token_window = nn.MultiheadAttention(
541
+ embed_dim=self.hidden_dim,
542
+ num_heads=self.num_heads,
543
+ batch_first=True
544
+ )
545
+
546
+ def process_long_sequence(self, input_ids: torch.Tensor) -> torch.Tensor:
547
+ """
548
+ Handle sequences longer than 48 bytes with sliding windows
549
+ """
550
+ batch_size, seq_len = input_ids.shape
551
+
552
+ if seq_len <= 48:
553
+ return super().forward(input_ids)
554
+
555
+ # Process in chunks with overlap
556
+ chunks = []
557
+ for i in range(0, seq_len - 48 + 1, 40): # 8-byte overlap
558
+ chunk = input_ids[:, i:i+48]
559
+ chunk_output = super().forward(chunk)
560
+ chunks.append(chunk_output['last_hidden_state'])
561
+
562
+ # Combine chunks with attention
563
+ combined = torch.cat(chunks, dim=1)
564
+ attended, _ = self.token_window(combined, combined, combined)
565
+
566
+ return {
567
+ 'last_hidden_state': attended,
568
+ 'num_chunks': len(chunks),
569
+ 'total_compression': seq_len / attended.size(1)
570
+ }
571
+
572
+
573
+ if __name__ == "__main__":
574
+ # Test the encoder
575
+ encoder = EncoderV62()
576
+
577
+ # Test input
578
+ batch_size = 2
579
+ input_ids = torch.randint(0, 256, (batch_size, 48))
580
+
581
+ # Forward pass
582
+ output = encoder(input_ids)
583
+
584
+ print(f"Input shape: {input_ids.shape}")
585
+ print(f"Output tokens: {output['num_tokens']}")
586
+ print(f"Compression ratio: {output['compression_ratio']:.2f}:1")
587
+ print(f"Gate averages: {output['gate_averages']}")
588
+ print(f"Monitoring stats: {encoder.get_monitoring_stats()}")
core/intelligent_loss.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Intelligent Loss Functions for v6.2.0
3
+ Multi-objective loss with GPT-5 suggested improvements
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import Dict, Optional, Tuple
10
+ import math
11
+
12
+
13
+ class IntelligentLoss(nn.Module):
14
+ """
15
+ Comprehensive loss function for progressive splitting tokenizer
16
+ Combines multiple objectives with dynamic weighting
17
+ """
18
+
19
+ def __init__(self, config: Optional[Dict] = None):
20
+ super().__init__()
21
+
22
+ # Default configuration
23
+ self.config = config or {}
24
+
25
+ # Special tokens (must match tokenizer)
26
+ self.PAD = 256
27
+ self.BOS = 257
28
+ self.EOS = 258
29
+ self.MASK = 259
30
+
31
+ # Loss components
32
+ self.reconstruction_loss = ReconstructionLoss(self.PAD)
33
+ self.compression_loss = CompressionLoss()
34
+ self.boundary_loss = BoundaryLoss()
35
+ self.language_loss = LanguageLoss()
36
+ self.consistency_loss = ConsistencyLoss()
37
+
38
+ # Dynamic weight adjustment
39
+ self.use_dynamic_weights = True
40
+ self.weight_history = {
41
+ 'reconstruction': [],
42
+ 'compression': [],
43
+ 'boundary': [],
44
+ 'language': [],
45
+ 'consistency': []
46
+ }
47
+
48
+ def estimate_language_difficulty(self, targets: Dict) -> float:
49
+ """Estimate language difficulty based on input characteristics"""
50
+ if 'input_ids' not in targets:
51
+ return 1.0
52
+
53
+ input_ids = targets['input_ids']
54
+ if input_ids.numel() == 0:
55
+ return 1.0
56
+
57
+ # Higher entropy = more complex language
58
+ unique_tokens = input_ids.unique().numel()
59
+ total_tokens = input_ids.numel()
60
+ diversity = min(1.0, (unique_tokens / total_tokens) * 2)
61
+
62
+ return diversity
63
+
64
+ def forward(self,
65
+ outputs: Dict[str, torch.Tensor],
66
+ targets: Dict[str, torch.Tensor],
67
+ weights: Optional[Dict[str, float]] = None) -> Dict[str, torch.Tensor]:
68
+ """
69
+ Compute combined loss with all objectives
70
+
71
+ Args:
72
+ outputs: Model outputs dictionary
73
+ targets: Target values dictionary
74
+ weights: Optional weight overrides
75
+
76
+ Returns:
77
+ Dictionary with total loss and individual components
78
+ """
79
+ losses = {}
80
+
81
+ # 1. Reconstruction loss (primary objective)
82
+ if 'logits' in outputs and 'input_ids' in targets:
83
+ losses['reconstruction'] = self.reconstruction_loss(
84
+ outputs['logits'],
85
+ targets['input_ids'],
86
+ targets.get('attention_mask')
87
+ )
88
+
89
+ # 2. Compression loss (encourage optimal compression)
90
+ if 'compression_ratio' in outputs:
91
+ losses['compression'] = self.compression_loss(
92
+ outputs['compression_ratio'],
93
+ outputs.get('num_tokens')
94
+ )
95
+
96
+ # 3. Boundary loss (learn meaningful boundaries)
97
+ if 'boundaries' in outputs and 'boundary_targets' in targets:
98
+ losses['boundary'] = self.boundary_loss(
99
+ outputs['boundaries'],
100
+ targets['boundary_targets'],
101
+ targets.get('boundary_mask')
102
+ )
103
+
104
+ # 4. Language loss (language identification/clustering)
105
+ if 'language_clusters' in outputs and 'language_targets' in targets:
106
+ losses['language'] = self.language_loss(
107
+ outputs['language_clusters'],
108
+ targets['language_targets']
109
+ )
110
+
111
+ # 5. Consistency loss (encoder-decoder consistency)
112
+ if 'encoder_hidden' in outputs and 'decoder_hidden' in outputs:
113
+ losses['consistency'] = self.consistency_loss(
114
+ outputs['encoder_hidden'],
115
+ outputs['decoder_hidden']
116
+ )
117
+
118
+ # Apply weights (either provided or dynamic)
119
+ if weights is None and self.use_dynamic_weights:
120
+ weights = self.compute_dynamic_weights(losses)
121
+ elif weights is None:
122
+ weights = {
123
+ 'reconstruction': 1.0,
124
+ 'compression': 1.0,
125
+ 'boundary': 1.0,
126
+ 'language': 0.5,
127
+ 'consistency': 0.5
128
+ }
129
+
130
+ # Weighted sum
131
+ total_loss = torch.tensor(0.0, device=next(iter(losses.values())).device)
132
+ for key, loss in losses.items():
133
+ weight = weights.get(key, 1.0)
134
+ total_loss = total_loss + weight * loss
135
+ losses[f'{key}_weighted'] = weight * loss
136
+
137
+ losses['total'] = total_loss
138
+
139
+ # Update weight history
140
+ for key in self.weight_history:
141
+ if key in losses:
142
+ self.weight_history[key].append(losses[key].item())
143
+
144
+ return losses
145
+
146
+ def compute_dynamic_weights(self, losses: Dict[str, torch.Tensor]) -> Dict[str, float]:
147
+ """
148
+ Dynamically adjust weights based on loss magnitudes and progress
149
+ GPT-5 suggestion: balance loss magnitudes for stable training
150
+ """
151
+ weights = {}
152
+ eps = 1e-8 # GPT fix: prevent division by zero
153
+
154
+ # Get loss magnitudes with NaN protection
155
+ magnitudes = {}
156
+ for k, v in losses.items():
157
+ if torch.isnan(v) or torch.isinf(v):
158
+ magnitudes[k] = 1.0 # Default safe value
159
+ else:
160
+ magnitudes[k] = v.item()
161
+
162
+ # Compute relative scales (GPT fix: add epsilon)
163
+ avg_magnitude = max(eps, sum(magnitudes.values()) / len(magnitudes))
164
+
165
+ for key, magnitude in magnitudes.items():
166
+ # Inverse scaling to balance magnitudes (GPT fix: add epsilon)
167
+ weights[key] = avg_magnitude / max(eps, magnitude)
168
+
169
+ # Dynamic adjustment based on loss ratios
170
+ if 'reconstruction' in magnitudes and 'compression' in magnitudes:
171
+ recon_loss = magnitudes['reconstruction']
172
+ comp_loss = magnitudes['compression']
173
+
174
+ # If reconstruction loss is too high relative to compression
175
+ if recon_loss > comp_loss * 10:
176
+ # Drastically reduce compression pressure
177
+ weights['compression'] *= 0.1
178
+ weights['reconstruction'] *= 5.0
179
+ elif recon_loss > comp_loss * 5:
180
+ # Moderate adjustment
181
+ weights['compression'] *= 0.5
182
+ weights['reconstruction'] *= 2.0
183
+ elif recon_loss < comp_loss * 0.5:
184
+ # Good reconstruction, can push compression
185
+ weights['compression'] *= 2.0
186
+ weights['reconstruction'] *= 0.5
187
+
188
+ # Normalize weights to prevent explosion
189
+ total_weight = sum(weights.values())
190
+ if total_weight > 0:
191
+ weights = {k: min(10.0, v / total_weight * len(weights)) for k, v in weights.items()}
192
+
193
+ return weights
194
+
195
+
196
+ class ReconstructionLoss(nn.Module):
197
+ """
198
+ Cross-entropy loss for sequence reconstruction
199
+ With label smoothing and focal loss options
200
+ """
201
+
202
+ def __init__(self, pad_token: int = 256, label_smoothing: float = 0.1):
203
+ super().__init__()
204
+ self.pad_token = pad_token
205
+ self.label_smoothing = label_smoothing
206
+ self.focal_alpha = 0.25
207
+ self.focal_gamma = 2.0
208
+ self.use_focal = False
209
+
210
+ def forward(self,
211
+ logits: torch.Tensor,
212
+ targets: torch.Tensor,
213
+ mask: Optional[torch.Tensor] = None) -> torch.Tensor:
214
+ """
215
+ Compute reconstruction loss
216
+
217
+ Args:
218
+ logits: [batch, seq_len, vocab_size]
219
+ targets: [batch, seq_len]
220
+ mask: [batch, seq_len] attention mask
221
+ """
222
+ batch_size, seq_len, vocab_size = logits.shape
223
+
224
+ # Reshape for loss computation
225
+ logits_flat = logits.reshape(-1, vocab_size)
226
+ targets_flat = targets.reshape(-1)
227
+
228
+ if self.use_focal:
229
+ # Focal loss for hard examples
230
+ ce_loss = F.cross_entropy(logits_flat, targets_flat, reduction='none')
231
+ pt = torch.exp(-ce_loss)
232
+ focal_loss = self.focal_alpha * (1 - pt) ** self.focal_gamma * ce_loss
233
+
234
+ if mask is not None:
235
+ mask_flat = mask.reshape(-1)
236
+ focal_loss = focal_loss * mask_flat
237
+ loss = focal_loss.sum() / mask_flat.sum()
238
+ else:
239
+ loss = focal_loss.mean()
240
+ else:
241
+ # Standard cross-entropy with label smoothing
242
+ if mask is not None:
243
+ mask_flat = mask.reshape(-1).bool() # GPT fix: ensure bool dtype
244
+ loss = F.cross_entropy(
245
+ logits_flat[mask_flat],
246
+ targets_flat[mask_flat],
247
+ ignore_index=self.pad_token,
248
+ label_smoothing=self.label_smoothing
249
+ )
250
+ else:
251
+ loss = F.cross_entropy(
252
+ logits_flat,
253
+ targets_flat,
254
+ ignore_index=self.pad_token,
255
+ label_smoothing=self.label_smoothing
256
+ )
257
+
258
+ return loss
259
+
260
+
261
+ class CompressionLoss(nn.Module):
262
+ """
263
+ Aggressive compression loss - push for high compression
264
+ Must beat existing tokenizers (4 bytes/token = 4:1)
265
+ """
266
+
267
+ def __init__(self):
268
+ super().__init__()
269
+ # Dynamic compression based on token count
270
+ # 1 token = 48:1, 2 = 24:1, 3 = 16:1, 4 = 12:1
271
+ self.min_ratio = 12.0 # 4 tokens (worst case, still 3x better than BPE)
272
+ self.target_ratio = 24.0 # 2 tokens (optimal balance)
273
+ self.max_ratio = 48.0 # 1 token (best compression)
274
+
275
+ def forward(self,
276
+ compression_ratio: torch.Tensor,
277
+ num_tokens: Optional[torch.Tensor] = None) -> torch.Tensor:
278
+ """
279
+ Compute compression loss (GPT fix: fully vectorized)
280
+
281
+ Args:
282
+ compression_ratio: Current compression ratio (scalar or batch)
283
+ num_tokens: Number of tokens used (for additional penalty)
284
+ """
285
+ # Ensure tensor (GPT fix: handle device properly)
286
+ if not torch.is_tensor(compression_ratio):
287
+ device = num_tokens.device if torch.is_tensor(num_tokens) else torch.device('cpu')
288
+ compression_ratio = torch.tensor(compression_ratio, dtype=torch.float32, device=device)
289
+
290
+ # Aggressive compression enforcement
291
+ # MUST achieve at least 16:1 to be viable
292
+ if compression_ratio < self.min_ratio:
293
+ # Moderate penalty for falling below minimum (reduced for stability)
294
+ under_loss = ((self.min_ratio - compression_ratio) / self.min_ratio) * 0.5
295
+ else:
296
+ under_loss = torch.tensor(0.0, dtype=compression_ratio.dtype, device=compression_ratio.device)
297
+
298
+ # Reward getting close to target (24:1)
299
+ if self.min_ratio <= compression_ratio < self.target_ratio:
300
+ # Encourage reaching target
301
+ target_loss = ((self.target_ratio - compression_ratio) / self.target_ratio) * 0.5
302
+ elif compression_ratio >= self.target_ratio:
303
+ # Excellent compression - small reward for going higher
304
+ target_loss = -0.1 * torch.log(compression_ratio / self.target_ratio + 1.0)
305
+ else:
306
+ target_loss = torch.tensor(0.0, dtype=compression_ratio.dtype, device=compression_ratio.device)
307
+
308
+ # Only mild penalty for extreme compression (>48:1)
309
+ if compression_ratio > self.max_ratio:
310
+ over_loss = ((compression_ratio - self.max_ratio) / self.max_ratio) * 0.2
311
+ else:
312
+ over_loss = torch.tensor(0.0, dtype=compression_ratio.dtype, device=compression_ratio.device)
313
+
314
+ loss = under_loss + target_loss + over_loss
315
+
316
+ # Additional penalty based on token count (GPT fix: vectorized)
317
+ if num_tokens is not None:
318
+ if not torch.is_tensor(num_tokens):
319
+ num_tokens = torch.tensor(num_tokens, dtype=torch.float32, device=compression_ratio.device)
320
+ token_penalty = 0.1 * torch.clamp(num_tokens - 8, min=0.0) ** 2
321
+ loss = loss + token_penalty
322
+
323
+ return loss.mean() if loss.dim() > 0 else loss
324
+
325
+
326
+ class BoundaryLoss(nn.Module):
327
+ """
328
+ Learn meaningful chunk boundaries
329
+ Combines multiple boundary objectives
330
+ """
331
+
332
+ def __init__(self):
333
+ super().__init__()
334
+ self.bce_loss = nn.BCEWithLogitsLoss(reduction='none')
335
+
336
+ def forward(self,
337
+ predicted: torch.Tensor,
338
+ target: torch.Tensor,
339
+ mask: Optional[torch.Tensor] = None) -> torch.Tensor:
340
+ """
341
+ Compute boundary loss
342
+
343
+ Args:
344
+ predicted: [batch, seq_len, boundary_classes] predicted boundaries
345
+ target: [batch, seq_len, boundary_classes] target boundaries
346
+ mask: [batch, seq_len] valid positions mask
347
+ """
348
+ # Binary cross-entropy for boundary prediction
349
+ loss = self.bce_loss(predicted, target.float())
350
+
351
+ if mask is not None:
352
+ # Apply mask
353
+ mask_expanded = mask.unsqueeze(-1).expand_as(loss)
354
+ loss = loss * mask_expanded
355
+ loss = loss.sum() / mask_expanded.sum()
356
+ else:
357
+ loss = loss.mean()
358
+
359
+ # Add regularization for boundary sparsity
360
+ # (boundaries should be relatively rare)
361
+ boundary_probs = torch.sigmoid(predicted)
362
+ sparsity_loss = 0.01 * boundary_probs.mean()
363
+
364
+ # Add smoothness regularization
365
+ # (boundaries should be somewhat smooth/continuous)
366
+ if predicted.size(1) > 1:
367
+ diff = predicted[:, 1:] - predicted[:, :-1]
368
+ smoothness_loss = 0.01 * (diff ** 2).mean()
369
+ else:
370
+ smoothness_loss = 0.0
371
+
372
+ total_loss = loss + sparsity_loss + smoothness_loss
373
+
374
+ return total_loss
375
+
376
+
377
+ class LanguageLoss(nn.Module):
378
+ """
379
+ Language identification/clustering loss
380
+ Supports both classification and clustering objectives
381
+ """
382
+
383
+ def __init__(self, num_languages: int = 128, temperature: float = 0.07):
384
+ super().__init__()
385
+ self.num_languages = num_languages
386
+ self.temperature = temperature
387
+
388
+ # For supervised language classification
389
+ self.ce_loss = nn.CrossEntropyLoss()
390
+
391
+ def forward(self,
392
+ predicted: torch.Tensor,
393
+ target: torch.Tensor,
394
+ mode: str = 'classification') -> torch.Tensor:
395
+ """
396
+ Compute language loss
397
+
398
+ Args:
399
+ predicted: [batch, seq_len, num_languages] or [batch, num_languages]
400
+ target: Language labels or cluster assignments
401
+ mode: 'classification' or 'clustering'
402
+ """
403
+ if mode == 'classification':
404
+ # Standard classification loss
405
+ if predicted.dim() == 3:
406
+ # Sequence-level predictions
407
+ batch_size, seq_len, _ = predicted.shape
408
+ predicted = predicted.reshape(-1, self.num_languages)
409
+ target = target.reshape(-1)
410
+
411
+ loss = self.ce_loss(predicted, target)
412
+
413
+ elif mode == 'clustering':
414
+ # Contrastive clustering loss (similar to SimCLR)
415
+ # Normalize embeddings
416
+ predicted = F.normalize(predicted, dim=-1)
417
+
418
+ # Compute similarity matrix
419
+ sim_matrix = torch.matmul(predicted, predicted.t()) / self.temperature
420
+
421
+ # Create labels (assuming batch contains similar samples)
422
+ batch_size = predicted.size(0)
423
+ labels = torch.arange(batch_size, device=predicted.device)
424
+
425
+ # Contrastive loss
426
+ loss = F.cross_entropy(sim_matrix, labels)
427
+
428
+ else:
429
+ raise ValueError(f"Unknown mode: {mode}")
430
+
431
+ return loss
432
+
433
+
434
+ class ConsistencyLoss(nn.Module):
435
+ """
436
+ Ensure consistency between encoder and decoder representations
437
+ GPT-5 suggestion: helps with training stability
438
+ """
439
+
440
+ def __init__(self, margin: float = 0.5):
441
+ super().__init__()
442
+ self.margin = margin
443
+
444
+ def forward(self,
445
+ encoder_hidden: torch.Tensor,
446
+ decoder_hidden: torch.Tensor) -> torch.Tensor:
447
+ """
448
+ Compute consistency loss between encoder and decoder
449
+
450
+ Args:
451
+ encoder_hidden: [batch, seq_len, hidden_dim]
452
+ decoder_hidden: [batch, seq_len, hidden_dim]
453
+ """
454
+ # Ensure same shape
455
+ if encoder_hidden.shape != decoder_hidden.shape:
456
+ # Align sequence lengths if different
457
+ min_len = min(encoder_hidden.size(1), decoder_hidden.size(1))
458
+ encoder_hidden = encoder_hidden[:, :min_len]
459
+ decoder_hidden = decoder_hidden[:, :min_len]
460
+
461
+ # L2 distance
462
+ l2_loss = F.mse_loss(encoder_hidden, decoder_hidden)
463
+
464
+ # Cosine similarity loss
465
+ encoder_norm = F.normalize(encoder_hidden, dim=-1)
466
+ decoder_norm = F.normalize(decoder_hidden, dim=-1)
467
+ cosine_sim = (encoder_norm * decoder_norm).sum(dim=-1)
468
+ cosine_loss = 1.0 - cosine_sim.mean()
469
+
470
+ # Combined loss
471
+ loss = l2_loss + 0.5 * cosine_loss
472
+
473
+ return loss
474
+
475
+
476
+ class AdaptiveLossScheduler:
477
+ """
478
+ Dynamically adjust loss weights during training
479
+ Based on training progress and performance
480
+ """
481
+
482
+ def __init__(self, config: Dict):
483
+ self.config = config
484
+ self.current_phase = 0
485
+ self.phase_epochs = [30, 60, 100] # Phase transition points
486
+
487
+ # Define phase-specific weights
488
+ self.phase_weights = [
489
+ # Phase 1: Boundary mastery
490
+ {
491
+ 'reconstruction': 2.0,
492
+ 'compression': 0.5,
493
+ 'boundary': 3.0,
494
+ 'language': 0.5,
495
+ 'consistency': 0.5
496
+ },
497
+ # Phase 2: Compression focus
498
+ {
499
+ 'reconstruction': 2.0,
500
+ 'compression': 3.0,
501
+ 'boundary': 1.0,
502
+ 'language': 1.0,
503
+ 'consistency': 1.0
504
+ },
505
+ # Phase 3: Balanced optimization
506
+ {
507
+ 'reconstruction': 3.0,
508
+ 'compression': 2.0,
509
+ 'boundary': 1.0,
510
+ 'language': 1.0,
511
+ 'consistency': 1.5
512
+ }
513
+ ]
514
+
515
+ def get_weights(self, epoch: int, metrics: Optional[Dict] = None) -> Dict[str, float]:
516
+ """
517
+ Get current loss weights based on training phase
518
+
519
+ Args:
520
+ epoch: Current training epoch
521
+ metrics: Optional performance metrics for adaptive adjustment
522
+ """
523
+ # Determine current phase
524
+ for i, phase_end in enumerate(self.phase_epochs):
525
+ if epoch <= phase_end:
526
+ self.current_phase = i
527
+ break
528
+
529
+ weights = self.phase_weights[self.current_phase].copy()
530
+
531
+ # Adaptive adjustments based on metrics
532
+ if metrics:
533
+ # If reconstruction is poor, increase its weight
534
+ if metrics.get('reconstruction_accuracy', 1.0) < 0.9:
535
+ weights['reconstruction'] *= 1.5
536
+
537
+ # If compression is off target, adjust weight
538
+ compression_ratio = metrics.get('compression_ratio', 16.0)
539
+ if compression_ratio < 8.0 or compression_ratio > 20.0:
540
+ weights['compression'] *= 1.5
541
+
542
+ return weights
543
+
544
+
545
+ if __name__ == "__main__":
546
+ # Test losses
547
+ print("Testing Intelligent Loss Functions")
548
+
549
+ # Create loss module
550
+ loss_fn = IntelligentLoss()
551
+
552
+ # Create dummy data
553
+ batch_size = 2
554
+ seq_len = 48
555
+ vocab_size = 260
556
+ hidden_dim = 1280
557
+
558
+ outputs = {
559
+ 'logits': torch.randn(batch_size, seq_len, vocab_size),
560
+ 'compression_ratio': torch.tensor(16.0),
561
+ 'num_tokens': torch.tensor(3),
562
+ 'boundaries': torch.randn(batch_size, seq_len, 4),
563
+ 'language_clusters': torch.randn(batch_size, 128),
564
+ 'encoder_hidden': torch.randn(batch_size, seq_len, hidden_dim),
565
+ 'decoder_hidden': torch.randn(batch_size, seq_len, hidden_dim)
566
+ }
567
+
568
+ targets = {
569
+ 'input_ids': torch.randint(0, 256, (batch_size, seq_len)),
570
+ 'attention_mask': torch.ones(batch_size, seq_len),
571
+ 'boundary_targets': torch.zeros(batch_size, seq_len, 4),
572
+ 'language_targets': torch.randint(0, 128, (batch_size,))
573
+ }
574
+
575
+ # Compute losses
576
+ losses = loss_fn(outputs, targets)
577
+
578
+ print("\nLoss components:")
579
+ for key, value in losses.items():
580
+ if isinstance(value, torch.Tensor):
581
+ print(f" {key}: {value.item():.4f}")
582
+
583
+ # Test adaptive scheduler
584
+ scheduler = AdaptiveLossScheduler({})
585
+
586
+ print("\nPhase weights:")
587
+ for epoch in [10, 40, 70]:
588
+ weights = scheduler.get_weights(epoch)
589
+ print(f" Epoch {epoch}: {weights}")
core/scheduler.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Learning Rate Schedulers for v6.2.0
3
+ Advanced scheduling with warmup and phase-based adjustments
4
+ """
5
+
6
+ import torch
7
+ import math
8
+ from typing import Optional, Dict, List, Any
9
+ import numpy as np
10
+
11
+
12
+ class WarmupCosineScheduler:
13
+ """
14
+ Cosine annealing with linear warmup
15
+ GPT-5 suggested: Essential for stable progressive splitting training
16
+ """
17
+
18
+ def __init__(self,
19
+ optimizer: torch.optim.Optimizer,
20
+ warmup_steps: int,
21
+ total_steps: int,
22
+ min_lr: float = 1e-6,
23
+ max_lr: Optional[float] = None):
24
+ self.optimizer = optimizer
25
+ self.warmup_steps = warmup_steps
26
+ self.total_steps = total_steps
27
+ self.min_lr = min_lr
28
+ self.max_lr = max_lr or optimizer.param_groups[0]['lr']
29
+ self.current_step = 0
30
+
31
+ def step(self):
32
+ """Update learning rate"""
33
+ self.current_step += 1
34
+
35
+ if self.current_step <= self.warmup_steps:
36
+ # Linear warmup
37
+ lr = self.max_lr * (self.current_step / self.warmup_steps)
38
+ else:
39
+ # Cosine annealing (GPT fix: guard against division by zero)
40
+ if self.total_steps <= self.warmup_steps:
41
+ lr = self.min_lr
42
+ else:
43
+ progress = (self.current_step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
44
+ progress = min(1.0, max(0.0, progress)) # Clamp to [0, 1]
45
+ lr = self.min_lr + (self.max_lr - self.min_lr) * 0.5 * (1 + math.cos(math.pi * progress))
46
+
47
+ for param_group in self.optimizer.param_groups:
48
+ param_group['lr'] = lr
49
+
50
+ return lr
51
+
52
+ def get_lr(self):
53
+ """Get current learning rate"""
54
+ return self.optimizer.param_groups[0]['lr']
55
+
56
+
57
+ class PhaseBasedScheduler:
58
+ """
59
+ Curriculum learning scheduler with phase transitions
60
+ Adjusts learning rate based on training phases
61
+ """
62
+
63
+ def __init__(self,
64
+ optimizer: torch.optim.Optimizer,
65
+ phase_configs: List[Dict],
66
+ current_epoch: int = 0):
67
+ """
68
+ Args:
69
+ optimizer: PyTorch optimizer
70
+ phase_configs: List of phase configurations
71
+ [{
72
+ 'epochs': (start, end),
73
+ 'lr': learning_rate,
74
+ 'warmup_epochs': warmup_duration
75
+ }, ...]
76
+ """
77
+ self.optimizer = optimizer
78
+ self.phase_configs = phase_configs
79
+ self.current_epoch = current_epoch
80
+ self.current_phase = 0
81
+ self.base_lr = optimizer.param_groups[0]['lr']
82
+
83
+ def step(self, epoch: Optional[int] = None):
84
+ """Update learning rate based on current phase"""
85
+ if epoch is not None:
86
+ self.current_epoch = epoch
87
+
88
+ # Find current phase
89
+ for i, phase in enumerate(self.phase_configs):
90
+ start_epoch, end_epoch = phase['epochs']
91
+ if start_epoch <= self.current_epoch <= end_epoch:
92
+ self.current_phase = i
93
+ break
94
+
95
+ phase = self.phase_configs[self.current_phase]
96
+ target_lr = phase['lr']
97
+ warmup_epochs = phase.get('warmup_epochs', 0)
98
+ start_epoch = phase['epochs'][0]
99
+
100
+ # Apply warmup if in warmup period
101
+ if self.current_epoch - start_epoch < warmup_epochs:
102
+ warmup_progress = (self.current_epoch - start_epoch + 1) / warmup_epochs
103
+ lr = target_lr * warmup_progress
104
+ else:
105
+ lr = target_lr
106
+
107
+ # Update optimizer
108
+ for param_group in self.optimizer.param_groups:
109
+ param_group['lr'] = lr
110
+
111
+ return lr
112
+
113
+
114
+ class AdaptiveScheduler:
115
+ """
116
+ Adaptive learning rate based on validation metrics
117
+ Reduces LR when metrics plateau
118
+ """
119
+
120
+ def __init__(self,
121
+ optimizer: torch.optim.Optimizer,
122
+ mode: str = 'min',
123
+ factor: float = 0.5,
124
+ patience: int = 10,
125
+ threshold: float = 1e-4,
126
+ min_lr: float = 1e-7):
127
+ """
128
+ Args:
129
+ optimizer: PyTorch optimizer
130
+ mode: 'min' or 'max' - whether to reduce LR when metric stops decreasing or increasing
131
+ factor: Factor to reduce LR by
132
+ patience: Number of epochs with no improvement to wait
133
+ threshold: Minimum change to qualify as improvement
134
+ min_lr: Minimum learning rate
135
+ """
136
+ self.optimizer = optimizer
137
+ self.mode = mode
138
+ self.factor = factor
139
+ self.patience = patience
140
+ self.threshold = threshold
141
+ self.min_lr = min_lr
142
+
143
+ self.best_score = None
144
+ self.num_bad_epochs = 0
145
+ self.last_reduction = 0
146
+
147
+ def step(self, metric: float, epoch: int = 0):
148
+ """Update learning rate based on metric"""
149
+ current_lr = self.optimizer.param_groups[0]['lr']
150
+
151
+ if self.best_score is None:
152
+ self.best_score = metric
153
+ else:
154
+ if self.mode == 'min':
155
+ improved = metric < self.best_score - self.threshold
156
+ else:
157
+ improved = metric > self.best_score + self.threshold
158
+
159
+ if improved:
160
+ self.best_score = metric
161
+ self.num_bad_epochs = 0
162
+ else:
163
+ self.num_bad_epochs += 1
164
+
165
+ # Reduce LR if patience exceeded
166
+ if self.num_bad_epochs >= self.patience:
167
+ new_lr = max(current_lr * self.factor, self.min_lr)
168
+
169
+ if new_lr < current_lr:
170
+ print(f"Reducing learning rate from {current_lr:.2e} to {new_lr:.2e}")
171
+
172
+ for param_group in self.optimizer.param_groups:
173
+ param_group['lr'] = new_lr
174
+
175
+ self.num_bad_epochs = 0
176
+ self.last_reduction = epoch
177
+
178
+ return current_lr
179
+
180
+
181
+ class ProgressiveSplittingScheduler:
182
+ """
183
+ Adaptive scheduler for progressive splitting
184
+ No fixed targets - adjusts based on quality feedback
185
+ """
186
+
187
+ def __init__(self,
188
+ optimizer: torch.optim.Optimizer,
189
+ initial_lr: float = 1e-4,
190
+ min_reconstruction: float = 0.85,
191
+ ema: float = 0.98,
192
+ min_lr: float = 1e-7):
193
+ self.optimizer = optimizer
194
+ self.initial_lr = initial_lr
195
+ self.min_reconstruction = min_reconstruction # Quality threshold
196
+ self.ema = ema
197
+ self.min_lr = min_lr
198
+
199
+ # Adaptive multipliers based on performance
200
+ self.quality_multiplier = 1.0 # Adjusts with reconstruction quality
201
+
202
+ # No phases - continuous adaptation
203
+ self.current_state = 'learning'
204
+
205
+ # EMA tracking for smooth transitions
206
+ self._ema_comp = None
207
+ self._ema_recon = None
208
+
209
+ def step(self, metrics: Dict[str, float]):
210
+ """
211
+ Update learning rate based on current metrics
212
+ GPT fix: EMA smoothing and minimum floor
213
+
214
+ Args:
215
+ metrics: Dictionary containing:
216
+ - compression_ratio: Current compression ratio
217
+ - reconstruction_acc: Reconstruction accuracy
218
+ """
219
+ compression_ratio = float(metrics.get('compression_ratio', 0.0))
220
+ reconstruction_acc = float(metrics.get('reconstruction_acc', 0.0))
221
+
222
+ # Update EMA (GPT fix: smooth transitions)
223
+ if self._ema_comp is None:
224
+ self._ema_comp = compression_ratio
225
+ self._ema_recon = reconstruction_acc
226
+ else:
227
+ self._ema_comp = self.ema * self._ema_comp + (1 - self.ema) * compression_ratio
228
+ self._ema_recon = self.ema * self._ema_recon + (1 - self.ema) * reconstruction_acc
229
+
230
+ # Adaptive adjustment based on reconstruction quality only
231
+ # No fixed compression target - emerges from quality
232
+ if self._ema_recon < self.min_reconstruction:
233
+ # Poor reconstruction - reduce LR for careful learning
234
+ self.quality_multiplier = max(0.5, self._ema_recon)
235
+ else:
236
+ # Good reconstruction - normal learning
237
+ self.quality_multiplier = 1.0
238
+
239
+ # Smooth LR changes
240
+ reconstruction_factor = max(0.1, self._ema_recon)
241
+
242
+ # Combined learning rate (adaptive, no phase multiplier)
243
+ lr = self.initial_lr * self.quality_multiplier * reconstruction_factor
244
+ lr = max(lr, self.min_lr) # Ensure minimum LR
245
+
246
+ # Update optimizer
247
+ for param_group in self.optimizer.param_groups:
248
+ param_group['lr'] = lr
249
+
250
+ return lr
251
+
252
+
253
+ class GumbelTemperatureScheduler:
254
+ """
255
+ Temperature annealing for Gumbel-Softmax
256
+ GPT-5 suggestion: Critical for progressive splitting
257
+ """
258
+
259
+ def __init__(self,
260
+ initial_temp: float = 1.0,
261
+ final_temp: float = 0.1,
262
+ anneal_rate: float = 0.99995,
263
+ anneal_steps: Optional[int] = None):
264
+ self.initial_temp = initial_temp
265
+ self.final_temp = final_temp
266
+ self.anneal_rate = anneal_rate
267
+ self.anneal_steps = anneal_steps
268
+ self.current_step = 0
269
+ self.current_temp = initial_temp
270
+
271
+ def step(self):
272
+ """Update temperature"""
273
+ self.current_step += 1
274
+
275
+ if self.anneal_steps:
276
+ # Linear annealing
277
+ progress = min(1.0, self.current_step / self.anneal_steps)
278
+ self.current_temp = self.initial_temp + (self.final_temp - self.initial_temp) * progress
279
+ else:
280
+ # Exponential annealing
281
+ self.current_temp = max(
282
+ self.final_temp,
283
+ self.initial_temp * (self.anneal_rate ** self.current_step)
284
+ )
285
+
286
+ return self.current_temp
287
+
288
+ def get_temperature(self):
289
+ """Get current temperature"""
290
+ return self.current_temp
291
+
292
+
293
+ class CompressionRatioScheduler:
294
+ """
295
+ Schedule target compression ratio during training
296
+ Gradually increase compression requirements
297
+ """
298
+
299
+ def __init__(self,
300
+ initial_ratio: float = 8.0,
301
+ target_ratio: float = 24.0,
302
+ warmup_epochs: int = 10,
303
+ total_epochs: int = 100):
304
+ self.initial_ratio = initial_ratio
305
+ self.target_ratio = target_ratio
306
+ self.warmup_epochs = warmup_epochs
307
+ self.total_epochs = total_epochs
308
+ self.current_epoch = 0
309
+
310
+ def step(self, epoch: Optional[int] = None):
311
+ """Update target compression ratio"""
312
+ if epoch is not None:
313
+ self.current_epoch = epoch
314
+ else:
315
+ self.current_epoch += 1
316
+
317
+ if self.current_epoch < self.warmup_epochs:
318
+ # Start with lower compression requirement
319
+ ratio = self.initial_ratio
320
+ else:
321
+ # Gradually increase to target
322
+ progress = (self.current_epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs)
323
+ progress = min(1.0, progress)
324
+ ratio = self.initial_ratio + (self.target_ratio - self.initial_ratio) * progress
325
+
326
+ return ratio
327
+
328
+
329
+ class MultiScheduler:
330
+ """
331
+ Combine multiple schedulers for comprehensive training control
332
+ """
333
+
334
+ def __init__(self, schedulers: Dict):
335
+ """
336
+ Args:
337
+ schedulers: Dictionary of schedulers
338
+ {
339
+ 'lr': learning_rate_scheduler,
340
+ 'gumbel': gumbel_temperature_scheduler,
341
+ 'compression': compression_ratio_scheduler,
342
+ ...
343
+ }
344
+ """
345
+ self.schedulers = schedulers
346
+
347
+ def step(self, **kwargs):
348
+ """
349
+ Step all schedulers
350
+ GPT fix: unified input convention
351
+
352
+ Returns:
353
+ Dictionary with all scheduler outputs
354
+ """
355
+ results = {}
356
+
357
+ for name, scheduler in self.schedulers.items():
358
+ try:
359
+ # Check scheduler type and pass appropriate arguments
360
+ if hasattr(scheduler, '__class__'):
361
+ class_name = scheduler.__class__.__name__
362
+
363
+ if class_name == 'AdaptiveScheduler' and 'metric' in kwargs:
364
+ results[name] = scheduler.step(kwargs['metric'], kwargs.get('epoch', 0))
365
+ elif class_name == 'PhaseBasedScheduler' and 'epoch' in kwargs:
366
+ results[name] = scheduler.step(kwargs['epoch'])
367
+ elif class_name == 'CompressionRatioScheduler' and 'epoch' in kwargs:
368
+ results[name] = scheduler.step(kwargs['epoch'])
369
+ elif class_name == 'ProgressiveSplittingScheduler' and 'metrics' in kwargs:
370
+ results[name] = scheduler.step(kwargs['metrics'])
371
+ elif hasattr(scheduler, 'step'):
372
+ # Generic step (no arguments)
373
+ results[name] = scheduler.step()
374
+ else:
375
+ if hasattr(scheduler, 'step'):
376
+ results[name] = scheduler.step()
377
+ except Exception as e:
378
+ print(f"Warning: Scheduler '{name}' step failed: {e}")
379
+ results[name] = None
380
+
381
+ return results
382
+
383
+ def get_current_values(self):
384
+ """Get current values from all schedulers"""
385
+ values = {}
386
+
387
+ for name, scheduler in self.schedulers.items():
388
+ if hasattr(scheduler, 'get_lr'):
389
+ values[name] = scheduler.get_lr()
390
+ elif hasattr(scheduler, 'get_temperature'):
391
+ values[name] = scheduler.get_temperature()
392
+ elif hasattr(scheduler, 'current_temp'):
393
+ values[name] = scheduler.current_temp
394
+ elif hasattr(scheduler, 'current_epoch'):
395
+ values[name] = scheduler.current_epoch
396
+
397
+ return values
398
+
399
+
400
+ class GateWarmupScheduler:
401
+ """게이트 파라미터 웜업 스케줄러
402
+
403
+ 초기: 모든 레이어 동등 사용 (gate=1.0)
404
+ 웜업: 점진적 게이트 학습 시작
405
+ 후기: 최적 게이트 값으로 수렴
406
+ """
407
+
408
+ def __init__(
409
+ self,
410
+ optimizer: torch.optim.Optimizer,
411
+ warmup_steps: int = 1000,
412
+ gate_param_group_name: str = 'gates',
413
+ importance_param_group_name: str = 'importance'
414
+ ):
415
+ """
416
+ Args:
417
+ optimizer: 옵티마이저
418
+ warmup_steps: 웜업 스텝 수
419
+ gate_param_group_name: 게이트 파라미터 그룹 이름
420
+ importance_param_group_name: 중요도 파라미터 그룹 이름
421
+ """
422
+ self.optimizer = optimizer
423
+ self.warmup_steps = warmup_steps
424
+ self.gate_group_name = gate_param_group_name
425
+ self.importance_group_name = importance_param_group_name
426
+
427
+ # 초기 학습률 저장
428
+ self.base_lrs = {}
429
+ for group in optimizer.param_groups:
430
+ if 'name' in group:
431
+ self.base_lrs[group['name']] = group['lr']
432
+
433
+ def get_gate_factor(self, step: int) -> float:
434
+ """게이트 학습률 계수 계산
435
+
436
+ 웜업 기간 동안은 낮은 학습률,
437
+ 이후 정상 학습률로 전환
438
+ """
439
+ if step < self.warmup_steps:
440
+ # 웜업 기간: 선형 증가
441
+ return step / self.warmup_steps
442
+ else:
443
+ # 정상 학습
444
+ return 1.0
445
+
446
+ def get_importance_factor(self, step: int) -> float:
447
+ """중요도 학습률 계수 계산
448
+
449
+ 게이트보다 느리게 학습 시작
450
+ """
451
+ delayed_warmup = self.warmup_steps * 1.5
452
+ if step < delayed_warmup:
453
+ return step / delayed_warmup * 0.5
454
+ else:
455
+ return 1.0
456
+
457
+ def step(self, current_step: int):
458
+ """스케줄러 스텝
459
+
460
+ Args:
461
+ current_step: 현재 글로벌 스텝
462
+ """
463
+ # 게이트 파라미터 그룹 학습률 조정
464
+ gate_factor = self.get_gate_factor(current_step)
465
+ importance_factor = self.get_importance_factor(current_step)
466
+
467
+ for group in self.optimizer.param_groups:
468
+ if 'name' not in group:
469
+ continue
470
+
471
+ if group['name'] == self.gate_group_name:
472
+ # 게이트 학습률 조정
473
+ group['lr'] = self.base_lrs[self.gate_group_name] * gate_factor
474
+
475
+ elif group['name'] == self.importance_group_name:
476
+ # 중요도 학습률 조정
477
+ group['lr'] = self.base_lrs[self.importance_group_name] * importance_factor
478
+
479
+ def get_lr(self) -> Dict[str, float]:
480
+ """현재 학습률 반환"""
481
+ lrs = {}
482
+ for group in self.optimizer.param_groups:
483
+ if 'name' in group:
484
+ lrs[group['name']] = group['lr']
485
+ return lrs
486
+
487
+
488
+ class UniversalCosineScheduler:
489
+ """Universal Cosine Annealing 스케줄러
490
+
491
+ 모든 언어에 대해 동일한 스케줄 적용
492
+ """
493
+
494
+ def __init__(
495
+ self,
496
+ optimizer: torch.optim.Optimizer,
497
+ warmup_steps: int = 1000,
498
+ total_steps: int = 10000,
499
+ min_lr_ratio: float = 0.1
500
+ ):
501
+ self.optimizer = optimizer
502
+ self.warmup_steps = warmup_steps
503
+ self.total_steps = total_steps
504
+ self.min_lr_ratio = min_lr_ratio
505
+ self.current_step = 0
506
+
507
+ # 초기 학습률 저장
508
+ self.base_lrs = [group['lr'] for group in optimizer.param_groups]
509
+
510
+ def step(self):
511
+ """스케줄러 스텝"""
512
+ self.current_step += 1
513
+
514
+ for idx, param_group in enumerate(self.optimizer.param_groups):
515
+ if self.current_step < self.warmup_steps:
516
+ # Warmup 단계
517
+ lr = self.base_lrs[idx] * (self.current_step / self.warmup_steps)
518
+ else:
519
+ # Cosine annealing
520
+ if self.total_steps <= self.warmup_steps:
521
+ # warmup_steps가 total_steps보다 크거나 같은 경우
522
+ lr = self.base_lrs[idx] * self.min_lr_ratio
523
+ else:
524
+ progress = min(1.0, (self.current_step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps))
525
+ lr = self.base_lrs[idx] * (
526
+ self.min_lr_ratio + (1 - self.min_lr_ratio) * 0.5 * (1 + math.cos(math.pi * progress))
527
+ )
528
+
529
+ param_group['lr'] = lr
530
+
531
+ def get_last_lr(self) -> List[float]:
532
+ """마지막 학습률 반환"""
533
+ return [group['lr'] for group in self.optimizer.param_groups]
534
+
535
+ def state_dict(self) -> Dict[str, Any]:
536
+ """스케줄러 상태 딕셔너리 반환 (체크포인트 저장용)"""
537
+ return {
538
+ 'current_step': self.current_step,
539
+ 'warmup_steps': self.warmup_steps,
540
+ 'total_steps': self.total_steps,
541
+ 'min_lr_ratio': self.min_lr_ratio,
542
+ 'base_lrs': self.base_lrs
543
+ }
544
+
545
+ def load_state_dict(self, state_dict: Dict[str, Any]):
546
+ """스케줄러 상태 로드 (체크포인트 재시작용)"""
547
+ self.current_step = state_dict['current_step']
548
+ self.warmup_steps = state_dict['warmup_steps']
549
+ self.total_steps = state_dict['total_steps']
550
+ self.min_lr_ratio = state_dict['min_lr_ratio']
551
+ self.base_lrs = state_dict['base_lrs']
552
+
553
+
554
+ class AdaptiveLayerScheduler:
555
+ """레이어별 적응적 스케줄러
556
+
557
+ 각 레이어의 학습 진행도에 따라 동적으로 조정
558
+ """
559
+
560
+ def __init__(
561
+ self,
562
+ layer_builder,
563
+ threshold_active: float = 0.7,
564
+ threshold_skip: float = 0.3
565
+ ):
566
+ """
567
+ Args:
568
+ layer_builder: LayerBuilder 인스턴스
569
+ threshold_active: 활성 레이어 임계값
570
+ threshold_skip: 스킵 레이어 임계값
571
+ """
572
+ self.layer_builder = layer_builder
573
+ self.threshold_active = threshold_active
574
+ self.threshold_skip = threshold_skip
575
+
576
+ # 레이어별 통계
577
+ self.layer_stats = {
578
+ 'usage_count': torch.zeros(5),
579
+ 'contribution': torch.zeros(5)
580
+ }
581
+
582
+ def update_stats(self, batch_output):
583
+ """배치 출력으로 통계 업데이트"""
584
+ with torch.no_grad():
585
+ gates = torch.sigmoid(self.layer_builder.layer_gates)
586
+
587
+ # 사용 횟수 업데이트
588
+ self.layer_stats['usage_count'] += (gates > self.threshold_skip).float()
589
+
590
+ # 기여도 추정 (간단한 버전)
591
+ importance = torch.nn.functional.softmax(
592
+ self.layer_builder.layer_importance, dim=0
593
+ )
594
+ self.layer_stats['contribution'] += importance.detach()
595
+
596
+ def get_layer_status(self) -> Dict[int, str]:
597
+ """각 레이어의 상태 반환"""
598
+ gates = torch.sigmoid(self.layer_builder.layer_gates)
599
+ status = {}
600
+
601
+ for i in range(5):
602
+ if gates[i] > self.threshold_active:
603
+ status[i] = "ACTIVE"
604
+ elif gates[i] > self.threshold_skip:
605
+ status[i] = "PARTIAL"
606
+ else:
607
+ status[i] = "SKIP"
608
+
609
+ return status
610
+
611
+ def suggest_pruning(self) -> List[int]:
612
+ """프루닝 가능한 레이어 제안"""
613
+ gates = torch.sigmoid(self.layer_builder.layer_gates)
614
+ prunable = []
615
+
616
+ for i in range(5):
617
+ if gates[i] < self.threshold_skip:
618
+ # 낮은 게이트 값 + 낮은 기여도
619
+ if self.layer_stats['contribution'][i] < 0.1:
620
+ prunable.append(i)
621
+
622
+ return prunable
623
+
624
+
625
+ if __name__ == "__main__":
626
+ # Test schedulers
627
+ print("Testing Schedulers")
628
+
629
+ # Create dummy optimizer
630
+ model = torch.nn.Linear(10, 10)
631
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
632
+
633
+ # Test WarmupCosineScheduler
634
+ print("\n1. WarmupCosineScheduler:")
635
+ scheduler = WarmupCosineScheduler(optimizer, warmup_steps=100, total_steps=1000)
636
+ lrs = []
637
+ for step in range(200):
638
+ lr = scheduler.step()
639
+ if step % 20 == 0:
640
+ print(f" Step {step}: LR = {lr:.6f}")
641
+ lrs.append(lr)
642
+
643
+ # Test PhaseBasedScheduler
644
+ print("\n2. PhaseBasedScheduler:")
645
+ phase_configs = [
646
+ {'epochs': (0, 30), 'lr': 1e-4, 'warmup_epochs': 5},
647
+ {'epochs': (31, 60), 'lr': 5e-5, 'warmup_epochs': 2},
648
+ {'epochs': (61, 100), 'lr': 1e-5, 'warmup_epochs': 0}
649
+ ]
650
+ scheduler = PhaseBasedScheduler(optimizer, phase_configs)
651
+ for epoch in [0, 5, 31, 35, 61, 80]:
652
+ lr = scheduler.step(epoch)
653
+ print(f" Epoch {epoch}: LR = {lr:.6f}")
654
+
655
+ # Test GumbelTemperatureScheduler
656
+ print("\n3. GumbelTemperatureScheduler:")
657
+ scheduler = GumbelTemperatureScheduler()
658
+ for step in [0, 100, 500, 1000, 5000]:
659
+ for _ in range(step - scheduler.current_step):
660
+ scheduler.step()
661
+ temp = scheduler.get_temperature()
662
+ print(f" Step {step}: Temperature = {temp:.4f}")
663
+
664
+ # Test CompressionRatioScheduler
665
+ print("\n4. CompressionRatioScheduler:")
666
+ scheduler = CompressionRatioScheduler()
667
+ for epoch in [0, 5, 10, 30, 50, 80, 100]:
668
+ ratio = scheduler.step(epoch)
669
+ print(f" Epoch {epoch}: Target ratio = {ratio:.1f}:1")
core/tokenizer.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Intelligent Tokenizer v6.2.0 - Byte Tokenizer with 46+2 Configuration
3
+ Handles chunking, sliding windows, and boundary adjustments
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from typing import Dict, List, Optional, Tuple, Union
10
+ import numpy as np
11
+
12
+
13
+ def _trim_utf8_boundary(byte_seq: List[int], limit: int) -> int:
14
+ """
15
+ Trim byte sequence to valid UTF-8 boundary (GPT suggestion)
16
+ """
17
+ end = min(limit, len(byte_seq))
18
+ while end > 0:
19
+ try:
20
+ bytes(byte_seq[:end]).decode('utf-8')
21
+ return end
22
+ except UnicodeDecodeError:
23
+ end -= 1
24
+ return limit
25
+
26
+
27
+ class ByteTokenizerV62:
28
+ """
29
+ Pure byte-level tokenizer
30
+ 46 content bytes + 2 special tokens (BOS/EOS) = 48 total
31
+ """
32
+
33
+ def __init__(self, config: Optional[Dict] = None):
34
+ # Configuration
35
+ self.content_size = 46 # Actual content bytes
36
+ self.max_seq_len = 48 # Total with BOS/EOS
37
+ self.chunk_overlap = 8 # Overlap for sliding window
38
+
39
+ # Special tokens
40
+ self.PAD = 256
41
+ self.BOS = 257
42
+ self.EOS = 258
43
+ self.MASK = 259
44
+ self.vocab_size = 260 # 256 bytes + 4 special
45
+
46
+ def encode(self,
47
+ text: str,
48
+ add_special_tokens: bool = True,
49
+ return_chunks: bool = False) -> Dict[str, torch.Tensor]:
50
+ """
51
+ Encode text to byte sequences
52
+
53
+ Args:
54
+ text: Input text
55
+ add_special_tokens: Whether to add BOS/EOS
56
+ return_chunks: Return multiple chunks for long sequences
57
+ """
58
+ # Convert to UTF-8 bytes
59
+ byte_sequence = list(text.encode('utf-8'))
60
+
61
+ if return_chunks and len(byte_sequence) > self.content_size:
62
+ # Handle long sequences with sliding window
63
+ return self._encode_with_chunks(byte_sequence, add_special_tokens)
64
+
65
+ # Single chunk processing with UTF-8 boundary (GPT suggestion)
66
+ if len(byte_sequence) > self.content_size:
67
+ cut_point = _trim_utf8_boundary(byte_sequence, self.content_size)
68
+ byte_sequence = byte_sequence[:cut_point]
69
+
70
+ # Add special tokens (GPT suggestion: cleaner padding order)
71
+ if add_special_tokens:
72
+ byte_sequence = [self.BOS] + byte_sequence + [self.EOS]
73
+
74
+ # Pad to max_seq_len (after special tokens for cleaner structure)
75
+ if len(byte_sequence) < self.max_seq_len:
76
+ padding_length = self.max_seq_len - len(byte_sequence)
77
+ byte_sequence = byte_sequence + [self.PAD] * padding_length
78
+
79
+ input_ids = torch.tensor(byte_sequence, dtype=torch.long)
80
+ attention_mask = (input_ids != self.PAD) # bool type (GPT suggestion)
81
+
82
+ return {
83
+ 'input_ids': input_ids,
84
+ 'attention_mask': attention_mask,
85
+ 'length': len(byte_sequence),
86
+ 'original_length': len(text.encode('utf-8'))
87
+ }
88
+
89
+ def _encode_with_chunks(self,
90
+ byte_sequence: List[int],
91
+ add_special_tokens: bool) -> Dict[str, torch.Tensor]:
92
+ """
93
+ Encode long sequences with sliding window chunks
94
+ """
95
+ chunks = []
96
+ positions = []
97
+
98
+ # Calculate stride (content_size - overlap)
99
+ stride = self.content_size - self.chunk_overlap
100
+
101
+ for i in range(0, len(byte_sequence), stride):
102
+ # Extract chunk
103
+ chunk = byte_sequence[i:i + self.content_size]
104
+
105
+ # Skip if chunk is too small (last chunk)
106
+ if len(chunk) < self.content_size // 2:
107
+ if chunks: # Merge with previous chunk if exists
108
+ last_chunk = chunks[-1]['input_ids'].tolist()
109
+ # Remove padding and special tokens from last chunk (GPT final check)
110
+ last_chunk = [b for b in last_chunk if b not in [self.PAD, self.BOS, self.EOS]]
111
+ # Add current chunk
112
+ merged = last_chunk + chunk + [self.EOS]
113
+ # Repad
114
+ if len(merged) < self.max_seq_len:
115
+ merged += [self.PAD] * (self.max_seq_len - len(merged))
116
+ merged_ids = torch.tensor(merged[:self.max_seq_len], dtype=torch.long)
117
+ merged_mask = (merged_ids != self.PAD) # Recalculate mask (GPT suggestion)
118
+ chunks[-1]['input_ids'] = merged_ids
119
+ chunks[-1]['attention_mask'] = merged_mask
120
+ break
121
+
122
+ # Pad chunk if necessary
123
+ if len(chunk) < self.content_size:
124
+ chunk += [self.PAD] * (self.content_size - len(chunk))
125
+
126
+ # Add special tokens
127
+ if add_special_tokens:
128
+ chunk_with_special = [self.BOS] + chunk + [self.EOS]
129
+ else:
130
+ chunk_with_special = chunk
131
+
132
+ # Create tensors
133
+ input_ids = torch.tensor(chunk_with_special, dtype=torch.long)
134
+ attention_mask = (input_ids != self.PAD) # bool type (GPT suggestion)
135
+
136
+ chunks.append({
137
+ 'input_ids': input_ids,
138
+ 'attention_mask': attention_mask,
139
+ 'position': (i, min(i + self.content_size, len(byte_sequence)))
140
+ })
141
+ positions.append((i, min(i + self.content_size, len(byte_sequence))))
142
+
143
+ # Stack all chunks
144
+ all_input_ids = torch.stack([c['input_ids'] for c in chunks])
145
+ all_attention_masks = torch.stack([c['attention_mask'] for c in chunks])
146
+
147
+ return {
148
+ 'input_ids': all_input_ids, # [num_chunks, seq_len]
149
+ 'attention_mask': all_attention_masks,
150
+ 'num_chunks': len(chunks),
151
+ 'chunk_positions': positions,
152
+ 'original_length': len(byte_sequence)
153
+ }
154
+
155
+ def reconstruct(self,
156
+ input_ids: torch.Tensor,
157
+ positions: List[Tuple[int, int]] = None,
158
+ skip_special_tokens: bool = True,
159
+ overlap: int = 8) -> str:
160
+ """
161
+ Reconstruct text from multiple chunks (GPT suggestion)
162
+
163
+ Args:
164
+ input_ids: [num_chunks, seq_len] for multi-chunk
165
+ positions: List of (start, end) positions for each chunk
166
+ skip_special_tokens: Whether to skip special tokens
167
+ overlap: Overlap size between chunks
168
+ """
169
+ if input_ids.dim() == 1:
170
+ # Single sequence, use regular decode
171
+ return self.decode(input_ids, skip_special_tokens)
172
+
173
+ # Multi-chunk reconstruction
174
+ pieces = []
175
+ for i, chunk_ids in enumerate(input_ids):
176
+ chunk_ids = chunk_ids.cpu().numpy().tolist()
177
+
178
+ # Remove special tokens and padding
179
+ if skip_special_tokens:
180
+ chunk_ids = [
181
+ b for b in chunk_ids
182
+ if b not in [self.PAD, self.BOS, self.EOS, self.MASK] and b < 256
183
+ ]
184
+
185
+ pieces.append(chunk_ids)
186
+
187
+ # Merge chunks with overlap handling
188
+ output = []
189
+ for i, chunk in enumerate(pieces):
190
+ if i == 0:
191
+ output.extend(chunk)
192
+ else:
193
+ # Skip overlap bytes from current chunk
194
+ output.extend(chunk[overlap:] if len(chunk) > overlap else chunk)
195
+
196
+ # Convert to string
197
+ try:
198
+ text = bytes(output).decode('utf-8', errors='replace')
199
+ except:
200
+ text = ""
201
+
202
+ return text
203
+
204
+ def decode(self,
205
+ input_ids: torch.Tensor,
206
+ skip_special_tokens: bool = True) -> str:
207
+ """
208
+ Decode byte sequences back to text
209
+ """
210
+ if isinstance(input_ids, torch.Tensor):
211
+ input_ids = input_ids.cpu().numpy().tolist()
212
+
213
+ # Handle batch dimension
214
+ if isinstance(input_ids[0], list):
215
+ input_ids = input_ids[0]
216
+
217
+ # Remove special tokens and padding
218
+ if skip_special_tokens:
219
+ input_ids = [
220
+ b for b in input_ids
221
+ if b not in [self.PAD, self.BOS, self.EOS, self.MASK] and b < 256
222
+ ]
223
+
224
+ # Convert bytes to string
225
+ try:
226
+ text = bytes(input_ids).decode('utf-8', errors='replace')
227
+ except:
228
+ text = ""
229
+
230
+ return text
231
+
232
+ def batch_encode(self,
233
+ texts: List[str],
234
+ add_special_tokens: bool = True) -> Dict[str, torch.Tensor]:
235
+ """
236
+ Encode multiple texts as a batch
237
+ """
238
+ encoded = [self.encode(text, add_special_tokens) for text in texts]
239
+
240
+ # Find max length
241
+ max_len = max(e['length'] for e in encoded)
242
+ max_len = min(max_len, self.max_seq_len)
243
+
244
+ # Create batch tensors
245
+ batch_size = len(texts)
246
+ input_ids = torch.full((batch_size, max_len), self.PAD, dtype=torch.long)
247
+ attention_mask = torch.zeros((batch_size, max_len), dtype=torch.bool) # bool type (GPT suggestion)
248
+
249
+ for i, enc in enumerate(encoded):
250
+ seq_len = min(enc['length'], max_len)
251
+ if enc['input_ids'].dim() == 0: # Handle scalar
252
+ enc['input_ids'] = enc['input_ids'].unsqueeze(0)
253
+ input_ids[i, :seq_len] = enc['input_ids'][:seq_len]
254
+ attention_mask[i, :seq_len] = True
255
+
256
+ return {
257
+ 'input_ids': input_ids,
258
+ 'attention_mask': attention_mask,
259
+ 'lengths': [e['length'] for e in encoded]
260
+ }
261
+
262
+
263
+ class ChunkBoundaryAdjuster(nn.Module):
264
+ """
265
+ Neural network for adjusting chunk boundaries
266
+ Learns optimal splitting points
267
+ """
268
+
269
+ def __init__(self, hidden_dim: int = 256):
270
+ super().__init__()
271
+
272
+ # Boundary scoring network
273
+ self.boundary_scorer = nn.Sequential(
274
+ nn.Linear(256, hidden_dim), # Input: byte embeddings
275
+ nn.ReLU(),
276
+ nn.Dropout(0.1),
277
+ nn.Linear(hidden_dim, hidden_dim // 2),
278
+ nn.ReLU(),
279
+ nn.Linear(hidden_dim // 2, 1), # Output: boundary score
280
+ nn.Sigmoid()
281
+ )
282
+
283
+ # UTF-8 boundary detector
284
+ self.utf8_detector = nn.Sequential(
285
+ nn.Conv1d(1, 16, kernel_size=4, padding=2), # Detect multi-byte patterns
286
+ nn.ReLU(),
287
+ nn.Conv1d(16, 1, kernel_size=1),
288
+ nn.Sigmoid()
289
+ )
290
+
291
+ def forward(self, byte_sequence: torch.Tensor) -> torch.Tensor:
292
+ """
293
+ Find optimal chunk boundaries
294
+
295
+ Args:
296
+ byte_sequence: [batch, seq_len, embedding_dim]
297
+
298
+ Returns:
299
+ boundary_scores: [batch, seq_len] - probability of boundary at each position
300
+ """
301
+ batch_size, seq_len = byte_sequence.shape[:2]
302
+
303
+ # Score each position as potential boundary
304
+ boundary_scores = self.boundary_scorer(byte_sequence).squeeze(-1)
305
+
306
+ # Detect UTF-8 boundaries (avoid splitting multi-byte characters)
307
+ byte_values = byte_sequence[..., 0].unsqueeze(1) # [batch, 1, seq_len]
308
+ utf8_scores = self.utf8_detector(byte_values).squeeze(1) # [batch, seq_len]
309
+
310
+ # Combine scores (prefer boundaries at valid UTF-8 positions)
311
+ combined_scores = boundary_scores * utf8_scores
312
+
313
+ # Apply constraints: boundaries should be ~46 bytes apart
314
+ for i in range(0, seq_len, 46):
315
+ if i < seq_len:
316
+ # Boost score at expected positions
317
+ combined_scores[:, i] = combined_scores[:, i] * 1.5
318
+
319
+ return combined_scores
320
+
321
+
322
+ class SlidingWindowProcessor(nn.Module):
323
+ """
324
+ Process sequences with sliding windows at multiple scales
325
+ """
326
+
327
+ def __init__(self, window_sizes: List[int] = [8, 16, 32, 46]):
328
+ super().__init__()
329
+ self.window_sizes = window_sizes
330
+
331
+ # Multi-scale convolutions for different window sizes
332
+ self.convs = nn.ModuleList([
333
+ nn.Conv1d(256, 128, kernel_size=ws, stride=ws//2, padding=ws//4)
334
+ for ws in window_sizes
335
+ ])
336
+
337
+ # Fusion layer
338
+ self.fusion = nn.Sequential(
339
+ nn.Linear(128 * len(window_sizes), 256),
340
+ nn.ReLU(),
341
+ nn.Dropout(0.1),
342
+ nn.Linear(256, 256)
343
+ )
344
+
345
+ def forward(self, byte_embeddings: torch.Tensor) -> torch.Tensor:
346
+ """
347
+ Apply multi-scale sliding windows
348
+
349
+ Args:
350
+ byte_embeddings: [batch, seq_len, embedding_dim]
351
+
352
+ Returns:
353
+ processed: [batch, seq_len, embedding_dim]
354
+ """
355
+ # Transpose for conv1d
356
+ x = byte_embeddings.transpose(1, 2) # [batch, embed, seq]
357
+
358
+ # Apply multi-scale convolutions
359
+ multi_scale_features = []
360
+ for conv in self.convs:
361
+ features = conv(x) # Different seq lengths
362
+ # Global average pooling to fixed size
363
+ pooled = F.adaptive_avg_pool1d(features, byte_embeddings.size(1))
364
+ multi_scale_features.append(pooled)
365
+
366
+ # Concatenate and transpose back
367
+ concat = torch.cat(multi_scale_features, dim=1) # [batch, 128*scales, seq]
368
+ concat = concat.transpose(1, 2) # [batch, seq, 128*scales]
369
+
370
+ # Fuse multi-scale features
371
+ fused = self.fusion(concat) # [batch, seq, 256]
372
+
373
+ # Residual connection
374
+ output = fused + byte_embeddings
375
+
376
+ return output
377
+
378
+
379
+ class AdaptiveChunker:
380
+ """
381
+ Adaptive chunking based on content complexity
382
+ Simple heuristic-based chunker for inference
383
+ """
384
+
385
+ def __init__(self):
386
+ self.min_chunk = 32
387
+ self.max_chunk = 46
388
+ self.target_chunk = 46
389
+
390
+ def determine_chunk_size(self, text: str) -> int:
391
+ """
392
+ Determine optimal chunk size based on text characteristics
393
+ """
394
+ byte_seq = text.encode('utf-8')
395
+
396
+ # Check character types
397
+ has_cjk = any(b >= 0x80 for b in byte_seq[:100]) # Non-ASCII
398
+ has_arabic = any(0x0600 <= ord(c) <= 0x06FF for c in text[:100])
399
+
400
+ # Adjust chunk size based on content
401
+ if has_cjk:
402
+ # CJK characters need smaller chunks (multi-byte)
403
+ return self.min_chunk
404
+ elif has_arabic:
405
+ # Arabic also benefits from smaller chunks
406
+ return 40
407
+ else:
408
+ # ASCII/Latin can use larger chunks
409
+ return self.target_chunk
410
+
411
+ def chunk_text(self, text: str) -> List[str]:
412
+ """
413
+ Split text into adaptive chunks
414
+ """
415
+ chunk_size = self.determine_chunk_size(text)
416
+ byte_seq = text.encode('utf-8')
417
+ chunks = []
418
+
419
+ i = 0
420
+ while i < len(byte_seq):
421
+ # Find chunk boundary (don't split UTF-8 sequences)
422
+ end = min(i + chunk_size, len(byte_seq))
423
+
424
+ # Backtrack to valid UTF-8 boundary if needed
425
+ while end > i and end < len(byte_seq):
426
+ try:
427
+ _ = byte_seq[i:end].decode('utf-8')
428
+ break
429
+ except:
430
+ end -= 1
431
+
432
+ chunk_bytes = byte_seq[i:end]
433
+ chunks.append(chunk_bytes.decode('utf-8', errors='replace'))
434
+ i = end
435
+
436
+ return chunks
437
+
438
+
439
+ if __name__ == "__main__":
440
+ # Test the tokenizer
441
+ tokenizer = ByteTokenizerV62()
442
+
443
+ # Test texts
444
+ test_texts = [
445
+ "Hello, world!",
446
+ "안녕하세요, 세계!",
447
+ "今天天气很好。",
448
+ "مرحبا بالعالم",
449
+ "A" * 100 # Long text
450
+ ]
451
+
452
+ for text in test_texts:
453
+ print(f"\nText: {text[:50]}...")
454
+
455
+ # Single chunk encoding
456
+ encoded = tokenizer.encode(text)
457
+ print(f" Encoded shape: {encoded['input_ids'].shape}")
458
+ print(f" Original length: {encoded['original_length']} bytes")
459
+
460
+ # Decode back
461
+ decoded = tokenizer.decode(encoded['input_ids'])
462
+ print(f" Decoded: {decoded[:50]}...")
463
+
464
+ # Check multi-chunk for long text
465
+ if encoded['original_length'] > 46:
466
+ multi_encoded = tokenizer.encode(text, return_chunks=True)
467
+ print(f" Chunks: {multi_encoded['num_chunks']}")
468
+
469
+ # Test batch encoding
470
+ batch = tokenizer.batch_encode(test_texts[:3])
471
+ print(f"\nBatch shape: {batch['input_ids'].shape}")
472
+
473
+ # Test adaptive chunker
474
+ chunker = AdaptiveChunker()
475
+ for text in test_texts[:3]:
476
+ chunk_size = chunker.determine_chunk_size(text)
477
+ print(f"\n{text[:30]}... → Chunk size: {chunk_size}")
core/unified_model.py CHANGED
@@ -1,755 +1,541 @@
1
  """
2
- Unified Intelligent Tokenizer Model v6.1.2
3
- Compression-First Learning with Adaptive Splitting
4
- - 64 byte chunks for aggressive compression
5
- - 50 epoch checkpoints with automatic splitting
6
- - Group relation learning for reconstruction
7
- - Boundary adjustment for semantic units
8
  """
9
 
10
  import torch
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
- import math
14
  from typing import Dict, List, Optional, Tuple, Union
 
15
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- class PositionalEncoding(nn.Module):
 
18
  """
19
- Sinusoidal Positional Encoding (Transformer 원본 방식)
20
- 학습 가능한 위치 임베딩 대신 고정된 sin/cos 패턴 사용
 
 
 
 
 
 
21
  """
22
-
23
- def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):
24
  super().__init__()
25
- self.dropout = nn.Dropout(dropout)
26
-
27
- # Create sinusoidal position encodings
28
- pe = torch.zeros(max_len, d_model)
29
- position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
30
-
31
- div_term = torch.exp(torch.arange(0, d_model, 2).float() *
32
- -(math.log(10000.0) / d_model))
33
-
34
- pe[:, 0::2] = torch.sin(position * div_term) # Even dimensions
35
- pe[:, 1::2] = torch.cos(position * div_term) # Odd dimensions
36
-
37
- # Register as buffer (not trainable)
38
- self.register_buffer('pe', pe.unsqueeze(0))
39
-
40
- def forward(self, x: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
41
  """
42
- Add positional encoding to input
 
43
  Args:
44
- x: (batch_size, seq_len, d_model)
 
 
 
 
 
 
 
 
45
  """
46
- x = x + self.pe[:, :x.size(1)]
47
- return self.dropout(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- class ByteTokenizer:
51
- """
52
- Pure byte-level tokenizer - no language rules
53
- """
54
-
55
- def __init__(self, max_seq_len: int = 64): # v6.1.2: 64 bytes for compression-first approach
56
- self.max_seq_len = max_seq_len
57
- self.PAD = 256
58
- self.BOS = 257
59
- self.EOS = 258
60
- self.MASK = 259
61
-
62
- def encode(self, text: str, add_special_tokens: bool = True) -> Dict[str, torch.Tensor]:
63
- # Convert to UTF-8 bytes
64
- byte_seq = list(text.encode('utf-8'))
65
-
66
- # Truncate if needed
67
- if len(byte_seq) > self.max_seq_len - 2:
68
- byte_seq = byte_seq[:self.max_seq_len - 2]
69
-
70
- # Add special tokens
71
- if add_special_tokens:
72
- byte_seq = [self.BOS] + byte_seq + [self.EOS]
73
-
74
- input_ids = torch.tensor(byte_seq, dtype=torch.long)
75
- attention_mask = torch.ones_like(input_ids)
76
-
77
- return {
78
- 'input_ids': input_ids,
79
- 'attention_mask': attention_mask,
80
- 'length': len(input_ids)
81
- }
82
-
83
- def encode_batch(self, texts: List[str]) -> Dict[str, torch.Tensor]:
84
- encoded = [self.encode(text) for text in texts]
85
- max_len = min(max(e['length'] for e in encoded), self.max_seq_len)
86
-
87
- batch_size = len(texts)
88
- input_ids = torch.full((batch_size, max_len), self.PAD, dtype=torch.long)
89
- attention_mask = torch.zeros((batch_size, max_len), dtype=torch.float32)
90
-
91
- for i, enc in enumerate(encoded):
92
- seq_len = min(enc['length'], max_len)
93
- input_ids[i, :seq_len] = enc['input_ids'][:seq_len]
94
- attention_mask[i, :seq_len] = 1.0
95
-
96
- return {
97
- 'input_ids': input_ids,
98
- 'attention_mask': attention_mask
99
- }
100
-
101
- def decode(self, input_ids: torch.Tensor, skip_special_tokens: bool = True) -> str:
102
- if isinstance(input_ids, torch.Tensor):
103
- input_ids = input_ids.cpu().numpy().tolist()
104
-
105
- if skip_special_tokens:
106
- input_ids = [b for b in input_ids if b < 256]
107
-
108
- try:
109
- byte_array = bytes([min(b, 255) for b in input_ids if b != self.PAD])
110
- return byte_array.decode('utf-8', errors='replace')
111
- except:
112
- return "".join([chr(b) if b < 128 else '?' for b in input_ids if b < 256])
113
-
114
-
115
- class ByteEncoderV61(nn.Module):
116
- """
117
- v6.1: 5-Layer Encoder with Layer-Specialized Architecture
118
- Layer 0: 768d - Byte to character (with curriculum learning)
119
- Layer 1: 896d - Language pattern discovery (no labels)
120
- Layer 2: 1024d - Eojeol/Word formation (+ eojeol PE)
121
- Layer 3: 1152d - Small phrase grouping (2-3 eojeols)
122
- Layer 4: 1280d - Final refinement (+ context PE)
123
-
124
- Target: 어절(eojeol) to 구(phrase) level compression (3:1 ratio)
125
- """
126
 
127
- def __init__(
128
- self,
129
- vocab_size: int = 260,
130
- hidden_dims: List[int] = [768, 896, 1024, 1152, 1280], # v6.1 dimensions
131
- num_heads: List[int] = [12, 14, 16, 18, 20], # v6.1: Progressive heads per layer
132
- dropout: float = 0.1,
133
- max_seq_len: int = 64 # v6.1.2: 64 chunk for compression-first
134
- ):
135
- super().__init__()
136
-
137
- # Layer 0: Byte to Character with Curriculum Learning
138
- self.byte_embedding = nn.Embedding(vocab_size, hidden_dims[0])
139
 
140
- # v6.1: Multi-level boundary predictors for hierarchical segmentation
141
- # Level 1: Character boundaries (UTF-8 multi-byte)
142
- self.char_boundary_predictor = nn.Linear(hidden_dims[0], 3) # 0: continue, 1: start, 2: end
143
 
144
- # Level 2: Eojeol boundaries (space + particle analysis)
145
- self.eojeol_boundary_predictor = nn.Linear(hidden_dims[2], 4) # 0: inside, 1: space, 2: particle, 3: punct
 
 
 
 
146
 
147
- # Level 3: Phrase boundaries (syntactic chunks)
148
- self.phrase_boundary_predictor = nn.Linear(hidden_dims[3], 3) # 0: inside, 1: weak boundary, 2: strong boundary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- # v6.1: Positional encoding ONLY for Layer 0
151
- self.pos_encoding = PositionalEncoding(hidden_dims[0], max_seq_len, dropout)
152
 
153
- # v6.1: Layer 1 - Language pattern discovery (no labels!)
154
- self.pattern_discoverer = nn.Linear(hidden_dims[1], 256) # Discover patterns autonomously (from 896d)
155
- self.lang_signal_generator = nn.Linear(hidden_dims[1], 128) # Generate language signals (from 896d)
 
 
156
 
157
- # v6.1: Group-aware relative position encodings for Layer 2-4
158
- self.group_pe_layer2 = nn.Embedding(max_seq_len, hidden_dims[2]) # For eojeol/word units
159
- self.group_pe_layer3 = nn.Embedding(max_seq_len, hidden_dims[3]) # For small phrases (2-3 eojeols)
160
- self.group_pe_layer4 = nn.Embedding(max_seq_len, hidden_dims[4]) # For context/discourse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
- # 5 Transformer layers with dimension changes
163
- self.layers = nn.ModuleList()
164
- for i in range(len(hidden_dims)):
165
- input_dim = hidden_dims[i-1] if i > 0 else hidden_dims[0]
166
- output_dim = hidden_dims[i]
167
 
168
- # Projection layer if dimension changes
169
- if input_dim != output_dim:
170
- proj = nn.Linear(input_dim, output_dim)
 
171
  else:
172
- proj = None
173
-
174
- # v6.1: Layer-specific head count for optimal dimension per head
175
- # Target: 64-80 dim per head
176
- layer_heads = num_heads[i] if isinstance(num_heads, list) else num_heads
177
-
178
- # Transformer encoder layer
179
- layer = nn.TransformerEncoderLayer(
180
- d_model=output_dim,
181
- nhead=layer_heads,
182
- dim_feedforward=output_dim * 4,
183
- dropout=dropout,
184
- activation='gelu',
185
- batch_first=True,
186
- norm_first=True
187
- )
188
-
189
- self.layers.append(nn.ModuleDict({
190
- 'projection': proj,
191
- 'transformer': layer,
192
- 'norm': nn.LayerNorm(output_dim)
193
- }))
194
-
195
- self.dropout = nn.Dropout(dropout)
196
-
197
- def forward(
198
- self,
199
- input_ids: torch.Tensor,
200
- attention_mask: Optional[torch.Tensor] = None,
201
- boundary_labels: Optional[torch.Tensor] = None,
202
- epoch: int = 0
203
- ) -> Dict[str, torch.Tensor]:
204
  """
205
- v6.1 Forward pass with curriculum learning
 
206
  Args:
207
- boundary_labels: UTF-8 boundary labels for curriculum learning (training only)
208
- epoch: Current epoch for curriculum schedule
 
 
 
 
 
 
 
209
  """
210
- batch_size, seq_len = input_ids.shape
211
-
212
- # Layer 0: Byte embedding with curriculum learning
213
- x = self.byte_embedding(input_ids)
214
-
215
- # v6.1: Positional encoding ONLY at Layer 0
216
- x = self.pos_encoding(x)
217
-
218
- # v6.1: Predict character boundaries (Layer 0)
219
- char_boundaries = self.char_boundary_predictor(x)
220
-
221
- # v6.1: Curriculum learning for character boundaries
222
- # Note: boundary_labels are eojeol boundaries (4 classes), not char boundaries (3 classes)
223
- # So we don't mix them with char_boundaries - they serve different purposes
224
- char_boundary_weights = F.softmax(char_boundaries, dim=-1)
225
-
226
- # Prepare attention mask
227
- if attention_mask is not None:
228
- # Keep attention mask as is for TransformerEncoderLayer
229
- # It expects shape (batch_size, seq_len) and handles masking internally
230
- pass
231
-
232
- # v6.1: Process through 5 specialized layers
233
- all_hidden_states = []
234
- discovered_patterns = None
235
- eojeol_boundaries = None
236
- phrase_boundaries = None
237
-
238
- for i, layer_dict in enumerate(self.layers):
239
- # Project if needed (before layer-specific processing)
240
- if layer_dict['projection'] is not None:
241
- x = layer_dict['projection'](x)
242
-
243
- # Layer 1: Add language signals (autonomous discovery)
244
- if i == 1:
245
- # Discover language patterns WITHOUT labels (x is now 896d)
246
- discovered_patterns = self.pattern_discoverer(x)
247
- lang_signals = self.lang_signal_generator(x)
248
-
249
- # Layer 2: Predict eojeol boundaries and add position encoding
250
- elif i == 2:
251
- # Predict eojeol boundaries (spaces, particles, punctuation)
252
- eojeol_boundaries = self.eojeol_boundary_predictor(x)
253
- positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
254
- group_pe = self.group_pe_layer2(positions)
255
- x = x + group_pe * 0.1 # Mild addition to preserve main signal
256
-
257
- # Layer 3: Predict phrase boundaries and add position encoding
258
- elif i == 3:
259
- # Predict phrase boundaries (weak/strong syntactic breaks)
260
- phrase_boundaries = self.phrase_boundary_predictor(x)
261
- positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
262
- group_pe = self.group_pe_layer3(positions)
263
- x = x + group_pe * 0.1
264
-
265
- elif i == 4:
266
- positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(batch_size, -1)
267
- group_pe = self.group_pe_layer4(positions)
268
- x = x + group_pe * 0.1
269
-
270
- # Transformer layer - properly handle mask
271
- if attention_mask is not None:
272
- key_padding_mask = (attention_mask == 0)
273
- x = layer_dict['transformer'](x, src_key_padding_mask=key_padding_mask)
274
  else:
275
- x = layer_dict['transformer'](x)
276
- x = layer_dict['norm'](x)
277
- all_hidden_states.append(x)
278
-
279
- # Pool for sequence representation
280
- if attention_mask is not None:
281
- # Masked mean pooling - attention_mask is (batch, seq)
282
- mask = attention_mask.unsqueeze(-1) # (batch, seq, 1)
283
- pooled = (x * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
284
  else:
285
- pooled = x.mean(dim=1)
286
-
287
- return {
288
- 'last_hidden_state': x,
289
- 'pooled_output': pooled,
290
- 'all_hidden_states': all_hidden_states,
291
- # v6.1 boundary predictions
292
- 'char_boundaries': char_boundaries,
293
- 'char_boundary_weights': char_boundary_weights,
294
- 'eojeol_boundaries': eojeol_boundaries,
295
- 'phrase_boundaries': phrase_boundaries,
296
- 'discovered_patterns': discovered_patterns
297
- }
298
 
 
 
 
 
299
 
300
- class CrossAttention(nn.Module):
301
- """
302
- Enhanced Cross-attention for relation learning between sequences
303
- 추론 레이어 연결을 위한 강화된 관계 학습
304
- """
305
-
306
- def __init__(self, hidden_dim: int = 1280, num_heads: int = 20, dropout: float = 0.1):
307
- super().__init__()
308
 
309
- # v6.1: Adjusted for 1280d (64 per head with 20 heads)
310
- self.cross_attn = nn.MultiheadAttention(
311
- hidden_dim, num_heads, dropout, batch_first=True
312
- )
313
-
314
- # v6.1: Enhanced relation classifier with reconstruction focus
315
- # 0: identity (완벽한 복원), 1: similar, 2: different, 3: continuation
316
- # 4: translation, 5: summary, 6: expansion, 7: contradiction
317
- self.relation_head = nn.Sequential(
318
- nn.Linear(hidden_dim * 2, hidden_dim),
319
- nn.GELU(),
320
- nn.Dropout(dropout),
321
- nn.Linear(hidden_dim, hidden_dim // 2),
322
- nn.GELU(),
323
- nn.Dropout(dropout),
324
- nn.Linear(hidden_dim // 2, 8)
325
- )
 
 
 
 
 
326
 
327
- # v6.1: Reconstruction-specific attention (복원 전용 어텐션)
328
- # Use 10 heads for reconstruction (128 per head)
329
- self.reconstruction_attn = nn.MultiheadAttention(
330
- hidden_dim, 10, dropout * 0.5, batch_first=True
331
- )
332
-
333
- # Gating mechanism for adaptive fusion
334
- self.gate = nn.Sequential(
335
- nn.Linear(hidden_dim * 2, hidden_dim),
336
- nn.Sigmoid()
337
- )
338
-
339
- self.norm1 = nn.LayerNorm(hidden_dim)
340
- self.norm2 = nn.LayerNorm(hidden_dim)
341
-
342
- def forward(
343
- self,
344
- query: torch.Tensor,
345
- key: torch.Tensor,
346
- query_mask: Optional[torch.Tensor] = None,
347
- key_mask: Optional[torch.Tensor] = None
348
- ) -> Dict[str, torch.Tensor]:
349
- # Normalize inputs
350
- query_norm = self.norm1(query)
351
- key_norm = self.norm2(key)
352
-
353
- # Fix key_mask dimension if needed
354
- if key_mask is not None:
355
- # Ensure key_mask matches key sequence length
356
- if key_mask.dim() == 2 and key_mask.size(1) != key.size(1):
357
- # Create new mask with correct dimensions
358
- batch_size = key.size(0)
359
- seq_len = key.size(1)
360
- key_mask = torch.ones(batch_size, seq_len, dtype=key_mask.dtype, device=key_mask.device)
361
-
362
- # Cross attention
363
- attn_output, attn_weights = self.cross_attn(
364
- query_norm, key_norm, key_norm,
365
- key_padding_mask=(key_mask == 0) if key_mask is not None else None
366
- )
367
-
368
- # Residual connection
369
- attn_output = attn_output + query
370
-
371
- # v6.1: Reconstruction-focused attention (복원 최적화)
372
- recon_output, recon_weights = self.reconstruction_attn(
373
- query_norm, query_norm, query_norm, # Self-attention for consistency
374
- key_padding_mask=(query_mask == 0) if query_mask is not None else None
375
- )
376
 
377
- # Combine cross and reconstruction attention
378
- combined_attn = attn_output * 0.7 + recon_output * 0.3
379
-
380
- # Adaptive gating for fusion
381
- gate_input = torch.cat([query.mean(dim=1), key.mean(dim=1)], dim=-1)
382
- gate_weights = self.gate(gate_input).unsqueeze(1)
383
-
384
- # Gated fusion: 적응적으로 attention 결과 조절
385
- fused_output = gate_weights * combined_attn + (1 - gate_weights) * query
386
-
387
- # Pool for relation classification
388
- query_pooled = query.mean(dim=1) if query_mask is None else \
389
- (query * query_mask.unsqueeze(-1)).sum(1) / query_mask.sum(1, keepdim=True).clamp(min=1e-9)
390
- key_pooled = key.mean(dim=1) if key_mask is None else \
391
- (key * key_mask.unsqueeze(-1)).sum(1) / key_mask.sum(1, keepdim=True).clamp(min=1e-9)
392
-
393
- # Classify relations with enhanced head
394
- combined = torch.cat([query_pooled, key_pooled], dim=-1)
395
- relation_logits = self.relation_head(combined)
396
-
397
- return {
398
- 'cross_attention': fused_output, # Gated fusion output
399
- 'attention_weights': attn_weights,
400
- 'reconstruction_weights': recon_weights, # v6.1: 복원 어텐션 가중치
401
- 'relation_logits': relation_logits,
402
- 'gate_weights': gate_weights.squeeze(1), # For analysis
403
- 'reconstruction_score': F.softmax(relation_logits, dim=-1)[:, 0] # identity 확률 (복원도)
404
- }
405
 
 
 
 
406
 
407
- class TransformerDecoder(nn.Module):
408
- """
409
- Transformer Decoder with Positional Encoding
410
- """
411
 
412
- def __init__(
413
- self,
414
- vocab_size: int = 260,
415
- hidden_dim: int = 1280, # v6.1: Match final encoder dim
416
- num_heads: int = 16, # v6.1: 1280/16 = 80 per head
417
- num_layers: int = 8, # v6.1 FINAL: 8 layers for better reconstruction
418
- dropout: float = 0.1,
419
- max_seq_len: int = 64 # v6.1.2: 64 chunk for compression-first
420
- ):
421
- super().__init__()
422
-
423
- # Token embedding
424
- self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
425
-
426
- # Positional encoding
427
- self.pos_encoding = PositionalEncoding(hidden_dim, max_seq_len, dropout)
428
-
429
- # Transformer decoder
430
- decoder_layer = nn.TransformerDecoderLayer(
431
- d_model=hidden_dim,
432
- nhead=num_heads,
433
- dim_feedforward=hidden_dim * 4,
434
- dropout=dropout,
435
- activation='gelu',
436
- batch_first=True,
437
- norm_first=True
438
- )
439
-
440
- self.transformer = nn.TransformerDecoder(decoder_layer, num_layers)
441
-
442
- # Output projection
443
- self.output_projection = nn.Linear(hidden_dim, vocab_size)
444
-
445
- self.hidden_dim = hidden_dim
446
- self.vocab_size = vocab_size
447
-
448
- def forward(
449
- self,
450
- encoder_hidden: torch.Tensor,
451
- decoder_input_ids: Optional[torch.Tensor] = None,
452
- encoder_mask: Optional[torch.Tensor] = None,
453
- decoder_mask: Optional[torch.Tensor] = None
454
- ) -> Dict[str, torch.Tensor]:
455
- batch_size = encoder_hidden.size(0)
456
-
457
- # Start with BOS if no input
458
- if decoder_input_ids is None:
459
- decoder_input_ids = torch.full((batch_size, 1), 257, device=encoder_hidden.device)
460
-
461
- # Embed and add positional encoding
462
- dec_seq_len = decoder_input_ids.size(1)
463
- x = self.token_embedding(decoder_input_ids)
464
- x = self.pos_encoding(x)
465
-
466
- # Create causal mask
467
- causal_mask = torch.triu(
468
- torch.ones(dec_seq_len, dec_seq_len, device=x.device) * float('-inf'),
469
- diagonal=1
470
- )
471
-
472
- # Decoder forward - handle variable-length encoder outputs
473
- # The encoder may compress the sequence, so memory (encoder_hidden) might be shorter
474
- # than the decoder sequence. This is expected and correct behavior.
475
- enc_seq_len = encoder_hidden.size(1)
476
-
477
- # Adjust encoder mask if needed
478
- if encoder_mask is not None:
479
- if encoder_mask.size(1) != enc_seq_len:
480
- # Encoder compressed the sequence, create new mask for compressed length
481
- # All compressed positions are valid (not masked)
482
- memory_key_padding_mask = torch.zeros(
483
- encoder_hidden.size(0), enc_seq_len,
484
- dtype=torch.bool, device=encoder_hidden.device
485
- )
486
- else:
487
- memory_key_padding_mask = (encoder_mask == 0)
488
  else:
489
- memory_key_padding_mask = None
490
-
491
- # Decoder attends to compressed encoder states via cross-attention
492
- # This naturally handles different sequence lengths
493
- decoder_output = self.transformer(
494
- tgt=x, # Decoder sequence (original length)
495
- memory=encoder_hidden, # Encoder sequence (possibly compressed)
496
- tgt_mask=causal_mask,
497
- memory_key_padding_mask=memory_key_padding_mask,
498
- tgt_key_padding_mask=(decoder_mask == 0) if decoder_mask is not None else None
499
- )
500
-
501
- # Project to vocabulary
502
- logits = self.output_projection(decoder_output)
503
-
504
- return {
505
- 'logits': logits,
506
- 'hidden_states': decoder_output
507
- }
508
-
509
- @torch.no_grad()
510
- def generate(
511
- self,
512
- encoder_hidden: torch.Tensor,
513
- encoder_mask: Optional[torch.Tensor] = None,
514
- max_length: int = 128,
515
- temperature: float = 0.1, # 토크나이저는 보수적 생성 (정확한 복원)
516
- top_k: int = 10, # 상위 10개만 고려
517
- top_p: float = 0.95
518
- ) -> torch.Tensor:
519
- batch_size = encoder_hidden.size(0)
520
- device = encoder_hidden.device
521
 
522
- # Start with BOS
523
- decoder_input_ids = torch.full((batch_size, 1), 257, device=device)
524
 
525
- # Track which sequences are done
526
- finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
 
527
 
528
- for _ in range(max_length - 1):
529
- # Forward pass
530
- outputs = self.forward(encoder_hidden, decoder_input_ids, encoder_mask)
531
- next_token_logits = outputs['logits'][:, -1, :] / temperature
532
 
533
- # Top-k filtering
534
- if top_k > 0:
535
- indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
536
- next_token_logits[indices_to_remove] = float('-inf')
 
 
 
537
 
538
- # Top-p filtering
539
- if top_p < 1.0:
540
- sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
541
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
542
 
543
- sorted_indices_to_remove = cumulative_probs > top_p
544
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
545
- sorted_indices_to_remove[..., 0] = 0
 
 
 
546
 
547
- indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
548
- next_token_logits[indices_to_remove] = float('-inf')
 
 
 
 
 
549
 
550
- # Sample
551
- probs = F.softmax(next_token_logits, dim=-1)
552
- next_tokens = torch.multinomial(probs, 1)
553
 
554
- # For finished sequences, force PAD token
555
- next_tokens[finished] = 256 # PAD token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556
 
557
- decoder_input_ids = torch.cat([decoder_input_ids, next_tokens], dim=-1)
 
558
 
559
- # Update finished status
560
- finished = finished | (next_tokens.squeeze(-1) == 258) # Mark as finished if EOS
 
 
 
 
 
561
 
562
- # Stop when all sequences are done
563
- if finished.all():
564
- break
565
 
566
- return decoder_input_ids
 
 
 
567
 
 
 
 
568
 
569
- class IntelligentTokenizerModelV61(nn.Module):
570
- """
571
- Complete Intelligent Tokenizer Model v6.1
572
- Pure learning-based with curriculum learning
573
- - No language labels during training
574
- - Curriculum learning for boundaries
575
- - Group-aware position encodings
576
- """
577
 
578
- def __init__(
579
- self,
580
- vocab_size: int = 260,
581
- encoder_dims: List[int] = [768, 896, 1024, 1152, 1280], # v6.1 dimensions
582
- encoder_heads: List[int] = [12, 14, 16, 18, 20], # v6.1: Optimal heads per layer
583
- decoder_hidden: int = 1280, # Match final encoder dim
584
- decoder_heads: int = 16, # v6.1: 80 per head for decoder
585
- num_decoder_layers: int = 8, # v6.1 FINAL: 8 layers for better reconstruction
586
- dropout: float = 0.1,
587
- max_seq_len: int = 64 # v6.1.2: 64 chunk for compression-first
588
- ):
589
- super().__init__()
590
 
591
- # v6.1 Components with optimized head counts
592
- self.tokenizer = ByteTokenizer(max_seq_len)
593
- self.encoder = ByteEncoderV61(vocab_size, encoder_dims, encoder_heads, dropout, max_seq_len)
594
- self.decoder = TransformerDecoder(vocab_size, decoder_hidden, decoder_heads, num_decoder_layers, dropout, max_seq_len)
595
- self.cross_attention = CrossAttention(encoder_dims[-1], 20, dropout) # 20 heads for 1280d
596
-
597
- def forward(
598
- self,
599
- input_texts: Optional[List[str]] = None,
600
- input_ids: Optional[torch.Tensor] = None,
601
- attention_mask: Optional[torch.Tensor] = None,
602
- decoder_input_ids: Optional[torch.Tensor] = None,
603
- labels: Optional[torch.Tensor] = None,
604
- boundary_labels: Optional[torch.Tensor] = None, # v6.1: for curriculum learning
605
- epoch: int = 0, # v6.1: for curriculum schedule
606
- use_cross_attention: bool = True
607
- ) -> Dict[str, torch.Tensor]:
608
- # Tokenize if text input
609
- if input_texts is not None:
610
- tokenized = self.tokenizer.encode_batch(input_texts)
611
- input_ids = tokenized['input_ids']
612
- attention_mask = tokenized['attention_mask']
613
-
614
- # 시퀀스 길이 체크 및 조정
615
- batch_size, seq_len = input_ids.shape
616
- device = input_ids.device
617
-
618
- # v6.1: Encode with curriculum learning
619
- encoder_outputs = self.encoder(input_ids, attention_mask, boundary_labels, epoch)
620
- encoder_hidden = encoder_outputs['last_hidden_state'] # v6.1: [batch, seq, 1280]
621
-
622
- # v6.1: 차원 확인 - 최종 차원은 1280
623
- assert encoder_hidden.size(-1) == 1280, f"Encoder dim mismatch: {encoder_hidden.size(-1)}"
624
-
625
- # Prepare decoder input for teacher forcing during training
626
- if decoder_input_ids is None:
627
- if labels is not None:
628
- # During training, use shifted labels as decoder input (teacher forcing)
629
- # Add BOS at the beginning and remove last token
630
- bos_tokens = torch.full((batch_size, 1), self.tokenizer.BOS, device=labels.device, dtype=labels.dtype)
631
- decoder_input_ids = torch.cat([bos_tokens, labels[:, :-1]], dim=1)
632
- else:
633
- # For inference/test, start with BOS token
634
- decoder_input_ids = torch.full((batch_size, 1), self.tokenizer.BOS, device=device, dtype=torch.long)
635
-
636
- # Decode
637
- decoder_outputs = self.decoder(
638
- encoder_hidden,
639
- decoder_input_ids,
640
- attention_mask
641
- )
642
- decoder_hidden = decoder_outputs['hidden_states'] # [batch, seq, 768]
643
-
644
- # Cross-Attention (마지막 레이어에서 관계 학습)
645
- cross_attn_outputs = None
646
- relation_logits = None
647
-
648
- if use_cross_attention and decoder_hidden is not None:
649
- # 디코더 출력과 인코더 출력 간 크로스어텐션
650
- cross_attn_outputs = self.cross_attention(
651
- query=decoder_hidden, # 디코더가 query
652
- key=encoder_hidden, # 인코더가 key/value
653
- query_mask=None, # decoder mask는 causal이므로 별도 처리
654
- key_mask=attention_mask
655
- )
656
-
657
- # 관계 학습 결과
658
- relation_logits = cross_attn_outputs['relation_logits']
659
-
660
- # Cross-attention으로 강화된 디코더 표현
661
- enhanced_decoder = decoder_hidden + cross_attn_outputs['cross_attention']
662
-
663
- # 최종 로짓 재계산 (cross-attention 적용 후)
664
- if hasattr(self.decoder, 'output_projection'):
665
- decoder_outputs['logits'] = self.decoder.output_projection(enhanced_decoder)
666
-
667
- # Calculate loss if labels provided
668
- loss = None
669
- if labels is not None:
670
- # Reconstruction loss
671
- loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer.PAD)
672
- recon_loss = loss_fct(
673
- decoder_outputs['logits'].reshape(-1, decoder_outputs['logits'].size(-1)),
674
- labels.reshape(-1)
675
- )
676
 
677
- # Boundary loss (if boundary labels provided)
678
- boundary_loss = 0
679
- if boundary_labels is not None and encoder_outputs.get('eojeol_boundaries') is not None:
680
- # Eojeol boundary loss
681
- eojeol_boundaries = encoder_outputs['eojeol_boundaries'] # [batch, seq, 4]
682
- if eojeol_boundaries.size(1) == boundary_labels.size(1):
683
- # Ensure boundary labels are in valid range (0-3)
684
- # Clamp to valid range to prevent CUDA errors
685
- boundary_labels_clamped = torch.clamp(boundary_labels, min=0, max=3)
686
-
687
- boundary_loss_fct = nn.CrossEntropyLoss(ignore_index=-1) # Use -1 for padding
688
- boundary_loss = boundary_loss_fct(
689
- eojeol_boundaries.reshape(-1, 4),
690
- boundary_labels_clamped.reshape(-1)
691
- ) * 0.5 # Weight for boundary loss
692
-
693
- # Relation loss (if cross-attention used)
694
- relation_loss = 0
695
- if relation_logits is not None:
696
- # 자기 관계는 identity (class 0)여야 함
697
- batch_identity = torch.zeros(batch_size, dtype=torch.long, device=device)
698
- relation_loss = F.cross_entropy(relation_logits, batch_identity) * 0.1
699
-
700
- loss = recon_loss + boundary_loss + relation_loss
701
-
702
- return {
703
- 'loss': loss,
704
- 'logits': decoder_outputs['logits'],
705
- 'decoder_logits': decoder_outputs['logits'], # Add for compatibility
706
- 'encoder_hidden_states': encoder_hidden,
707
- 'decoder_hidden_states': decoder_hidden,
708
- 'pooled_output': encoder_outputs['pooled_output'],
709
- 'cross_attention': cross_attn_outputs['cross_attention'] if cross_attn_outputs else None,
710
- 'relation_logits': relation_logits,
711
- 'all_encoder_states': encoder_outputs.get('all_hidden_states', None),
712
- # Add boundary predictions for visualization
713
- 'char_boundaries': encoder_outputs.get('char_boundaries'),
714
- 'eojeol_boundaries': encoder_outputs.get('eojeol_boundaries'),
715
- 'phrase_boundaries': encoder_outputs.get('phrase_boundaries'),
716
- 'discovered_patterns': encoder_outputs.get('discovered_patterns')
717
  }
718
-
719
- def encode_text(self, text: str) -> torch.Tensor:
720
- """Encode single text to representation"""
721
- tokenized = self.tokenizer.encode(text)
722
- # Move to same device as model
723
- device = next(self.parameters()).device
724
- input_ids = tokenized['input_ids'].unsqueeze(0).to(device)
725
- attention_mask = tokenized['attention_mask'].unsqueeze(0).to(device)
726
-
727
- with torch.no_grad():
728
- outputs = self.encoder(input_ids, attention_mask)
729
-
730
- return outputs['pooled_output'].squeeze(0)
731
-
732
- def decode_representation(self, representation: torch.Tensor, max_length: int = 128) -> str:
733
- """Decode representation back to text"""
734
- if representation.dim() == 1:
735
- representation = representation.unsqueeze(0).unsqueeze(0)
736
- elif representation.dim() == 2:
737
- representation = representation.unsqueeze(1)
738
-
739
- with torch.no_grad():
740
- output_ids = self.decoder.generate(representation, max_length=max_length)
741
-
742
- text = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
743
- return text
744
-
745
- def compute_relation(self, text1: str, text2: str) -> torch.Tensor:
746
- """Compute relation between two texts"""
747
- # Encode both texts
748
- enc1 = self.encode_text(text1).unsqueeze(0).unsqueeze(0)
749
- enc2 = self.encode_text(text2).unsqueeze(0).unsqueeze(0)
750
-
751
- # Compute cross-attention and relations
752
- with torch.no_grad():
753
- outputs = self.cross_attention(enc1, enc2)
754
-
755
- return F.softmax(outputs['relation_logits'], dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Intelligent Tokenizer v6.2.0 - Unified Model
3
+ Integrates encoder, decoder, and tokenizer with all GPT improvements
 
 
 
 
4
  """
5
 
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
 
9
  from typing import Dict, List, Optional, Tuple, Union
10
+ import math
11
 
12
+ # Import our components
13
+ try:
14
+ from .encoder import EncoderV62
15
+ from .decoder import DecoderV62
16
+ from .tokenizer import ByteTokenizerV62
17
+ except ImportError:
18
+ # For standalone testing
19
+ from encoder import EncoderV62
20
+ from decoder import DecoderV62
21
+ from tokenizer import ByteTokenizerV62
22
 
23
+
24
+ class IntelligentTokenizerV62(nn.Module):
25
  """
26
+ Complete v6.2.0 model with progressive splitting and optimizations
27
+
28
+ Key features:
29
+ - 48-byte chunks (46+2 with BOS/EOS)
30
+ - Progressive splitting: 48→1→N→M tokens
31
+ - Multi-level cross-attention
32
+ - KV cache optimization (8x reduction)
33
+ - All GPT-5 improvements integrated
34
  """
35
+
36
+ def __init__(self, config: Optional[Dict] = None):
37
  super().__init__()
38
+
39
+ # Default configuration
40
+ self.config = config or {}
41
+
42
+ # Model components
43
+ self.tokenizer = ByteTokenizerV62(config)
44
+ self.encoder = EncoderV62(config)
45
+ self.decoder = DecoderV62(config)
46
+
47
+ # Training configuration
48
+ self.compression_weight = 0.1
49
+ self.reconstruction_weight = 0.1
50
+ self.boundary_weight = 0.1
51
+
52
+ # Monitoring
53
+ self.register_buffer('training_step', torch.tensor(0))
54
+ self.register_buffer('current_epoch', torch.tensor(0))
55
+
56
+ def forward(self,
57
+ input_ids: torch.Tensor = None,
58
+ attention_mask: torch.Tensor = None,
59
+ labels: torch.Tensor = None,
60
+ text: str = None,
61
+ return_loss: bool = True,
62
+ temperature: float = 1.0) -> Dict[str, torch.Tensor]:
63
  """
64
+ Unified forward pass
65
+
66
  Args:
67
+ input_ids: Pre-tokenized input (optional)
68
+ attention_mask: Attention mask (optional)
69
+ labels: Target labels for training (optional)
70
+ text: Raw text input (alternative to input_ids)
71
+ return_loss: Whether to compute loss
72
+ temperature: Temperature for Gumbel-Softmax in encoder
73
+
74
+ Returns:
75
+ Dictionary with model outputs
76
  """
77
+ # Handle text input
78
+ if text is not None:
79
+ encoded = self.tokenizer.encode(text, add_special_tokens=True)
80
+ input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
81
+ attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
82
+
83
+ # Handle string passed as input_ids (common mistake)
84
+ if isinstance(input_ids, str):
85
+ text = input_ids
86
+ encoded = self.tokenizer.encode(text, add_special_tokens=True)
87
+ input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
88
+ attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
89
+
90
+ # Ensure tensors are on the right device
91
+ device = next(self.parameters()).device
92
+ if input_ids is not None and torch.is_tensor(input_ids):
93
+ input_ids = input_ids.to(device)
94
+ if attention_mask is not None and torch.is_tensor(attention_mask):
95
+ attention_mask = attention_mask.to(device)
96
+ if labels is not None and torch.is_tensor(labels):
97
+ labels = labels.to(device)
98
+
99
+ # Encoder forward pass with temperature for Gumbel annealing
100
+ encoder_outputs = self.encoder(
101
+ input_ids=input_ids,
102
+ attention_mask=attention_mask,
103
+ temperature=temperature
104
+ )
105
 
106
+ # Decoder forward pass
107
+ if labels is not None:
108
+ # Training mode with teacher forcing (GPT suggestion: shift by 1)
109
+ # Input: labels[:-1], Target: labels[1:]
110
+ decoder_input = labels[:, :-1] if labels.dim() > 1 else labels[:-1]
111
+ decoder_mask = attention_mask[:, :-1] if attention_mask is not None and attention_mask.dim() > 1 else None
112
+
113
+ decoder_outputs = self.decoder(
114
+ encoder_all_hidden=encoder_outputs['all_hidden_states'],
115
+ decoder_input_ids=decoder_input,
116
+ attention_mask=decoder_mask
117
+ )
118
+ else:
119
+ # Inference mode (without teacher forcing)
120
+ # For now, fallback to using input as labels for stable training
121
+ # TODO: Implement proper autoregressive generation
122
+ if return_loss and input_ids is not None:
123
+ labels = input_ids # Use input as both input and target
124
+ decoder_input = labels[:, :-1] if labels.dim() > 1 else labels[:-1]
125
+ decoder_mask = attention_mask[:, :-1] if attention_mask is not None and attention_mask.dim() > 1 else None
126
+
127
+ decoder_outputs = self.decoder(
128
+ encoder_all_hidden=encoder_outputs['all_hidden_states'],
129
+ decoder_input_ids=decoder_input,
130
+ attention_mask=decoder_mask
131
+ )
132
+ else:
133
+ decoder_outputs = self.decoder(
134
+ encoder_all_hidden=encoder_outputs['all_hidden_states'],
135
+ decoder_input_ids=None,
136
+ attention_mask=attention_mask
137
+ )
138
 
139
+ # Combine outputs with prefix to avoid key collision (GPT suggestion)
140
+ outputs = {}
141
+ for key, value in encoder_outputs.items():
142
+ outputs[f'enc_{key}'] = value
143
+ for key, value in decoder_outputs.items():
144
+ outputs[f'dec_{key}'] = value
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ # Compute loss if requested
147
+ if return_loss and labels is not None:
148
+ loss = self.compute_loss(outputs, labels, attention_mask)
149
+ outputs['loss'] = loss
 
 
 
 
 
 
 
 
150
 
151
+ return outputs
 
 
152
 
153
+ def compute_loss(self,
154
+ outputs: Dict[str, torch.Tensor],
155
+ labels: torch.Tensor,
156
+ attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
157
+ """
158
+ Compute combined loss with multiple objectives
159
 
160
+ Components:
161
+ 1. Reconstruction loss (cross-entropy)
162
+ 2. Compression loss (encourage higher compression)
163
+ 3. Boundary loss (boundary prediction accuracy)
164
+ """
165
+ losses = {}
166
+
167
+ # 1. Reconstruction loss (GPT suggestion: use shifted targets)
168
+ if 'dec_logits' in outputs:
169
+ logits = outputs['dec_logits']
170
+
171
+ # Shift targets for next-token prediction
172
+ target_labels = labels[:, 1:] if labels.dim() > 1 else labels[1:]
173
+ target_mask = attention_mask[:, 1:] if attention_mask is not None and attention_mask.dim() > 1 else None
174
+
175
+ # Reshape for cross-entropy
176
+ batch_size, seq_len, vocab_size = logits.shape
177
+ logits_flat = logits.reshape(-1, vocab_size)
178
+ labels_flat = target_labels.reshape(-1)
179
+
180
+ # Mask out padding (GPT suggestion: use bool mask)
181
+ if target_mask is not None:
182
+ mask_flat = target_mask.reshape(-1).bool()
183
+ reconstruction_loss = F.cross_entropy(
184
+ logits_flat[mask_flat],
185
+ labels_flat[mask_flat],
186
+ ignore_index=self.tokenizer.PAD,
187
+ label_smoothing=0.1 # Added label smoothing
188
+ )
189
+ else:
190
+ reconstruction_loss = F.cross_entropy(
191
+ logits_flat,
192
+ labels_flat,
193
+ ignore_index=self.tokenizer.PAD,
194
+ label_smoothing=0.1
195
+ )
196
 
197
+ losses['reconstruction'] = reconstruction_loss * self.reconstruction_weight
 
198
 
199
+ # 2. Compression loss (GPT suggestion: use proper device tensor creation)
200
+ if 'enc_compression_ratio' in outputs:
201
+ # Target compression ratio (e.g., 24:1 as per config)
202
+ target_ratio = 24.0
203
+ current_ratio = outputs['enc_compression_ratio']
204
 
205
+ # Create tensors on same device (GPT suggestion)
206
+ if isinstance(current_ratio, (int, float)):
207
+ current_ratio_tensor = labels.new_tensor(current_ratio, dtype=torch.float32)
208
+ else:
209
+ current_ratio_tensor = current_ratio.float()
210
+ target_ratio_tensor = labels.new_tensor(target_ratio, dtype=torch.float32)
211
+
212
+ # Penalize deviation from target (use smooth L1 to avoid explosion)
213
+ compression_loss = F.smooth_l1_loss(
214
+ current_ratio_tensor,
215
+ target_ratio_tensor,
216
+ beta=2.0 # Transition point from L2 to L1
217
+ )
218
+
219
+ losses['compression'] = compression_loss * self.compression_weight
220
+
221
+ # 3. Boundary loss (GPT suggestion: more meaningful boundary learning)
222
+ if 'enc_boundaries' in outputs and outputs['enc_boundaries'] is not None:
223
+ boundary_scores = outputs['enc_boundaries']
224
+
225
+ # Boundary sparsity + smoothness (GPT suggestion)
226
+ # Encourage sparse but clear boundaries
227
+ boundary_probs = torch.sigmoid(boundary_scores)
228
 
229
+ # Sparsity loss (boundaries should be rare)
230
+ sparsity_loss = boundary_probs.mean() * 0.1
 
 
 
231
 
232
+ # Smoothness loss (adjacent boundaries should be different)
233
+ if boundary_scores.size(1) > 1:
234
+ diff = boundary_scores[:, 1:] - boundary_scores[:, :-1]
235
+ smoothness_loss = (diff ** 2).mean() * 0.01
236
  else:
237
+ smoothness_loss = 0.0
238
+
239
+ boundary_loss = sparsity_loss + smoothness_loss
240
+
241
+ losses['boundary'] = boundary_loss * self.boundary_weight
242
+
243
+ # Combine all losses
244
+ total_loss = sum(losses.values())
245
+
246
+ # Store individual losses for monitoring
247
+ self.last_losses = losses
248
+
249
+ return total_loss
250
+
251
+ def generate(self,
252
+ text: str = None,
253
+ input_ids: torch.Tensor = None,
254
+ max_length: int = 48,
255
+ temperature: float = 0.1,
256
+ top_k: int = 10,
257
+ top_p: float = 0.95) -> str:
 
 
 
 
 
 
 
 
 
 
 
258
  """
259
+ Generate/reconstruct text
260
+
261
  Args:
262
+ text: Input text to encode and reconstruct
263
+ input_ids: Pre-encoded input
264
+ max_length: Maximum generation length
265
+ temperature: Sampling temperature
266
+ top_k: Top-k sampling
267
+ top_p: Top-p (nucleus) sampling
268
+
269
+ Returns:
270
+ Reconstructed/generated text
271
  """
272
+ # Encode input if text is provided (GPT suggestion: handle multi-chunk properly)
273
+ chunk_positions = None
274
+ if text is not None:
275
+ # Check if text needs chunking
276
+ if len(text.encode('utf-8')) > self.tokenizer.content_size:
277
+ encoded = self.tokenizer.encode(text, add_special_tokens=True, return_chunks=True)
278
+ chunk_positions = encoded.get('chunk_positions', None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  else:
280
+ encoded = self.tokenizer.encode(text, add_special_tokens=True)
281
+
282
+ input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
283
+ attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
 
 
 
 
 
284
  else:
285
+ attention_mask = (input_ids != self.tokenizer.PAD).bool() # GPT suggestion: bool mask
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
+ # Move to device
288
+ device = next(self.parameters()).device
289
+ input_ids = input_ids.to(device)
290
+ attention_mask = attention_mask.to(device)
291
 
292
+ # Encode
293
+ with torch.no_grad():
294
+ encoder_outputs = self.encoder(
295
+ input_ids=input_ids,
296
+ attention_mask=attention_mask
297
+ )
 
 
298
 
299
+ # Prepare all hidden states for decoder
300
+ if 'all_hidden_states' in encoder_outputs:
301
+ encoder_all_hidden = encoder_outputs['all_hidden_states']
302
+ else:
303
+ compressed = encoder_outputs.get('compressed', encoder_outputs.get('hidden_states'))
304
+ encoder_all_hidden = [compressed] * 4
305
+
306
+ # Autoregressive generation (fixed version)
307
+ batch_size = input_ids.size(0)
308
+
309
+ # Start with BOS token
310
+ generated_ids = torch.full((batch_size, 1), self.tokenizer.BOS, device=device)
311
+
312
+ for step in range(max_length - 1):
313
+ with torch.no_grad():
314
+ # Decode current sequence
315
+ decoder_outputs = self.decoder(
316
+ encoder_all_hidden=encoder_all_hidden,
317
+ decoder_input_ids=generated_ids,
318
+ attention_mask=torch.ones_like(generated_ids),
319
+ use_cache=False
320
+ )
321
 
322
+ # Get next token prediction
323
+ logits = decoder_outputs['logits'][:, -1, :] / temperature
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
+ # Top-k filtering
326
+ if top_k > 0:
327
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
328
+ logits[indices_to_remove] = float('-inf')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
+ # Sample next token
331
+ probs = F.softmax(logits, dim=-1)
332
+ next_token = torch.multinomial(probs, num_samples=1)
333
 
334
+ # Append to generated sequence
335
+ generated_ids = torch.cat([generated_ids, next_token], dim=1)
 
 
336
 
337
+ # Check for EOS
338
+ if (next_token == self.tokenizer.EOS).all():
339
+ break
340
+
341
+ # Decode to text (GPT suggestion: proper multi-chunk reconstruction)
342
+ if generated_ids.dim() > 2 and chunk_positions is not None:
343
+ # Multi-chunk output with positions
344
+ text = self.tokenizer.reconstruct(
345
+ generated_ids,
346
+ positions=chunk_positions,
347
+ overlap=self.tokenizer.chunk_overlap
348
+ )
349
+ elif generated_ids.dim() > 2:
350
+ # Multi-chunk without positions (fallback)
351
+ text = self.tokenizer.reconstruct(generated_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  else:
353
+ # Single sequence
354
+ text = self.tokenizer.decode(generated_ids[0] if generated_ids.dim() > 1 else generated_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
+ return text
 
357
 
358
+ def compress(self, text: str) -> Dict[str, Union[torch.Tensor, float]]:
359
+ """
360
+ Compress text and return compression statistics
361
 
362
+ Args:
363
+ text: Input text to compress
 
 
364
 
365
+ Returns:
366
+ Dictionary with compressed representation and statistics
367
+ """
368
+ # Encode text
369
+ encoded = self.tokenizer.encode(text, add_special_tokens=True)
370
+ input_ids = encoded['input_ids'].unsqueeze(0) if encoded['input_ids'].dim() == 1 else encoded['input_ids']
371
+ attention_mask = encoded['attention_mask'].unsqueeze(0) if encoded['attention_mask'].dim() == 1 else encoded['attention_mask']
372
 
373
+ # Move to device
374
+ device = next(self.parameters()).device
375
+ input_ids = input_ids.to(device)
376
+ attention_mask = attention_mask.to(device)
377
 
378
+ # Get compressed representation
379
+ with torch.no_grad():
380
+ encoder_outputs = self.encoder(
381
+ input_ids=input_ids,
382
+ attention_mask=attention_mask
383
+ )
384
 
385
+ return {
386
+ 'compressed': encoder_outputs['compressed'],
387
+ 'num_tokens': encoder_outputs['num_tokens'],
388
+ 'compression_ratio': encoder_outputs['compression_ratio'],
389
+ 'original_bytes': len(text.encode('utf-8')),
390
+ 'compressed_size': encoder_outputs['num_tokens'] * 2 # Approximate bytes
391
+ }
392
 
393
+ def update_training_state(self, epoch: int, step: int = 0, reconstruction_loss: float = None):
394
+ """
395
+ Update training state - adaptive, not phase-based
396
 
397
+ Args:
398
+ epoch: Current epoch
399
+ step: Current training step
400
+ reconstruction_loss: Current reconstruction quality
401
+ """
402
+ self.current_epoch = torch.tensor(epoch)
403
+ self.training_step = torch.tensor(step)
404
+
405
+ # Update encoder warmup (gates only)
406
+ self.encoder.set_warmup_step(step)
407
+
408
+ # Adaptive weight adjustment based on performance
409
+ if reconstruction_loss is not None:
410
+ # If reconstruction is poor, increase its weight
411
+ if reconstruction_loss > 1.0:
412
+ self.reconstruction_weight = 1.0
413
+ self.compression_weight = 0.1 # Less compression focus
414
+ else:
415
+ # Good reconstruction, can focus on compression
416
+ self.reconstruction_weight = 0.5
417
+ self.compression_weight = 0.1
418
 
419
+ # Boundary weight stays moderate
420
+ self.boundary_weight = 0.1
421
 
422
+ # Let encoder know about reconstruction quality
423
+ self.encoder.adaptive_compression_control(reconstruction_loss)
424
+ else:
425
+ # Default balanced weights
426
+ self.reconstruction_weight = 0.5
427
+ self.compression_weight = 0.1
428
+ self.boundary_weight = 0.1
429
 
430
+ def get_model_stats(self) -> Dict[str, float]:
431
+ """
432
+ Get model statistics for monitoring
433
 
434
+ Returns:
435
+ Dictionary with various model statistics
436
+ """
437
+ stats = {}
438
 
439
+ # Encoder stats (GPT suggestion: already prefixed)
440
+ encoder_stats = self.encoder.get_monitoring_stats()
441
+ stats.update({f'encoder_{k}': v for k, v in encoder_stats.items()})
442
 
443
+ # Decoder memory stats
444
+ decoder_memory = self.decoder.get_memory_usage()
445
+ stats.update({f'decoder_{k}': v for k, v in decoder_memory.items()})
 
 
 
 
 
446
 
447
+ # Loss stats (if available) - check for tensor items
448
+ if hasattr(self, 'last_losses'):
449
+ for k, v in self.last_losses.items():
450
+ if isinstance(v, torch.Tensor):
451
+ stats[f'loss_{k}'] = v.item() if v.numel() == 1 else v.mean().item()
452
+ else:
453
+ stats[f'loss_{k}'] = float(v)
 
 
 
 
 
454
 
455
+ # Training info
456
+ stats['current_epoch'] = self.current_epoch.item()
457
+ stats['training_step'] = self.training_step.item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
 
459
+ return stats
460
+
461
+ def save_checkpoint(self, path: str):
462
+ """
463
+ Save model checkpoint
464
+
465
+ Args:
466
+ path: Path to save checkpoint
467
+ """
468
+ checkpoint = {
469
+ 'model_state_dict': self.state_dict(),
470
+ 'config': self.config,
471
+ 'epoch': self.current_epoch.item(),
472
+ 'step': self.training_step.item(),
473
+ 'stats': self.get_model_stats()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  }
475
+ torch.save(checkpoint, path)
476
+ print(f"Checkpoint saved to {path}")
477
+
478
+ @classmethod
479
+ def from_checkpoint(cls, path: str, device: str = 'cuda'):
480
+ """
481
+ Load model from checkpoint
482
+
483
+ Args:
484
+ path: Path to checkpoint
485
+ device: Device to load model on
486
+
487
+ Returns:
488
+ Loaded model instance
489
+ """
490
+ checkpoint = torch.load(path, map_location=device)
491
+
492
+ # Create model with saved config
493
+ model = cls(checkpoint.get('config', {}))
494
+ model.load_state_dict(checkpoint['model_state_dict'])
495
+ model.to(device)
496
+
497
+ # Restore training state
498
+ if 'epoch' in checkpoint:
499
+ model.current_epoch = torch.tensor(checkpoint['epoch'])
500
+ if 'step' in checkpoint:
501
+ model.training_step = torch.tensor(checkpoint['step'])
502
+
503
+ print(f"Model loaded from {path} (Epoch {checkpoint.get('epoch', 0)})")
504
+ return model
505
+
506
+
507
+ if __name__ == "__main__":
508
+ # Test unified model
509
+ print("Testing Intelligent Tokenizer v6.2.0")
510
+
511
+ # Create model
512
+ model = IntelligentTokenizerV62()
513
+ print(f"Model created with {sum(p.numel() for p in model.parameters())/1e6:.1f}M parameters")
514
+
515
+ # Test texts
516
+ test_texts = [
517
+ "Hello, world!",
518
+ "안녕하세요, 만나서 반갑습니다. 오늘 날씨가 좋네요!",
519
+ "今天天气很好。",
520
+ ]
521
+
522
+ for text in test_texts:
523
+ print(f"\nInput: {text}")
524
+
525
+ # Compress
526
+ compression = model.compress(text)
527
+ print(f" Compression ratio: {compression['compression_ratio']:.1f}:1")
528
+ print(f" Tokens: {compression['num_tokens']}")
529
+
530
+ # Generate (reconstruct)
531
+ reconstructed = model.generate(text, temperature=0.1)
532
+ print(f" Reconstructed: {reconstructed}")
533
+
534
+ # Get model stats
535
+ stats = model.get_model_stats()
536
+ print(f"\nModel Statistics:")
537
+ for key, value in stats.items():
538
+ if isinstance(value, float):
539
+ print(f" {key}: {value:.4f}")
540
+ else:
541
+ print(f" {key}: {value}")