| """ |
| ========================================== |
| WWHO Encoder |
| ========================================== |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from typing import Optional |
|
|
| from linguis_trie import LinguisTrie |
|
|
| def _is_boundary_token(token: str, segmenter) -> bool: |
| for ch in token: |
| if segmenter: |
| lang = segmenter._get_char_language(ch) |
| if lang is not None and lang != "latin": |
| return False |
| return True |
|
|
| def segment_into_words(syllables: list[str], segmenter) -> list[list[str]]: |
| words: list[list[str]] = [] |
| current: list[str] = [] |
|
|
| for tok in syllables: |
| if _is_boundary_token(tok, segmenter): |
| if current: |
| words.append(current) |
| current = [] |
| words.append([tok]) |
| else: |
| if tok[0] in (' ', '\t', '\n', '\r') and current: |
| words.append(current) |
| current = [] |
| current.append(tok) |
|
|
| if current: |
| words.append(current) |
| return words |
|
|
| class SGPEEncoder: |
|
|
| def __init__(self, vocab_path: str): |
| with open(vocab_path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
|
|
| self.vocab: dict[str, int] = data["vocab"] |
| self.merges: list[tuple[str, str]] = [tuple(m) for m in data["merges"]] |
| self.special_tokens: list[str] = data["special_tokens"] |
| self.leading_space: bool = data.get("leading_space", False) |
| |
| script_mode = data.get("script_mode", "mixed") |
| |
| from linguis_trie import load_dfa_map |
| from router import CodeSwitchSegmenter |
| |
| self._dfa_map = load_dfa_map(script_mode) |
| |
| language_blocks = {lang: dfa.unicode_blocks for lang, dfa in self._dfa_map.items()} |
| self._segmenter = CodeSwitchSegmenter(language_blocks) |
|
|
| self._merge_priority: dict[tuple[str, str], int] = { |
| (a, b): rank for rank, (a, b) in enumerate(self.merges) |
| } |
|
|
| def encode(self, text: str) -> list[int]: |
| tokens = self.tokenize(text) |
| return [self.vocab.get(t, self.unk_id) for t in tokens] |
|
|
| def _apply_merges_to_word(self, tokens: list[str]) -> list[str]: |
| if len(tokens) <= 1: |
| return tokens |
|
|
| while True: |
| best_rank = len(self.merges) |
| best_idx = -1 |
| for i in range(len(tokens) - 1): |
| pair = (tokens[i], tokens[i + 1]) |
| rank = self._merge_priority.get(pair) |
| if rank is not None and rank < best_rank: |
| best_rank = rank |
| best_idx = i |
| if best_idx == -1: |
| break |
| merged = tokens[best_idx] + tokens[best_idx + 1] |
| tokens = tokens[:best_idx] + [merged] + tokens[best_idx + 2:] |
|
|
| return tokens |
|
|
| def tokenize(self, text: str) -> list[str]: |
| tokens: list[str] = [] |
| for seg in self._segmenter.segment(text): |
| if seg.language == "latin": |
| tokens.append(seg.text) |
| else: |
| dfa = self._dfa_map.get(seg.language) |
| if not dfa: |
| tokens.append(seg.text) |
| continue |
| syllables = dfa.tokenize(seg.text, leading_space=seg.has_leading_space) |
| words = segment_into_words(syllables, self._segmenter) |
| for word_toks in words: |
| if len(word_toks) == 1 and _is_boundary_token(word_toks[0], self._segmenter): |
| tokens.append(word_toks[0]) |
| continue |
| cleaned = [t if t in self.vocab else "[UNK]" for t in word_toks] |
| tokens.extend(self._apply_merges_to_word(cleaned)) |
| return tokens |
|
|
| def decode(self, ids: list[int]) -> str: |
| id_to_token = {v: k for k, v in self.vocab.items()} |
| return "".join(id_to_token.get(i, "") for i in ids) |
|
|
|
|
| |
| |
| |
|
|
| class MetaVocab: |
| def __init__(self, sgpe_vocab: dict[str, int], tiktoken_size: int): |
| self.tiktoken_size: int = tiktoken_size |
| self._sgpe_raw: dict[str, int] = sgpe_vocab |
| self._sgpe_offset: dict[str, int] = { |
| tok: idx + tiktoken_size for tok, idx in sgpe_vocab.items() |
| } |
| self._sgpe_reverse: dict[int, str] = { |
| v: k for k, v in self._sgpe_offset.items() |
| } |
|
|
| @property |
| def total_size(self) -> int: |
| return self.tiktoken_size + len(self._sgpe_raw) |
|
|
| def encode_sgpe_token(self, token: str, unk_id_raw: int) -> int: |
| return self._sgpe_offset.get(token, unk_id_raw + self.tiktoken_size) |
|
|
| def decode_id(self, uid: int) -> Optional[str]: |
| if uid < self.tiktoken_size: |
| return None |
| return self._sgpe_reverse.get(uid) |
|
|
| def is_tiktoken_id(self, uid: int) -> bool: |
| return uid < self.tiktoken_size |
|
|
| def sgpe_unk_id(self, raw_unk: int) -> int: |
| return raw_unk + self.tiktoken_size |
|
|
|
|
| |
| |
| |
|
|
| class WWHOMetaEncoder: |
|
|
| def __init__(self, vocab_path: str, tiktoken_model: str = "o200k_base"): |
| |
| with open(vocab_path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
|
|
| sgpe_vocab: dict[str, int] = data["vocab"] |
| self._merges: list[tuple[str, str]] = [tuple(m) for m in data["merges"]] |
| self._special_tokens: list[str] = data["special_tokens"] |
| self._leading_space: bool = data.get("leading_space", False) |
| self._raw_unk_id: int = sgpe_vocab.get("[UNK]", 1) |
|
|
| if " " not in sgpe_vocab: |
| next_id = max(sgpe_vocab.values()) + 1 |
| sgpe_vocab[" "] = next_id |
|
|
| try: |
| from router import _INDIC_PUNCT_CHARS |
| for ch in _INDIC_PUNCT_CHARS: |
| if ch not in sgpe_vocab: |
| next_id = max(sgpe_vocab.values()) + 1 |
| sgpe_vocab[ch] = next_id |
| except ImportError: |
| pass |
|
|
| self._merge_priority: dict[tuple[str, str], int] = { |
| (a, b): rank for rank, (a, b) in enumerate(self._merges) |
| } |
|
|
| |
| try: |
| import tiktoken as _tiktoken |
| self._tik = _tiktoken.get_encoding(tiktoken_model) |
| except Exception as e: |
| raise RuntimeError( |
| f"tiktoken ({tiktoken_model!r}) unavailable: {e}. " |
| ) |
|
|
| |
| self._meta = MetaVocab(sgpe_vocab, self._tik.n_vocab) |
| self._space_id: int = self._meta._sgpe_offset[" "] |
|
|
| |
| from linguis_trie import load_dfa_map, LinguisTrie |
| |
| self._dfa_map: dict[str, LinguisTrie] = load_dfa_map("mixed") |
|
|
| |
| from router import CodeSwitchSegmenter |
| language_blocks = {lang: dfa.unicode_blocks for lang, dfa in self._dfa_map.items()} |
| self._segmenter = CodeSwitchSegmenter(language_blocks) |
|
|
| |
| |
| |
|
|
| @property |
| def vocab_size(self) -> int: |
| return self._meta.total_size |
|
|
| @property |
| def tiktoken_size(self) -> int: |
| return self._meta.tiktoken_size |
|
|
| @property |
| def vocab(self) -> dict[str, int]: |
| return self._meta._sgpe_raw |
|
|
| def encode(self, text: str) -> list[int]: |
| ids: list[int] = [] |
| for seg in self._segmenter.segment(text): |
| if seg.language == "latin": |
| ids.extend(self._tik.encode(seg.text)) |
| else: |
| dfa = self._dfa_map.get(seg.language) |
| if not dfa: |
| ids.extend(self._tik.encode(seg.text)) |
| continue |
| syllables = dfa.tokenize(seg.text, leading_space=seg.has_leading_space) |
| words = segment_into_words(syllables, self._segmenter) |
| for word_toks in words: |
| if len(word_toks) == 1 and _is_boundary_token(word_toks[0], self._segmenter): |
| ids.extend(self._tik.encode(word_toks[0])) |
| continue |
| merged = self._apply_merges(word_toks) |
| for tok in merged: |
| ids.append(self._meta.encode_sgpe_token(tok, self._raw_unk_id)) |
| return ids |
|
|
| def decode(self, ids: list[int]) -> str: |
| tik_buf: list[int] = [] |
| result_parts: list[str] = [] |
|
|
| def _flush_tik(): |
| if tik_buf: |
| result_parts.append(self._tik.decode(tik_buf)) |
| tik_buf.clear() |
|
|
| for uid in ids: |
| if self._meta.is_tiktoken_id(uid): |
| tik_buf.append(uid) |
| else: |
| _flush_tik() |
| tok = self._meta.decode_id(uid) |
| result_parts.append(tok if tok is not None else "") |
|
|
| _flush_tik() |
| return "".join(result_parts) |
|
|
| def tokenize(self, text: str) -> list[str]: |
| tokens: list[str] = [] |
| for seg in self._segmenter.segment(text): |
| if seg.language == "latin": |
| ids = self._tik.encode(seg.text) |
| tokens.extend(self._tik.decode([i]) for i in ids) |
| else: |
| dfa = self._dfa_map.get(seg.language) |
| if not dfa: |
| ids = self._tik.encode(seg.text) |
| tokens.extend(self._tik.decode([i]) for i in ids) |
| continue |
| syllables = dfa.tokenize(seg.text, leading_space=seg.has_leading_space) |
| words = segment_into_words(syllables, self._segmenter) |
| for word_toks in words: |
| if len(word_toks) == 1 and _is_boundary_token(word_toks[0], self._segmenter): |
| ids = self._tik.encode(word_toks[0]) |
| tokens.extend(self._tik.decode([i]) for i in ids) |
| continue |
| tokens.extend(self._apply_merges(word_toks)) |
| return tokens |
|
|
| def _apply_merges(self, tokens: list[str]) -> list[str]: |
| if len(tokens) <= 1: |
| return tokens |
| sgpe = self._meta._sgpe_raw |
| cleaned = [t if t in sgpe else "[UNK]" for t in tokens] |
| while True: |
| best_rank = len(self._merges) |
| best_idx = -1 |
| for i in range(len(cleaned) - 1): |
| pair = (cleaned[i], cleaned[i + 1]) |
| rank = self._merge_priority.get(pair) |
| if rank is not None and rank < best_rank: |
| best_rank = rank |
| best_idx = i |
| if best_idx == -1: |
| break |
| merged = cleaned[best_idx] + cleaned[best_idx + 1] |
| cleaned = cleaned[:best_idx] + [merged] + cleaned[best_idx + 2:] |
| return cleaned |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="WWHO Encoder (Unified Meta-Vocabulary)") |
| parser.add_argument("--vocab", type=str, default="output/vocab.json", |
| help="Path to WWHO vocab.json") |
| parser.add_argument("--text", type=str, required=True, |
| help="Text to encode (supports mixed Latin + Indic)") |
| parser.add_argument("--mode", type=str, default="meta", |
| choices=["sgpe", "meta"], |
| help="'sgpe' = pure SGPE encoder; 'meta' = unified meta-encoder") |
| parser.add_argument("--tiktoken_model", type=str, default="o200k_base") |
| args = parser.parse_args() |
|
|
| if args.mode == "sgpe": |
| enc = SGPEEncoder(args.vocab) |
| tokens = enc.tokenize(args.text) |
| ids = enc.encode(args.text) |
| print(f"[SGPEEncoder]") |
| print(f" tokens : {tokens}") |
| print(f" ids : {ids}") |
| print(f" count : {len(tokens)}") |
| else: |
| enc = WWHOMetaEncoder(args.vocab, tiktoken_model=args.tiktoken_model) |
| tokens = enc.tokenize(args.text) |
| ids = enc.encode(args.text) |
| decoded = enc.decode(ids) |
| print(f"[WWHOMetaEncoder]") |
| print(f" vocab_size : {enc.vocab_size:,} " |
| f"(tiktoken={enc.tiktoken_size:,} + SGPE={enc.vocab_size - enc.tiktoken_size:,})") |
| print(f" tokens : {tokens}") |
| print(f" ids : {ids}") |
| print(f" count : {len(tokens)}") |
| print(f" decoded: {decoded!r}") |
| print(f" lossless: {decoded == args.text}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|