mvi-ai-engine / language /tokenizer.py
Musombi's picture
Update language/tokenizer.py
1a24d78 verified
import os
import json
import sentencepiece as spm
class HybridTokenizer:
PAD_TOKEN = "<pad>"
UNK_TOKEN = "<unk>"
BOS_TOKEN = "<bos>"
EOS_TOKEN = "<eos>"
def __init__(self, sp_model_path=None, vocab_path=None):
self.sp = spm.SentencePieceProcessor()
self.has_sp = False
if sp_model_path and os.path.exists(sp_model_path):
self.sp.Load(sp_model_path)
self.has_sp = True
print(f"[TOKENIZER] Loaded SentencePiece model")
# Load vocab ONLY for size alignment (not encoding)
self.vocab = {}
if vocab_path and os.path.exists(vocab_path):
with open(vocab_path, "r", encoding="utf-8") as f:
self.vocab = json.load(f)
self.pad_id = self.sp.pad_id() if self.has_sp else 0
self.unk_id = self.sp.unk_id() if self.has_sp else 1
self.bos_id = self.sp.bos_id() if self.has_sp else 2
self.eos_id = self.sp.eos_id() if self.has_sp else 3
# Fix invalid (-1) values
if self.pad_id < 0: self.pad_id = 0
if self.unk_id < 0: self.unk_id = 0
if self.bos_id < 0: self.bos_id = None
if self.eos_id < 0: self.eos_id = None
# ---------------------------
# ENCODE (PURE SP)
# ---------------------------
def encode(self, text, max_len=512):
if not self.has_sp:
raise RuntimeError("SentencePiece model not loaded")
ids = self.sp.encode(text, out_type=int)
ids = ids[:max_len]
if self.bos_id is not None:
ids = [self.bos_id] + ids
if self.eos_id is not None:
ids = ids + [self.eos_id]
ids = self._sanitize_ids(ids)
return ids
def safe_encode(self, text, max_len=512):
try:
return self.encode(text, max_len=max_len)
except Exception as e:
print(f"[TOKENIZER ERROR] {e}")
fallback = []
if self.bos_id is not None:
fallback.append(self.bos_id)
fallback.append(self.unk_id if self.unk_id is not None else 0)
if self.eos_id is not None:
fallback.append(self.eos_id)
return fallback
# ---------------------------
# PAD
# ---------------------------
def pad(self, ids, max_len):
if len(ids) < max_len:
return ids + [self.pad_id] * (max_len - len(ids))
return ids[:max_len]
# ---------------------------
# DECODE (PURE SP)
# ---------------------------
def decode(self, ids):
# remove special tokens
cleaned = []
for i in ids:
if i in {self.pad_id, self.bos_id}:
continue
if i == self.eos_id:
break
cleaned.append(i)
if not cleaned:
return ""
return self.sp.decode(cleaned)
# ---------------------------
# SAFETY: ID SANITIZATION
# ---------------------------
def _sanitize_ids(self, ids):
vocab_size = self.vocab_size
# fallback UNK (must be valid)
unk = self.unk_id if (self.unk_id is not None and self.unk_id >= 0) else 0
return [
i if (isinstance(i, int) and 0 <= i < vocab_size) else unk
for i in ids
]
# ---------------------------
# VOCAB SIZE (CRITICAL)
# ---------------------------
@property
def vocab_size(self):
if self.has_sp:
return self.sp.GetPieceSize()
return len(self.vocab)