Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import sentencepiece as spm | |
| class HybridTokenizer: | |
| PAD_TOKEN = "<pad>" | |
| UNK_TOKEN = "<unk>" | |
| BOS_TOKEN = "<bos>" | |
| EOS_TOKEN = "<eos>" | |
| def __init__(self, sp_model_path=None, vocab_path=None): | |
| self.sp = spm.SentencePieceProcessor() | |
| self.has_sp = False | |
| if sp_model_path and os.path.exists(sp_model_path): | |
| self.sp.Load(sp_model_path) | |
| self.has_sp = True | |
| print(f"[TOKENIZER] Loaded SentencePiece model") | |
| # Load vocab ONLY for size alignment (not encoding) | |
| self.vocab = {} | |
| if vocab_path and os.path.exists(vocab_path): | |
| with open(vocab_path, "r", encoding="utf-8") as f: | |
| self.vocab = json.load(f) | |
| self.pad_id = self.sp.pad_id() if self.has_sp else 0 | |
| self.unk_id = self.sp.unk_id() if self.has_sp else 1 | |
| self.bos_id = self.sp.bos_id() if self.has_sp else 2 | |
| self.eos_id = self.sp.eos_id() if self.has_sp else 3 | |
| # Fix invalid (-1) values | |
| if self.pad_id < 0: self.pad_id = 0 | |
| if self.unk_id < 0: self.unk_id = 0 | |
| if self.bos_id < 0: self.bos_id = None | |
| if self.eos_id < 0: self.eos_id = None | |
| # --------------------------- | |
| # ENCODE (PURE SP) | |
| # --------------------------- | |
| def encode(self, text, max_len=512): | |
| if not self.has_sp: | |
| raise RuntimeError("SentencePiece model not loaded") | |
| ids = self.sp.encode(text, out_type=int) | |
| ids = ids[:max_len] | |
| if self.bos_id is not None: | |
| ids = [self.bos_id] + ids | |
| if self.eos_id is not None: | |
| ids = ids + [self.eos_id] | |
| ids = self._sanitize_ids(ids) | |
| return ids | |
| def safe_encode(self, text, max_len=512): | |
| try: | |
| return self.encode(text, max_len=max_len) | |
| except Exception as e: | |
| print(f"[TOKENIZER ERROR] {e}") | |
| fallback = [] | |
| if self.bos_id is not None: | |
| fallback.append(self.bos_id) | |
| fallback.append(self.unk_id if self.unk_id is not None else 0) | |
| if self.eos_id is not None: | |
| fallback.append(self.eos_id) | |
| return fallback | |
| # --------------------------- | |
| # PAD | |
| # --------------------------- | |
| def pad(self, ids, max_len): | |
| if len(ids) < max_len: | |
| return ids + [self.pad_id] * (max_len - len(ids)) | |
| return ids[:max_len] | |
| # --------------------------- | |
| # DECODE (PURE SP) | |
| # --------------------------- | |
| def decode(self, ids): | |
| # remove special tokens | |
| cleaned = [] | |
| for i in ids: | |
| if i in {self.pad_id, self.bos_id}: | |
| continue | |
| if i == self.eos_id: | |
| break | |
| cleaned.append(i) | |
| if not cleaned: | |
| return "" | |
| return self.sp.decode(cleaned) | |
| # --------------------------- | |
| # SAFETY: ID SANITIZATION | |
| # --------------------------- | |
| def _sanitize_ids(self, ids): | |
| vocab_size = self.vocab_size | |
| # fallback UNK (must be valid) | |
| unk = self.unk_id if (self.unk_id is not None and self.unk_id >= 0) else 0 | |
| return [ | |
| i if (isinstance(i, int) and 0 <= i < vocab_size) else unk | |
| for i in ids | |
| ] | |
| # --------------------------- | |
| # VOCAB SIZE (CRITICAL) | |
| # --------------------------- | |
| def vocab_size(self): | |
| if self.has_sp: | |
| return self.sp.GetPieceSize() | |
| return len(self.vocab) |