teszenofficial commited on
Commit
77d62e8
·
verified ·
1 Parent(s): 1bf463e

Upload 4 files

Browse files
Files changed (4) hide show
  1. config.yaml +64 -0
  2. dataset.py +98 -0
  3. model.py +291 -0
  4. tokenizer.py +138 -0
config.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MTP Mini - Configuración Mejorada para Generación Coherente
2
+
3
+ model:
4
+ vocab_size: 4000
5
+ d_model: 512 # Aumentado para más capacidad
6
+ n_layers: 8 # Más capas
7
+ n_heads: 8 # Más cabezas de atención
8
+ d_ff: 2048 # 4x d_model
9
+ max_seq_len: 512 # Contexto más largo
10
+ dropout: 0.2 # Más dropout para evitar overfitting
11
+ use_swiglu: true # Activación mejorada
12
+
13
+ training:
14
+ batch_size: 4 # Batch más pequeño para corpus pequeño
15
+ accumulation_steps: 4 # Effective batch = 16
16
+ epochs: 20 # MENOS épocas
17
+ learning_rate: 0.0003 # LR más alto para convergencia rápida
18
+ min_lr: 0.00001
19
+ weight_decay: 0.1 # MÁS weight decay para regularización
20
+ max_grad_norm: 1.0
21
+ num_threads: 4
22
+ save_every: 5
23
+
24
+ # Early stopping
25
+ patience: 5 # Parar si no mejora en 5 epochs
26
+ min_delta: 0.001 # Mejora mínima requerida
27
+
28
+ # Learning rate schedule
29
+ warmup_steps: 100
30
+ use_lr_scheduler: true
31
+
32
+ # Regularización adicional
33
+ label_smoothing: 0.1
34
+ use_eos_loss_weight: true # Dar más peso al token EOS
35
+
36
+ data:
37
+ corpus_path: corpus/mtp_mini_corpus.jsonl
38
+ min_text_length: 50 # Textos más largos
39
+ max_text_length: 2000 # Permitir respuestas largas
40
+ validation_split: 0.15
41
+
42
+ # Augmentación de datos
43
+ use_augmentation: true
44
+ augmentation_prob: 0.3
45
+
46
+ generation:
47
+ # Parámetros de generación mejorados
48
+ default_max_tokens: 150
49
+ default_temperature: 0.7
50
+ default_top_k: 40
51
+ default_top_p: 0.92
52
+ default_repetition_penalty: 1.15
53
+ min_response_length: 20
54
+ use_length_penalty: true
55
+
56
+ # Control de coherencia
57
+ use_perplexity_filter: true
58
+ max_perplexity: 100.0
59
+
60
+ # Stop sequences
61
+ stop_sequences:
62
+ - "###"
63
+ - "\n\n\n"
64
+ - "Instrucción:"
dataset.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import json
4
+ import random
5
+
6
+
7
+ class MTPDataset(Dataset):
8
+ """Dataset mejorado con augmentación de datos"""
9
+
10
+ def __init__(self, corpus_path, tokenizer, max_seq_len=512,
11
+ use_augmentation=False, augmentation_prob=0.3):
12
+ self.tokenizer = tokenizer
13
+ self.max_seq_len = max_seq_len
14
+ self.use_augmentation = use_augmentation
15
+ self.augmentation_prob = augmentation_prob
16
+ self.data = []
17
+
18
+ # Load corpus
19
+ with open(corpus_path, 'r', encoding='utf-8') as f:
20
+ for line in f:
21
+ entry = json.loads(line)
22
+ if 'instruction' in entry and 'response' in entry:
23
+ self.data.append(entry)
24
+
25
+ print(f"✓ Loaded {len(self.data)} examples from corpus")
26
+ if use_augmentation:
27
+ print(f"✓ Data augmentation enabled (prob={augmentation_prob})")
28
+
29
+ def __len__(self):
30
+ return len(self.data)
31
+
32
+ def augment_text(self, text):
33
+ """Augmentación simple de texto"""
34
+ if not self.use_augmentation or random.random() > self.augmentation_prob:
35
+ return text
36
+
37
+ # Variación 1: Agregar espacios aleatorios (simula variaciones en formato)
38
+ if random.random() < 0.3:
39
+ text = text.strip()
40
+
41
+ # Variación 2: Cambiar puntuación final
42
+ if random.random() < 0.2:
43
+ if text.endswith('.'):
44
+ text = text[:-1]
45
+ elif not text.endswith(('.', '!', '?')):
46
+ text = text + '.'
47
+
48
+ return text
49
+
50
+ def __getitem__(self, idx):
51
+ entry = self.data[idx]
52
+
53
+ instruction = entry['instruction']
54
+ response = entry['response']
55
+
56
+ # Aplicar augmentación
57
+ instruction = self.augment_text(instruction)
58
+ response = self.augment_text(response)
59
+
60
+ # Formato mejorado
61
+ full_text = f"### Instrucción:\n{instruction}\n\n### Respuesta:\n{response}"
62
+
63
+ # Tokenize
64
+ tokens = self.tokenizer.encode(full_text)
65
+
66
+ # Add BOS and EOS
67
+ tokens = [self.tokenizer.bos_id()] + tokens + [self.tokenizer.eos_id()]
68
+
69
+ # Truncate if too long
70
+ if len(tokens) > self.max_seq_len:
71
+ # Truncar manteniendo BOS y EOS
72
+ tokens = [tokens[0]] + tokens[1:self.max_seq_len-1] + [self.tokenizer.eos_id()]
73
+
74
+ # Convert to tensor
75
+ input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
76
+ target_ids = torch.tensor(tokens[1:], dtype=torch.long)
77
+
78
+ return input_ids, target_ids
79
+
80
+
81
+ def collate_fn(batch, pad_id=0):
82
+ """Custom collate function con padding inteligente"""
83
+ input_ids = [item[0] for item in batch]
84
+ target_ids = [item[1] for item in batch]
85
+
86
+ # Find max length in batch
87
+ max_len = max(len(ids) for ids in input_ids)
88
+
89
+ # Pad sequences
90
+ input_ids_padded = []
91
+ target_ids_padded = []
92
+
93
+ for inp, tgt in zip(input_ids, target_ids):
94
+ pad_len = max_len - len(inp)
95
+ input_ids_padded.append(torch.cat([inp, torch.full((pad_len,), pad_id, dtype=torch.long)]))
96
+ target_ids_padded.append(torch.cat([tgt, torch.full((pad_len,), pad_id, dtype=torch.long)]))
97
+
98
+ return torch.stack(input_ids_padded), torch.stack(target_ids_padded)
model.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+
7
+ class RotaryPositionalEmbedding(nn.Module):
8
+ """RoPE - Rotary Position Embedding"""
9
+
10
+ def __init__(self, dim, max_seq_len=2048, base=10000):
11
+ super().__init__()
12
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
13
+ self.register_buffer('inv_freq', inv_freq)
14
+ self.max_seq_len = max_seq_len
15
+
16
+ def forward(self, seq_len, device):
17
+ t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
18
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
19
+ emb = torch.cat((freqs, freqs), dim=-1)
20
+ return emb.cos(), emb.sin()
21
+
22
+
23
+ def apply_rotary_pos_emb(q, k, cos, sin):
24
+ """Aplica RoPE a queries y keys"""
25
+ def rotate_half(x):
26
+ x1, x2 = x.chunk(2, dim=-1)
27
+ return torch.cat((-x2, x1), dim=-1)
28
+
29
+ q_embed = (q * cos) + (rotate_half(q) * sin)
30
+ k_embed = (k * cos) + (rotate_half(k) * sin)
31
+ return q_embed, k_embed
32
+
33
+
34
+ class MultiHeadSelfAttention(nn.Module):
35
+ """Multi-Head Self-Attention con RoPE y optimizaciones"""
36
+
37
+ def __init__(self, d_model, n_heads, dropout=0.1, max_seq_len=2048):
38
+ super().__init__()
39
+ assert d_model % n_heads == 0
40
+
41
+ self.d_model = d_model
42
+ self.n_heads = n_heads
43
+ self.d_k = d_model // n_heads
44
+
45
+ self.q_linear = nn.Linear(d_model, d_model, bias=False)
46
+ self.k_linear = nn.Linear(d_model, d_model, bias=False)
47
+ self.v_linear = nn.Linear(d_model, d_model, bias=False)
48
+ self.out_linear = nn.Linear(d_model, d_model, bias=False)
49
+
50
+ self.dropout = nn.Dropout(dropout)
51
+ self.attn_dropout = nn.Dropout(dropout)
52
+ self.rope = RotaryPositionalEmbedding(self.d_k, max_seq_len)
53
+
54
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
55
+
56
+ def forward(self, x, mask=None):
57
+ batch_size, seq_len, d_model = x.size()
58
+
59
+ Q = self.q_linear(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
60
+ K = self.k_linear(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
61
+ V = self.v_linear(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
62
+
63
+ cos, sin = self.rope(seq_len, x.device)
64
+ cos = cos[None, None, :, :]
65
+ sin = sin[None, None, :, :]
66
+ Q, K = apply_rotary_pos_emb(Q, K, cos, sin)
67
+
68
+ if self.flash and mask is None:
69
+ context = F.scaled_dot_product_attention(
70
+ Q, K, V,
71
+ attn_mask=None,
72
+ dropout_p=self.dropout.p if self.training else 0.0,
73
+ is_causal=True
74
+ )
75
+ else:
76
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
77
+ if mask is not None:
78
+ scores = scores.masked_fill(mask == 0, float('-inf'))
79
+ attn_weights = F.softmax(scores, dim=-1)
80
+ attn_weights = self.attn_dropout(attn_weights)
81
+ context = torch.matmul(attn_weights, V)
82
+
83
+ context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
84
+ output = self.out_linear(context)
85
+ return self.dropout(output)
86
+
87
+
88
+ class SwiGLU(nn.Module):
89
+ """SwiGLU activation"""
90
+
91
+ def __init__(self, d_model, d_ff, dropout=0.1):
92
+ super().__init__()
93
+ self.w1 = nn.Linear(d_model, d_ff, bias=False)
94
+ self.w2 = nn.Linear(d_ff, d_model, bias=False)
95
+ self.w3 = nn.Linear(d_model, d_ff, bias=False)
96
+ self.dropout = nn.Dropout(dropout)
97
+
98
+ def forward(self, x):
99
+ return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))
100
+
101
+
102
+ class FeedForward(nn.Module):
103
+ """Feed-Forward estándar"""
104
+
105
+ def __init__(self, d_model, d_ff, dropout=0.1):
106
+ super().__init__()
107
+ self.linear1 = nn.Linear(d_model, d_ff)
108
+ self.linear2 = nn.Linear(d_ff, d_model)
109
+ self.dropout = nn.Dropout(dropout)
110
+
111
+ def forward(self, x):
112
+ return self.linear2(self.dropout(F.gelu(self.linear1(x))))
113
+
114
+
115
+ class RMSNorm(nn.Module):
116
+ """RMSNorm"""
117
+
118
+ def __init__(self, dim, eps=1e-6):
119
+ super().__init__()
120
+ self.eps = eps
121
+ self.weight = nn.Parameter(torch.ones(dim))
122
+
123
+ def forward(self, x):
124
+ norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
125
+ return x * norm * self.weight
126
+
127
+
128
+ class TransformerBlock(nn.Module):
129
+ """Transformer Block mejorado"""
130
+
131
+ def __init__(self, d_model, n_heads, d_ff, dropout=0.1, max_seq_len=2048, use_swiglu=True):
132
+ super().__init__()
133
+ self.attention = MultiHeadSelfAttention(d_model, n_heads, dropout, max_seq_len)
134
+
135
+ if use_swiglu:
136
+ self.feed_forward = SwiGLU(d_model, d_ff, dropout)
137
+ else:
138
+ self.feed_forward = FeedForward(d_model, d_ff, dropout)
139
+
140
+ self.norm1 = RMSNorm(d_model)
141
+ self.norm2 = RMSNorm(d_model)
142
+ self.dropout = nn.Dropout(dropout)
143
+
144
+ def forward(self, x, mask=None):
145
+ x = x + self.attention(self.norm1(x), mask)
146
+ x = x + self.feed_forward(self.norm2(x))
147
+ return x
148
+
149
+
150
+ class MTPMiniModel(nn.Module):
151
+ """MTP Mini - Modelo mejorado para generación coherente"""
152
+
153
+ def __init__(self, vocab_size, d_model=512, n_layers=8, n_heads=8,
154
+ d_ff=2048, max_seq_len=512, dropout=0.2, use_swiglu=True):
155
+ super().__init__()
156
+
157
+ self.vocab_size = vocab_size
158
+ self.d_model = d_model
159
+ self.max_seq_len = max_seq_len
160
+
161
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
162
+ self.dropout = nn.Dropout(dropout)
163
+
164
+ self.blocks = nn.ModuleList([
165
+ TransformerBlock(d_model, n_heads, d_ff, dropout, max_seq_len, use_swiglu)
166
+ for _ in range(n_layers)
167
+ ])
168
+
169
+ self.norm_f = RMSNorm(d_model)
170
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
171
+
172
+ # Weight tying
173
+ self.lm_head.weight = self.token_embedding.weight
174
+
175
+ self.apply(self._init_weights)
176
+
177
+ def _init_weights(self, module):
178
+ if isinstance(module, nn.Linear):
179
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
180
+ if module.bias is not None:
181
+ torch.nn.init.zeros_(module.bias)
182
+ elif isinstance(module, nn.Embedding):
183
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
184
+
185
+ def forward(self, input_ids, targets=None, use_eos_weight=False):
186
+ batch_size, seq_len = input_ids.size()
187
+
188
+ mask = torch.tril(torch.ones(seq_len, seq_len, device=input_ids.device)).view(1, 1, seq_len, seq_len)
189
+
190
+ x = self.dropout(self.token_embedding(input_ids))
191
+
192
+ for block in self.blocks:
193
+ x = block(x, mask)
194
+
195
+ x = self.norm_f(x)
196
+ logits = self.lm_head(x)
197
+
198
+ loss = None
199
+ if targets is not None:
200
+ if use_eos_weight:
201
+ # Dar más peso al token EOS para aprender a terminar
202
+ weights = torch.ones(self.vocab_size, device=logits.device)
203
+ weights[3] = 2.0 # EOS token
204
+ loss = F.cross_entropy(
205
+ logits.view(-1, self.vocab_size),
206
+ targets.view(-1),
207
+ weight=weights,
208
+ label_smoothing=0.1
209
+ )
210
+ else:
211
+ loss = F.cross_entropy(
212
+ logits.view(-1, self.vocab_size),
213
+ targets.view(-1),
214
+ label_smoothing=0.1
215
+ )
216
+
217
+ return logits, loss
218
+
219
+ def generate(self, input_ids, max_new_tokens=150, temperature=0.7,
220
+ top_k=40, top_p=0.92, repetition_penalty=1.15,
221
+ min_length=20, eos_token_id=3, stop_sequences=None):
222
+ """Generación mejorada con control de longitud y coherencia"""
223
+ self.eval()
224
+
225
+ generated = input_ids.clone()
226
+ generated_text_tokens = 0
227
+
228
+ with torch.no_grad():
229
+ for step in range(max_new_tokens):
230
+ input_ids_cond = generated if generated.size(1) <= self.max_seq_len else generated[:, -self.max_seq_len:]
231
+
232
+ logits, _ = self(input_ids_cond)
233
+ logits = logits[:, -1, :].clone()
234
+
235
+ # Repetition penalty mejorado
236
+ if repetition_penalty != 1.0:
237
+ for token_id in set(generated[0].tolist()):
238
+ if logits[0, token_id] < 0:
239
+ logits[0, token_id] *= repetition_penalty
240
+ else:
241
+ logits[0, token_id] /= repetition_penalty
242
+
243
+ # Penalizar tokens repetidos recientes más fuertemente
244
+ if generated.size(1) > 10:
245
+ recent_tokens = generated[0, -10:].tolist()
246
+ for token_id in set(recent_tokens):
247
+ count = recent_tokens.count(token_id)
248
+ if count > 2:
249
+ logits[0, token_id] -= count * 2.0
250
+
251
+ # No permitir EOS hasta longitud mínima
252
+ if generated_text_tokens < min_length:
253
+ logits[0, eos_token_id] = float('-inf')
254
+ else:
255
+ # Aumentar probabilidad de EOS gradualmente
256
+ eos_boost = (generated_text_tokens - min_length) * 0.1
257
+ logits[0, eos_token_id] += eos_boost
258
+
259
+ # Temperature
260
+ logits = logits / temperature
261
+
262
+ # Top-k
263
+ if top_k > 0:
264
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
265
+ logits[logits < v[:, [-1]]] = float('-inf')
266
+
267
+ # Top-p (nucleus)
268
+ if top_p < 1.0:
269
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
270
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
271
+ sorted_indices_to_remove = cumulative_probs > top_p
272
+ sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
273
+ sorted_indices_to_remove[:, 0] = 0
274
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
275
+ logits[indices_to_remove] = float('-inf')
276
+
277
+ # Sample
278
+ probs = F.softmax(logits, dim=-1)
279
+ next_token = torch.multinomial(probs, num_samples=1)
280
+
281
+ # Check for EOS
282
+ if next_token.item() == eos_token_id and generated_text_tokens >= min_length:
283
+ break
284
+
285
+ generated = torch.cat([generated, next_token], dim=1)
286
+ generated_text_tokens += 1
287
+
288
+ return generated
289
+
290
+ def count_parameters(self):
291
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
tokenizer.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sentencepiece as spm
2
+ import os
3
+ import json
4
+
5
+
6
+ class MTPTokenizer:
7
+ """Tokenizer using SentencePiece BPE"""
8
+
9
+ def __init__(self, model_path=None):
10
+ self.sp = None
11
+ self.model_path = model_path
12
+
13
+ if model_path and os.path.exists(model_path):
14
+ self.load(model_path)
15
+
16
+ def train(self, corpus_path, vocab_size=4000, model_prefix='mtp_tokenizer'):
17
+ """Train SentencePiece BPE tokenizer on corpus"""
18
+
19
+ # Extract text from JSONL corpus
20
+ texts = []
21
+ with open(corpus_path, 'r', encoding='utf-8') as f:
22
+ for line in f:
23
+ data = json.loads(line)
24
+ if 'instruction' in data:
25
+ texts.append(data['instruction'])
26
+ if 'response' in data:
27
+ texts.append(data['response'])
28
+
29
+ # Save temporary text file
30
+ temp_file = 'temp_corpus.txt'
31
+ with open(temp_file, 'w', encoding='utf-8') as f:
32
+ f.write('\n'.join(texts))
33
+
34
+ # Calculate optimal vocab size based on corpus
35
+ total_chars = sum(len(text) for text in texts)
36
+ max_vocab = min(vocab_size, int(total_chars * 0.15)) # Heuristic: ~15% of chars
37
+
38
+ print(f" → Corpus stats: {len(texts)} texts, {total_chars} characters")
39
+ print(f" → Adjusted vocab size: {max_vocab} (requested: {vocab_size})")
40
+
41
+ # Train SentencePiece with adjusted parameters
42
+ try:
43
+ spm.SentencePieceTrainer.train(
44
+ input=temp_file,
45
+ model_prefix=model_prefix,
46
+ vocab_size=max_vocab,
47
+ model_type='bpe',
48
+ pad_id=0,
49
+ unk_id=1,
50
+ bos_id=2,
51
+ eos_id=3,
52
+ character_coverage=1.0,
53
+ normalization_rule_name='identity',
54
+ num_threads=4,
55
+ split_digits=True,
56
+ allow_whitespace_only_pieces=False,
57
+ byte_fallback=False,
58
+ max_sentencepiece_length=16
59
+ )
60
+ except RuntimeError as e:
61
+ if "Vocabulary size too high" in str(e):
62
+ # Extract suggested max from error and retry
63
+ import re
64
+ match = re.search(r'value <= (\d+)', str(e))
65
+ if match:
66
+ suggested_max = int(match.group(1))
67
+ print(f" → Retrying with vocab size: {suggested_max}")
68
+ spm.SentencePieceTrainer.train(
69
+ input=temp_file,
70
+ model_prefix=model_prefix,
71
+ vocab_size=suggested_max,
72
+ model_type='bpe',
73
+ pad_id=0,
74
+ unk_id=1,
75
+ bos_id=2,
76
+ eos_id=3,
77
+ character_coverage=1.0,
78
+ normalization_rule_name='identity',
79
+ num_threads=4,
80
+ split_digits=True,
81
+ allow_whitespace_only_pieces=False,
82
+ byte_fallback=False,
83
+ max_sentencepiece_length=16
84
+ )
85
+ else:
86
+ raise
87
+ else:
88
+ raise
89
+
90
+ # Clean up
91
+ os.remove(temp_file)
92
+
93
+ # Load the trained model
94
+ self.model_path = f"{model_prefix}.model"
95
+ self.load(self.model_path)
96
+
97
+ print(f"✓ Tokenizer trained: {self.vocab_size()} tokens")
98
+ print(f"✓ Model saved: {self.model_path}")
99
+
100
+ def load(self, model_path):
101
+ """Load trained tokenizer"""
102
+ self.sp = spm.SentencePieceProcessor()
103
+ self.sp.load(model_path)
104
+ self.model_path = model_path
105
+
106
+ def encode(self, text):
107
+ """Encode text to token IDs"""
108
+ if self.sp is None:
109
+ raise ValueError("Tokenizer not loaded. Train or load a model first.")
110
+ return self.sp.encode_as_ids(text)
111
+
112
+ def decode(self, ids):
113
+ """Decode token IDs to text"""
114
+ if self.sp is None:
115
+ raise ValueError("Tokenizer not loaded. Train or load a model first.")
116
+ return self.sp.decode_ids(ids)
117
+
118
+ def vocab_size(self):
119
+ """Get vocabulary size"""
120
+ if self.sp is None:
121
+ return 0
122
+ return self.sp.get_piece_size()
123
+
124
+ def bos_id(self):
125
+ """Beginning of sentence token ID"""
126
+ return self.sp.bos_id()
127
+
128
+ def eos_id(self):
129
+ """End of sentence token ID"""
130
+ return self.sp.eos_id()
131
+
132
+ def pad_id(self):
133
+ """Padding token ID"""
134
+ return self.sp.pad_id()
135
+
136
+ def unk_id(self):
137
+ """Unknown token ID"""
138
+ return self.sp.unk_id()