Upload baoule_tokenizer.py with huggingface_hub
Browse files- baoule_tokenizer.py +414 -0
baoule_tokenizer.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import unicodedata
|
3 |
+
import regex as re
|
4 |
+
from datasets import load_dataset
|
5 |
+
import time
|
6 |
+
import os
|
7 |
+
|
8 |
+
def get_stats(ids, stats=None):
|
9 |
+
"""
|
10 |
+
Calcule la fréquence des paires d'ids consécutifs.
|
11 |
+
Conserve la même logique que la version originale car cette fonction est indépendante
|
12 |
+
des spécificités de la langue.
|
13 |
+
"""
|
14 |
+
stats = {} if stats is None else stats
|
15 |
+
for pair in zip(ids, ids[1:]):
|
16 |
+
stats[pair] = stats.get(pair, 0) + 1
|
17 |
+
return stats
|
18 |
+
|
19 |
+
def merge(ids, pair, idx):
|
20 |
+
"""
|
21 |
+
Fusionne les paires d'ids identifiées.
|
22 |
+
Conserve la même logique que la version originale car cette fonction gère
|
23 |
+
uniquement la fusion des tokens déjà identifiés.
|
24 |
+
"""
|
25 |
+
newids = []
|
26 |
+
i = 0
|
27 |
+
while i < len(ids):
|
28 |
+
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
|
29 |
+
newids.append(idx)
|
30 |
+
i += 2
|
31 |
+
else:
|
32 |
+
newids.append(ids[i])
|
33 |
+
i += 1
|
34 |
+
return newids
|
35 |
+
|
36 |
+
def replace_control_characters(s: str) -> str:
|
37 |
+
"""
|
38 |
+
Remplace les caractères de contrôle, avec une attention particulière aux
|
39 |
+
caractères spéciaux du baoulé.
|
40 |
+
"""
|
41 |
+
chars = []
|
42 |
+
for ch in s:
|
43 |
+
# Gestion spéciale des caractères baoulé
|
44 |
+
if ch in ['ɛ', 'ɔ', 'ŋ', 'ɲ']: # Caractères spéciaux baoulé
|
45 |
+
chars.append(ch)
|
46 |
+
# Gestion standard des caractères de contrôle
|
47 |
+
elif unicodedata.category(ch)[0] != "C":
|
48 |
+
chars.append(ch)
|
49 |
+
else:
|
50 |
+
chars.append(f"\u{ord(ch):04x}")
|
51 |
+
return "".join(chars)
|
52 |
+
|
53 |
+
def render_token(t: bytes) -> str:
|
54 |
+
"""
|
55 |
+
Décode les tokens en gérant les caractères spéciaux du baoulé.
|
56 |
+
"""
|
57 |
+
try:
|
58 |
+
# Tentative de décodage standard
|
59 |
+
s = t.decode('utf-8', errors='replace')
|
60 |
+
# Gestion des caractères spéciaux baoulé
|
61 |
+
s = replace_control_characters(s)
|
62 |
+
return s
|
63 |
+
except UnicodeDecodeError:
|
64 |
+
# En cas d'échec, retourne le caractère de remplacement
|
65 |
+
return '�'
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
class BaouleTokenizer:
|
70 |
+
def __init__(self):
|
71 |
+
# Initialisation des attributs obligatoires
|
72 |
+
self.special_chars = {
|
73 |
+
'ɛ': 256,
|
74 |
+
'ɔ': 257,
|
75 |
+
'ŋ': 258,
|
76 |
+
'ɲ': 259
|
77 |
+
}
|
78 |
+
|
79 |
+
self.pattern = r"(?i:'n|gb|kp|ny|[ɛɔ]n)|(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^
|
80 |
+
\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[
|
81 |
+
]*|\s*[
|
82 |
+
]+|\s+(?!\S)|\s+"
|
83 |
+
self.compiled_pattern = re.compile(self.pattern)
|
84 |
+
|
85 |
+
self.special_tokens = {
|
86 |
+
#'<|baoule|>': 1101,
|
87 |
+
#'<|french|>': 1102,
|
88 |
+
#'<|end|>': 1103,
|
89 |
+
#'<|unknown|>': 1104,
|
90 |
+
#'<|pad|>': 1105,
|
91 |
+
# or
|
92 |
+
'<|begin_of_text|>': 1101,
|
93 |
+
'<|end_of_text|>': 1102,
|
94 |
+
'<|start_header_id|>': 1103,
|
95 |
+
'<|end_header_id|>': 1104,
|
96 |
+
'<|eot_id|>': 1105
|
97 |
+
}
|
98 |
+
|
99 |
+
self.merges = {}
|
100 |
+
self.vocab = self._build_vocab()
|
101 |
+
|
102 |
+
def train(self, dataset, vocab_size):
|
103 |
+
assert vocab_size >= 260 # 256 ASCII + 4 caractères spéciaux baoulé minimum
|
104 |
+
|
105 |
+
# Extraction des textes baoulé du dataset HuggingFace
|
106 |
+
text_chunks = []
|
107 |
+
for item in dataset['train']:
|
108 |
+
chunks = re.findall(self.compiled_pattern, item['baoule'])
|
109 |
+
text_chunks.extend(chunks)
|
110 |
+
|
111 |
+
# Conversion en ids avec gestion spéciale des caractères baoulé
|
112 |
+
ids = []
|
113 |
+
for chunk in text_chunks:
|
114 |
+
chunk_ids = []
|
115 |
+
i = 0
|
116 |
+
while i < len(chunk):
|
117 |
+
# Vérification des digraphes
|
118 |
+
if i < len(chunk) - 1:
|
119 |
+
digraph = chunk[i:i+2]
|
120 |
+
if digraph in ['gb', 'kp', 'ny']:
|
121 |
+
chunk_ids.append(ord(digraph[0]))
|
122 |
+
chunk_ids.append(ord(digraph[1]))
|
123 |
+
i += 2
|
124 |
+
continue
|
125 |
+
|
126 |
+
# Gestion des caractères spéciaux baoulé
|
127 |
+
char = chunk[i]
|
128 |
+
if char in self.special_chars:
|
129 |
+
chunk_ids.append(self.special_chars[char])
|
130 |
+
else:
|
131 |
+
# Encodage UTF-8 standard pour les autres caractères
|
132 |
+
chunk_ids.extend(list(char.encode("utf-8")))
|
133 |
+
i += 1
|
134 |
+
ids.append(chunk_ids)
|
135 |
+
|
136 |
+
# Calcul des fusions
|
137 |
+
num_merges = vocab_size - (260 + len(self.special_tokens))
|
138 |
+
merges = {}
|
139 |
+
vocab = {idx: bytes([idx]) for idx in range(256)}
|
140 |
+
vocab.update({idx: char.encode('utf-8') for char, idx in self.special_chars.items()})
|
141 |
+
|
142 |
+
for i in range(num_merges):
|
143 |
+
stats = {}
|
144 |
+
for chunk_ids in ids:
|
145 |
+
get_stats(chunk_ids, stats)
|
146 |
+
if not stats:
|
147 |
+
break
|
148 |
+
pair = max(stats, key=stats.get)
|
149 |
+
idx = 260 + len(self.special_tokens) + i
|
150 |
+
ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
|
151 |
+
merges[pair] = idx
|
152 |
+
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
|
153 |
+
|
154 |
+
self.merges = merges
|
155 |
+
self.vocab = vocab
|
156 |
+
|
157 |
+
def _build_vocab(self):
|
158 |
+
# Vocabulaire de base incluant ASCII et caractères spéciaux baoulé
|
159 |
+
vocab = {idx: bytes([idx]) for idx in range(256)}
|
160 |
+
vocab.update({idx: char.encode('utf-8') for char, idx in self.special_chars.items()})
|
161 |
+
|
162 |
+
# Ajout des fusions
|
163 |
+
for (p0, p1), idx in self.merges.items():
|
164 |
+
vocab[idx] = vocab[p0] + vocab[p1]
|
165 |
+
|
166 |
+
# Ajout des tokens spéciaux
|
167 |
+
for special, idx in self.special_tokens.items():
|
168 |
+
vocab[idx] = special.encode("utf-8")
|
169 |
+
|
170 |
+
return vocab
|
171 |
+
|
172 |
+
def save(self, file_prefix):
|
173 |
+
# Sauvegarde du modèle
|
174 |
+
model_file = file_prefix + ".model"
|
175 |
+
with open(model_file, 'w') as f:
|
176 |
+
f.write("baoule tokenizer v1.0
|
177 |
+
")
|
178 |
+
f.write(f"{self.pattern}
|
179 |
+
")
|
180 |
+
|
181 |
+
# Sauvegarde des caractères spéciaux baoulé
|
182 |
+
f.write(f"{len(self.special_chars)}
|
183 |
+
")
|
184 |
+
for char, idx in self.special_chars.items():
|
185 |
+
f.write(f"{char} {idx}
|
186 |
+
")
|
187 |
+
|
188 |
+
# Sauvegarde des tokens spéciaux
|
189 |
+
f.write(f"{len(self.special_tokens)}
|
190 |
+
")
|
191 |
+
for token, idx in self.special_tokens.items():
|
192 |
+
f.write(f"{token} {idx}
|
193 |
+
")
|
194 |
+
|
195 |
+
# Sauvegarde des fusions
|
196 |
+
for idx1, idx2 in self.merges:
|
197 |
+
f.write(f"{idx1} {idx2}
|
198 |
+
")
|
199 |
+
|
200 |
+
# Sauvegarde du vocabulaire
|
201 |
+
vocab_file = file_prefix + ".vocab"
|
202 |
+
inverted_merges = {idx: pair for pair, idx in self.merges.items()}
|
203 |
+
with open(vocab_file, "w", encoding="utf-8") as f:
|
204 |
+
for idx, token in self.vocab.items():
|
205 |
+
s = render_token(token)
|
206 |
+
if idx in inverted_merges:
|
207 |
+
idx0, idx1 = inverted_merges[idx]
|
208 |
+
s0 = render_token(self.vocab[idx0])
|
209 |
+
s1 = render_token(self.vocab[idx1])
|
210 |
+
f.write(f"[{s0}][{s1}] -> [{s}] {idx}
|
211 |
+
")
|
212 |
+
else:
|
213 |
+
f.write(f"[{s}] {idx}
|
214 |
+
")
|
215 |
+
def load(self, model_file):
|
216 |
+
merges = {}
|
217 |
+
special_tokens = {}
|
218 |
+
special_chars = {}
|
219 |
+
|
220 |
+
with open(model_file, 'r', encoding="utf-8") as f:
|
221 |
+
version = f.readline().strip()
|
222 |
+
self.pattern = f.readline().strip()
|
223 |
+
self.compiled_pattern = re.compile(self.pattern)
|
224 |
+
|
225 |
+
# Lecture des caractères spéciaux baoulé
|
226 |
+
num_special_chars = int(f.readline().strip())
|
227 |
+
for _ in range(num_special_chars):
|
228 |
+
char, char_idx = f.readline().strip().split()
|
229 |
+
special_chars[char] = int(char_idx)
|
230 |
+
|
231 |
+
# Lecture des tokens spéciaux
|
232 |
+
num_special = int(f.readline().strip())
|
233 |
+
for _ in range(num_special):
|
234 |
+
special, special_idx = f.readline().strip().split()
|
235 |
+
special_tokens[special] = int(special_idx)
|
236 |
+
|
237 |
+
# Création du vocabulaire de base
|
238 |
+
base_vocab = {}
|
239 |
+
# Ajouter les caractères ASCII
|
240 |
+
for i in range(256):
|
241 |
+
base_vocab[i] = bytes([i])
|
242 |
+
# Ajouter les caractères spéciaux
|
243 |
+
for char, idx in special_chars.items():
|
244 |
+
base_vocab[idx] = char.encode('utf-8')
|
245 |
+
# Ajouter les tokens spéciaux
|
246 |
+
for token, idx in special_tokens.items():
|
247 |
+
base_vocab[idx] = token.encode('utf-8')
|
248 |
+
|
249 |
+
# Lecture des fusions
|
250 |
+
for line in f:
|
251 |
+
try:
|
252 |
+
idx1, idx2 = map(int, line.strip().split())
|
253 |
+
if idx1 not in base_vocab or idx2 not in base_vocab:
|
254 |
+
print(f"Warning: skipping fusion for indices {idx1}, {idx2} - not found in vocabulary")
|
255 |
+
continue
|
256 |
+
next_idx = len(base_vocab)
|
257 |
+
merges[(idx1, idx2)] = next_idx
|
258 |
+
base_vocab[next_idx] = base_vocab[idx1] + base_vocab[idx2]
|
259 |
+
except Exception as e:
|
260 |
+
print(f"Error processing line: {line.strip()}")
|
261 |
+
print(f"Current vocabulary keys: {sorted(base_vocab.keys())}")
|
262 |
+
raise e
|
263 |
+
|
264 |
+
self.merges = merges
|
265 |
+
self.special_tokens = special_tokens
|
266 |
+
self.special_chars = special_chars
|
267 |
+
self.vocab = base_vocab
|
268 |
+
|
269 |
+
return self
|
270 |
+
|
271 |
+
def encode(self, text):
|
272 |
+
"""
|
273 |
+
Encode le texte baoulé en liste d'identifiants entiers.
|
274 |
+
Gère les caractères spéciaux baoulé et les digraphes.
|
275 |
+
"""
|
276 |
+
# Pattern pour identifier les tokens spéciaux
|
277 |
+
special_pattern = "(" + "|".join(re.escape(k) for k in self.special_tokens) + ")"
|
278 |
+
special_chunks = re.split(special_pattern, text)
|
279 |
+
|
280 |
+
ids = []
|
281 |
+
|
282 |
+
for part in special_chunks:
|
283 |
+
# Gestion des tokens spéciaux
|
284 |
+
if part in self.special_tokens:
|
285 |
+
ids.append(self.special_tokens[part])
|
286 |
+
elif part: # Ignorer les parties vides
|
287 |
+
# Découpage du texte en chunks selon le pattern
|
288 |
+
text_chunks = re.findall(self.compiled_pattern, part)
|
289 |
+
|
290 |
+
for chunk in text_chunks:
|
291 |
+
chunk_ids = []
|
292 |
+
i = 0
|
293 |
+
|
294 |
+
# Traitement caractère par caractère avec gestion des digraphes
|
295 |
+
while i < len(chunk):
|
296 |
+
# Vérification des digraphes baoulé (gb, kp, ny)
|
297 |
+
if i < len(chunk) - 1:
|
298 |
+
digraph = chunk[i:i+2]
|
299 |
+
if digraph.lower() in ['gb', 'kp', 'ny']:
|
300 |
+
chunk_ids.extend([ord(digraph[0]), ord(digraph[1])])
|
301 |
+
i += 2
|
302 |
+
continue
|
303 |
+
|
304 |
+
# Vérification des voyelles nasales
|
305 |
+
if i < len(chunk) - 1 and chunk[i+1] == 'n':
|
306 |
+
current_char = chunk[i]
|
307 |
+
if current_char in 'aɛiɔu':
|
308 |
+
# Traiter la voyelle nasale comme une unité
|
309 |
+
nasal_vowel = chunk[i:i+2]
|
310 |
+
chunk_ids.extend(list(nasal_vowel.encode('utf-8')))
|
311 |
+
i += 2
|
312 |
+
continue
|
313 |
+
|
314 |
+
# Gestion des caractères spéciaux baoulé
|
315 |
+
current_char = chunk[i]
|
316 |
+
if current_char in self.special_chars:
|
317 |
+
chunk_ids.append(self.special_chars[current_char])
|
318 |
+
else:
|
319 |
+
# Encodage UTF-8 standard pour les autres caractères
|
320 |
+
chunk_ids.extend(list(current_char.encode('utf-8')))
|
321 |
+
i += 1
|
322 |
+
|
323 |
+
# Application des fusions (byte-pair encoding)
|
324 |
+
while len(chunk_ids) >= 2:
|
325 |
+
stats = get_stats(chunk_ids)
|
326 |
+
pair = min(stats, key=lambda p: self.merges.get(p, float('inf')))
|
327 |
+
|
328 |
+
if pair not in self.merges:
|
329 |
+
break
|
330 |
+
|
331 |
+
idx = self.merges[pair]
|
332 |
+
chunk_ids = merge(chunk_ids, pair, idx)
|
333 |
+
|
334 |
+
ids.extend(chunk_ids)
|
335 |
+
|
336 |
+
return ids
|
337 |
+
|
338 |
+
def decode(self, ids):
|
339 |
+
"""
|
340 |
+
Décode une liste d'identifiants en texte baoulé.
|
341 |
+
Gère la reconstruction des caractères spéciaux et des digraphes.
|
342 |
+
"""
|
343 |
+
part_bytes = []
|
344 |
+
inverse_special_tokens = {v: k for k, v in self.special_tokens.items()}
|
345 |
+
inverse_special_chars = {v: k for k, v in self.special_chars.items()}
|
346 |
+
|
347 |
+
i = 0
|
348 |
+
while i < len(ids):
|
349 |
+
current_id = ids[i]
|
350 |
+
|
351 |
+
# Gestion des tokens spéciaux
|
352 |
+
if current_id in inverse_special_tokens:
|
353 |
+
part_bytes.append(inverse_special_tokens[current_id].encode('utf-8'))
|
354 |
+
i += 1
|
355 |
+
continue
|
356 |
+
|
357 |
+
# Gestion des caractères spéciaux baoulé
|
358 |
+
if current_id in inverse_special_chars:
|
359 |
+
part_bytes.append(inverse_special_chars[current_id].encode('utf-8'))
|
360 |
+
i += 1
|
361 |
+
continue
|
362 |
+
|
363 |
+
# Gestion du vocabulaire standard
|
364 |
+
if current_id in self.vocab:
|
365 |
+
# Vérification des digraphes potentiels
|
366 |
+
if i < len(ids) - 1:
|
367 |
+
next_id = ids[i + 1]
|
368 |
+
current_bytes = self.vocab[current_id]
|
369 |
+
if next_id in self.vocab:
|
370 |
+
next_bytes = self.vocab[next_id]
|
371 |
+
combined = current_bytes + next_bytes
|
372 |
+
# Vérification si c'est un digraphe baoulé
|
373 |
+
try:
|
374 |
+
combined_str = combined.decode('utf-8')
|
375 |
+
if combined_str.lower() in ['gb', 'kp', 'ny']:
|
376 |
+
part_bytes.append(combined)
|
377 |
+
i += 2
|
378 |
+
continue
|
379 |
+
except UnicodeDecodeError:
|
380 |
+
pass
|
381 |
+
|
382 |
+
part_bytes.append(self.vocab[current_id])
|
383 |
+
i += 1
|
384 |
+
else:
|
385 |
+
raise ValueError(f"ID de token invalide: {current_id}")
|
386 |
+
|
387 |
+
# Reconstruction du texte final
|
388 |
+
text_bytes = b''.join(part_bytes)
|
389 |
+
text = text_bytes.decode('utf-8', errors='replace')
|
390 |
+
|
391 |
+
return text
|
392 |
+
|
393 |
+
|
394 |
+
|
395 |
+
|
396 |
+
# Chargement du dataset depuis HuggingFace
|
397 |
+
dataset = load_dataset("Adjoumani/translations_french_baoule_V1")
|
398 |
+
|
399 |
+
# Configuration
|
400 |
+
vocab_size = 512
|
401 |
+
output_dir = "./models"
|
402 |
+
os.makedirs(output_dir, exist_ok=True)
|
403 |
+
|
404 |
+
# Initialisation et entraînement
|
405 |
+
tokenizer = BaouleTokenizer()
|
406 |
+
start_time = time.time()
|
407 |
+
tokenizer.train(dataset, vocab_size)
|
408 |
+
end_time = time.time()
|
409 |
+
|
410 |
+
# Sauvegarde
|
411 |
+
tokenizer.save(f"{output_dir}/baoule_tokenizer")
|
412 |
+
|
413 |
+
print(f"Durée d'entraînement: {end_time-start_time:.2f} secondes")
|
414 |
+
print(f"Modèle sauvegardé dans: {output_dir}/baoule_tokenizer")
|