Adjoumani commited on
Commit
155a79d
·
verified ·
1 Parent(s): e15419b

Upload baoule_tokenizer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. baoule_tokenizer.py +414 -0
baoule_tokenizer.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import unicodedata
3
+ import regex as re
4
+ from datasets import load_dataset
5
+ import time
6
+ import os
7
+
8
+ def get_stats(ids, stats=None):
9
+ """
10
+ Calcule la fréquence des paires d'ids consécutifs.
11
+ Conserve la même logique que la version originale car cette fonction est indépendante
12
+ des spécificités de la langue.
13
+ """
14
+ stats = {} if stats is None else stats
15
+ for pair in zip(ids, ids[1:]):
16
+ stats[pair] = stats.get(pair, 0) + 1
17
+ return stats
18
+
19
+ def merge(ids, pair, idx):
20
+ """
21
+ Fusionne les paires d'ids identifiées.
22
+ Conserve la même logique que la version originale car cette fonction gère
23
+ uniquement la fusion des tokens déjà identifiés.
24
+ """
25
+ newids = []
26
+ i = 0
27
+ while i < len(ids):
28
+ if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
29
+ newids.append(idx)
30
+ i += 2
31
+ else:
32
+ newids.append(ids[i])
33
+ i += 1
34
+ return newids
35
+
36
+ def replace_control_characters(s: str) -> str:
37
+ """
38
+ Remplace les caractères de contrôle, avec une attention particulière aux
39
+ caractères spéciaux du baoulé.
40
+ """
41
+ chars = []
42
+ for ch in s:
43
+ # Gestion spéciale des caractères baoulé
44
+ if ch in ['ɛ', 'ɔ', 'ŋ', 'ɲ']: # Caractères spéciaux baoulé
45
+ chars.append(ch)
46
+ # Gestion standard des caractères de contrôle
47
+ elif unicodedata.category(ch)[0] != "C":
48
+ chars.append(ch)
49
+ else:
50
+ chars.append(f"\u{ord(ch):04x}")
51
+ return "".join(chars)
52
+
53
+ def render_token(t: bytes) -> str:
54
+ """
55
+ Décode les tokens en gérant les caractères spéciaux du baoulé.
56
+ """
57
+ try:
58
+ # Tentative de décodage standard
59
+ s = t.decode('utf-8', errors='replace')
60
+ # Gestion des caractères spéciaux baoulé
61
+ s = replace_control_characters(s)
62
+ return s
63
+ except UnicodeDecodeError:
64
+ # En cas d'échec, retourne le caractère de remplacement
65
+ return '�'
66
+
67
+
68
+
69
+ class BaouleTokenizer:
70
+ def __init__(self):
71
+ # Initialisation des attributs obligatoires
72
+ self.special_chars = {
73
+ 'ɛ': 256,
74
+ 'ɔ': 257,
75
+ 'ŋ': 258,
76
+ 'ɲ': 259
77
+ }
78
+
79
+ self.pattern = r"(?i:'n|gb|kp|ny|[ɛɔ]n)|(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^
80
+ \p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[
81
+ ]*|\s*[
82
+ ]+|\s+(?!\S)|\s+"
83
+ self.compiled_pattern = re.compile(self.pattern)
84
+
85
+ self.special_tokens = {
86
+ #'<|baoule|>': 1101,
87
+ #'<|french|>': 1102,
88
+ #'<|end|>': 1103,
89
+ #'<|unknown|>': 1104,
90
+ #'<|pad|>': 1105,
91
+ # or
92
+ '<|begin_of_text|>': 1101,
93
+ '<|end_of_text|>': 1102,
94
+ '<|start_header_id|>': 1103,
95
+ '<|end_header_id|>': 1104,
96
+ '<|eot_id|>': 1105
97
+ }
98
+
99
+ self.merges = {}
100
+ self.vocab = self._build_vocab()
101
+
102
+ def train(self, dataset, vocab_size):
103
+ assert vocab_size >= 260 # 256 ASCII + 4 caractères spéciaux baoulé minimum
104
+
105
+ # Extraction des textes baoulé du dataset HuggingFace
106
+ text_chunks = []
107
+ for item in dataset['train']:
108
+ chunks = re.findall(self.compiled_pattern, item['baoule'])
109
+ text_chunks.extend(chunks)
110
+
111
+ # Conversion en ids avec gestion spéciale des caractères baoulé
112
+ ids = []
113
+ for chunk in text_chunks:
114
+ chunk_ids = []
115
+ i = 0
116
+ while i < len(chunk):
117
+ # Vérification des digraphes
118
+ if i < len(chunk) - 1:
119
+ digraph = chunk[i:i+2]
120
+ if digraph in ['gb', 'kp', 'ny']:
121
+ chunk_ids.append(ord(digraph[0]))
122
+ chunk_ids.append(ord(digraph[1]))
123
+ i += 2
124
+ continue
125
+
126
+ # Gestion des caractères spéciaux baoulé
127
+ char = chunk[i]
128
+ if char in self.special_chars:
129
+ chunk_ids.append(self.special_chars[char])
130
+ else:
131
+ # Encodage UTF-8 standard pour les autres caractères
132
+ chunk_ids.extend(list(char.encode("utf-8")))
133
+ i += 1
134
+ ids.append(chunk_ids)
135
+
136
+ # Calcul des fusions
137
+ num_merges = vocab_size - (260 + len(self.special_tokens))
138
+ merges = {}
139
+ vocab = {idx: bytes([idx]) for idx in range(256)}
140
+ vocab.update({idx: char.encode('utf-8') for char, idx in self.special_chars.items()})
141
+
142
+ for i in range(num_merges):
143
+ stats = {}
144
+ for chunk_ids in ids:
145
+ get_stats(chunk_ids, stats)
146
+ if not stats:
147
+ break
148
+ pair = max(stats, key=stats.get)
149
+ idx = 260 + len(self.special_tokens) + i
150
+ ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
151
+ merges[pair] = idx
152
+ vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
153
+
154
+ self.merges = merges
155
+ self.vocab = vocab
156
+
157
+ def _build_vocab(self):
158
+ # Vocabulaire de base incluant ASCII et caractères spéciaux baoulé
159
+ vocab = {idx: bytes([idx]) for idx in range(256)}
160
+ vocab.update({idx: char.encode('utf-8') for char, idx in self.special_chars.items()})
161
+
162
+ # Ajout des fusions
163
+ for (p0, p1), idx in self.merges.items():
164
+ vocab[idx] = vocab[p0] + vocab[p1]
165
+
166
+ # Ajout des tokens spéciaux
167
+ for special, idx in self.special_tokens.items():
168
+ vocab[idx] = special.encode("utf-8")
169
+
170
+ return vocab
171
+
172
+ def save(self, file_prefix):
173
+ # Sauvegarde du modèle
174
+ model_file = file_prefix + ".model"
175
+ with open(model_file, 'w') as f:
176
+ f.write("baoule tokenizer v1.0
177
+ ")
178
+ f.write(f"{self.pattern}
179
+ ")
180
+
181
+ # Sauvegarde des caractères spéciaux baoulé
182
+ f.write(f"{len(self.special_chars)}
183
+ ")
184
+ for char, idx in self.special_chars.items():
185
+ f.write(f"{char} {idx}
186
+ ")
187
+
188
+ # Sauvegarde des tokens spéciaux
189
+ f.write(f"{len(self.special_tokens)}
190
+ ")
191
+ for token, idx in self.special_tokens.items():
192
+ f.write(f"{token} {idx}
193
+ ")
194
+
195
+ # Sauvegarde des fusions
196
+ for idx1, idx2 in self.merges:
197
+ f.write(f"{idx1} {idx2}
198
+ ")
199
+
200
+ # Sauvegarde du vocabulaire
201
+ vocab_file = file_prefix + ".vocab"
202
+ inverted_merges = {idx: pair for pair, idx in self.merges.items()}
203
+ with open(vocab_file, "w", encoding="utf-8") as f:
204
+ for idx, token in self.vocab.items():
205
+ s = render_token(token)
206
+ if idx in inverted_merges:
207
+ idx0, idx1 = inverted_merges[idx]
208
+ s0 = render_token(self.vocab[idx0])
209
+ s1 = render_token(self.vocab[idx1])
210
+ f.write(f"[{s0}][{s1}] -> [{s}] {idx}
211
+ ")
212
+ else:
213
+ f.write(f"[{s}] {idx}
214
+ ")
215
+ def load(self, model_file):
216
+ merges = {}
217
+ special_tokens = {}
218
+ special_chars = {}
219
+
220
+ with open(model_file, 'r', encoding="utf-8") as f:
221
+ version = f.readline().strip()
222
+ self.pattern = f.readline().strip()
223
+ self.compiled_pattern = re.compile(self.pattern)
224
+
225
+ # Lecture des caractères spéciaux baoulé
226
+ num_special_chars = int(f.readline().strip())
227
+ for _ in range(num_special_chars):
228
+ char, char_idx = f.readline().strip().split()
229
+ special_chars[char] = int(char_idx)
230
+
231
+ # Lecture des tokens spéciaux
232
+ num_special = int(f.readline().strip())
233
+ for _ in range(num_special):
234
+ special, special_idx = f.readline().strip().split()
235
+ special_tokens[special] = int(special_idx)
236
+
237
+ # Création du vocabulaire de base
238
+ base_vocab = {}
239
+ # Ajouter les caractères ASCII
240
+ for i in range(256):
241
+ base_vocab[i] = bytes([i])
242
+ # Ajouter les caractères spéciaux
243
+ for char, idx in special_chars.items():
244
+ base_vocab[idx] = char.encode('utf-8')
245
+ # Ajouter les tokens spéciaux
246
+ for token, idx in special_tokens.items():
247
+ base_vocab[idx] = token.encode('utf-8')
248
+
249
+ # Lecture des fusions
250
+ for line in f:
251
+ try:
252
+ idx1, idx2 = map(int, line.strip().split())
253
+ if idx1 not in base_vocab or idx2 not in base_vocab:
254
+ print(f"Warning: skipping fusion for indices {idx1}, {idx2} - not found in vocabulary")
255
+ continue
256
+ next_idx = len(base_vocab)
257
+ merges[(idx1, idx2)] = next_idx
258
+ base_vocab[next_idx] = base_vocab[idx1] + base_vocab[idx2]
259
+ except Exception as e:
260
+ print(f"Error processing line: {line.strip()}")
261
+ print(f"Current vocabulary keys: {sorted(base_vocab.keys())}")
262
+ raise e
263
+
264
+ self.merges = merges
265
+ self.special_tokens = special_tokens
266
+ self.special_chars = special_chars
267
+ self.vocab = base_vocab
268
+
269
+ return self
270
+
271
+ def encode(self, text):
272
+ """
273
+ Encode le texte baoulé en liste d'identifiants entiers.
274
+ Gère les caractères spéciaux baoulé et les digraphes.
275
+ """
276
+ # Pattern pour identifier les tokens spéciaux
277
+ special_pattern = "(" + "|".join(re.escape(k) for k in self.special_tokens) + ")"
278
+ special_chunks = re.split(special_pattern, text)
279
+
280
+ ids = []
281
+
282
+ for part in special_chunks:
283
+ # Gestion des tokens spéciaux
284
+ if part in self.special_tokens:
285
+ ids.append(self.special_tokens[part])
286
+ elif part: # Ignorer les parties vides
287
+ # Découpage du texte en chunks selon le pattern
288
+ text_chunks = re.findall(self.compiled_pattern, part)
289
+
290
+ for chunk in text_chunks:
291
+ chunk_ids = []
292
+ i = 0
293
+
294
+ # Traitement caractère par caractère avec gestion des digraphes
295
+ while i < len(chunk):
296
+ # Vérification des digraphes baoulé (gb, kp, ny)
297
+ if i < len(chunk) - 1:
298
+ digraph = chunk[i:i+2]
299
+ if digraph.lower() in ['gb', 'kp', 'ny']:
300
+ chunk_ids.extend([ord(digraph[0]), ord(digraph[1])])
301
+ i += 2
302
+ continue
303
+
304
+ # Vérification des voyelles nasales
305
+ if i < len(chunk) - 1 and chunk[i+1] == 'n':
306
+ current_char = chunk[i]
307
+ if current_char in 'aɛiɔu':
308
+ # Traiter la voyelle nasale comme une unité
309
+ nasal_vowel = chunk[i:i+2]
310
+ chunk_ids.extend(list(nasal_vowel.encode('utf-8')))
311
+ i += 2
312
+ continue
313
+
314
+ # Gestion des caractères spéciaux baoulé
315
+ current_char = chunk[i]
316
+ if current_char in self.special_chars:
317
+ chunk_ids.append(self.special_chars[current_char])
318
+ else:
319
+ # Encodage UTF-8 standard pour les autres caractères
320
+ chunk_ids.extend(list(current_char.encode('utf-8')))
321
+ i += 1
322
+
323
+ # Application des fusions (byte-pair encoding)
324
+ while len(chunk_ids) >= 2:
325
+ stats = get_stats(chunk_ids)
326
+ pair = min(stats, key=lambda p: self.merges.get(p, float('inf')))
327
+
328
+ if pair not in self.merges:
329
+ break
330
+
331
+ idx = self.merges[pair]
332
+ chunk_ids = merge(chunk_ids, pair, idx)
333
+
334
+ ids.extend(chunk_ids)
335
+
336
+ return ids
337
+
338
+ def decode(self, ids):
339
+ """
340
+ Décode une liste d'identifiants en texte baoulé.
341
+ Gère la reconstruction des caractères spéciaux et des digraphes.
342
+ """
343
+ part_bytes = []
344
+ inverse_special_tokens = {v: k for k, v in self.special_tokens.items()}
345
+ inverse_special_chars = {v: k for k, v in self.special_chars.items()}
346
+
347
+ i = 0
348
+ while i < len(ids):
349
+ current_id = ids[i]
350
+
351
+ # Gestion des tokens spéciaux
352
+ if current_id in inverse_special_tokens:
353
+ part_bytes.append(inverse_special_tokens[current_id].encode('utf-8'))
354
+ i += 1
355
+ continue
356
+
357
+ # Gestion des caractères spéciaux baoulé
358
+ if current_id in inverse_special_chars:
359
+ part_bytes.append(inverse_special_chars[current_id].encode('utf-8'))
360
+ i += 1
361
+ continue
362
+
363
+ # Gestion du vocabulaire standard
364
+ if current_id in self.vocab:
365
+ # Vérification des digraphes potentiels
366
+ if i < len(ids) - 1:
367
+ next_id = ids[i + 1]
368
+ current_bytes = self.vocab[current_id]
369
+ if next_id in self.vocab:
370
+ next_bytes = self.vocab[next_id]
371
+ combined = current_bytes + next_bytes
372
+ # Vérification si c'est un digraphe baoulé
373
+ try:
374
+ combined_str = combined.decode('utf-8')
375
+ if combined_str.lower() in ['gb', 'kp', 'ny']:
376
+ part_bytes.append(combined)
377
+ i += 2
378
+ continue
379
+ except UnicodeDecodeError:
380
+ pass
381
+
382
+ part_bytes.append(self.vocab[current_id])
383
+ i += 1
384
+ else:
385
+ raise ValueError(f"ID de token invalide: {current_id}")
386
+
387
+ # Reconstruction du texte final
388
+ text_bytes = b''.join(part_bytes)
389
+ text = text_bytes.decode('utf-8', errors='replace')
390
+
391
+ return text
392
+
393
+
394
+
395
+
396
+ # Chargement du dataset depuis HuggingFace
397
+ dataset = load_dataset("Adjoumani/translations_french_baoule_V1")
398
+
399
+ # Configuration
400
+ vocab_size = 512
401
+ output_dir = "./models"
402
+ os.makedirs(output_dir, exist_ok=True)
403
+
404
+ # Initialisation et entraînement
405
+ tokenizer = BaouleTokenizer()
406
+ start_time = time.time()
407
+ tokenizer.train(dataset, vocab_size)
408
+ end_time = time.time()
409
+
410
+ # Sauvegarde
411
+ tokenizer.save(f"{output_dir}/baoule_tokenizer")
412
+
413
+ print(f"Durée d'entraînement: {end_time-start_time:.2f} secondes")
414
+ print(f"Modèle sauvegardé dans: {output_dir}/baoule_tokenizer")