Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| import json | |
| import torch | |
| from pathlib import Path | |
| from unicodedata import category, normalize | |
| from tokenizers import Tokenizer | |
| from huggingface_hub import hf_hub_download | |
| # Special tokens | |
| SOT = "[START]" | |
| EOT = "[STOP]" | |
| UNK = "[UNK]" | |
| SPACE = "[SPACE]" | |
| SPECIAL_TOKENS = [SOT, EOT, UNK, SPACE, "[PAD]", "[SEP]", "[CLS]", "[MASK]"] | |
| logger = logging.getLogger(__name__) | |
| class EnTokenizer: | |
| def __init__(self, vocab_file_path): | |
| self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path) | |
| self.check_vocabset_sot_eot() | |
| def check_vocabset_sot_eot(self): | |
| voc = self.tokenizer.get_vocab() | |
| assert SOT in voc | |
| assert EOT in voc | |
| def text_to_tokens(self, text: str): | |
| text_tokens = self.encode(text) | |
| text_tokens = torch.IntTensor(text_tokens).unsqueeze(0) | |
| return text_tokens | |
| def encode(self, txt: str): | |
| """ | |
| clean_text > (append `lang_id`) > replace SPACE > encode text using Tokenizer | |
| """ | |
| txt = txt.replace(' ', SPACE) | |
| code = self.tokenizer.encode(txt) | |
| ids = code.ids | |
| return ids | |
| def decode(self, seq): | |
| if isinstance(seq, torch.Tensor): | |
| seq = seq.cpu().numpy() | |
| txt: str = self.tokenizer.decode(seq, skip_special_tokens=False) | |
| txt = txt.replace(' ', '') | |
| txt = txt.replace(SPACE, ' ') | |
| txt = txt.replace(EOT, '') | |
| txt = txt.replace(UNK, '') | |
| return txt | |
| # Model repository | |
| REPO_ID = "ResembleAI/chatterbox" | |
| # Global instances for optional dependencies | |
| _kakasi = None | |
| _dicta = None | |
| _russian_stresser = None | |
| def is_kanji(c: str) -> bool: | |
| """Check if character is kanji.""" | |
| return 19968 <= ord(c) <= 40959 | |
| def is_katakana(c: str) -> bool: | |
| """Check if character is katakana.""" | |
| return 12449 <= ord(c) <= 12538 | |
| def hiragana_normalize(text: str) -> str: | |
| """Japanese text normalization: converts kanji to hiragana; katakana remains the same.""" | |
| global _kakasi | |
| try: | |
| if _kakasi is None: | |
| import pykakasi | |
| _kakasi = pykakasi.kakasi() | |
| result = _kakasi.convert(text) | |
| out = [] | |
| for r in result: | |
| inp = r['orig'] | |
| hira = r["hira"] | |
| # Any kanji in the phrase | |
| if any([is_kanji(c) for c in inp]): | |
| if hira and hira[0] in ["は", "へ"]: # Safety check for empty hira | |
| hira = " " + hira | |
| out.append(hira) | |
| # All katakana | |
| elif all([is_katakana(c) for c in inp]) if inp else False: # Safety check for empty inp | |
| out.append(r['orig']) | |
| else: | |
| out.append(inp) | |
| normalized_text = "".join(out) | |
| # Decompose Japanese characters for tokenizer compatibility | |
| import unicodedata | |
| normalized_text = unicodedata.normalize('NFKD', normalized_text) | |
| return normalized_text | |
| except ImportError: | |
| logger.warning("pykakasi not available - Japanese text processing skipped") | |
| return text | |
| def add_hebrew_diacritics(text: str) -> str: | |
| """Hebrew text normalization: adds diacritics to Hebrew text.""" | |
| global _dicta | |
| try: | |
| if _dicta is None: | |
| from dicta_onnx import Dicta | |
| _dicta = Dicta() | |
| return _dicta.add_diacritics(text) | |
| except ImportError: | |
| logger.warning("dicta_onnx not available - Hebrew text processing skipped") | |
| return text | |
| except Exception as e: | |
| logger.warning(f"Hebrew diacritization failed: {e}") | |
| return text | |
| def korean_normalize(text: str) -> str: | |
| """Korean text normalization: decompose syllables into Jamo for tokenization.""" | |
| def decompose_hangul(char): | |
| """Decompose Korean syllable into Jamo components.""" | |
| if not ('\uac00' <= char <= '\ud7af'): | |
| return char | |
| # Hangul decomposition formula | |
| base = ord(char) - 0xAC00 | |
| initial = chr(0x1100 + base // (21 * 28)) | |
| medial = chr(0x1161 + (base % (21 * 28)) // 28) | |
| final = chr(0x11A7 + base % 28) if base % 28 > 0 else '' | |
| return initial + medial + final | |
| # Decompose syllables and normalize punctuation | |
| result = ''.join(decompose_hangul(char) for char in text) | |
| return result.strip() | |
| class ChineseCangjieConverter: | |
| """Converts Chinese characters to Cangjie codes for tokenization.""" | |
| def __init__(self, model_dir=None): | |
| self.word2cj = {} | |
| self.cj2word = {} | |
| self.segmenter = None | |
| self._load_cangjie_mapping(model_dir) | |
| self._init_segmenter() | |
| def _load_cangjie_mapping(self, model_dir=None): | |
| """Load Cangjie mapping from HuggingFace model repository.""" | |
| try: | |
| cangjie_file = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename="Cangjie5_TC.json", | |
| cache_dir=model_dir | |
| ) | |
| with open(cangjie_file, "r", encoding="utf-8") as fp: | |
| data = json.load(fp) | |
| for entry in data: | |
| word, code = entry.split("\t")[:2] | |
| self.word2cj[word] = code | |
| if code not in self.cj2word: | |
| self.cj2word[code] = [word] | |
| else: | |
| self.cj2word[code].append(word) | |
| except Exception as e: | |
| logger.warning(f"Could not load Cangjie mapping: {e}") | |
| def _init_segmenter(self): | |
| """Initialize pkuseg segmenter.""" | |
| try: | |
| from spacy_pkuseg import pkuseg | |
| self.segmenter = pkuseg() | |
| except ImportError: | |
| logger.warning("pkuseg not available - Chinese segmentation will be skipped") | |
| self.segmenter = None | |
| def _cangjie_encode(self, glyph: str): | |
| """Encode a single Chinese glyph to Cangjie code.""" | |
| normed_glyph = glyph | |
| code = self.word2cj.get(normed_glyph, None) | |
| if code is None: # e.g. Japanese hiragana | |
| return None | |
| index = self.cj2word[code].index(normed_glyph) | |
| index = str(index) if index > 0 else "" | |
| return code + str(index) | |
| def __call__(self, text): | |
| """Convert Chinese characters in text to Cangjie tokens.""" | |
| output = [] | |
| if self.segmenter is not None: | |
| segmented_words = self.segmenter.cut(text) | |
| full_text = " ".join(segmented_words) | |
| else: | |
| full_text = text | |
| for t in full_text: | |
| if category(t) == "Lo": | |
| cangjie = self._cangjie_encode(t) | |
| if cangjie is None: | |
| output.append(t) | |
| continue | |
| code = [] | |
| for c in cangjie: | |
| code.append(f"[cj_{c}]") | |
| code.append("[cj_.]") | |
| code = "".join(code) | |
| output.append(code) | |
| else: | |
| output.append(t) | |
| return "".join(output) | |
| def add_russian_stress(text: str) -> str: | |
| """Russian text normalization: adds stress marks to Russian text.""" | |
| global _russian_stresser | |
| try: | |
| if _russian_stresser is None: | |
| from russian_text_stresser.text_stresser import RussianTextStresser | |
| _russian_stresser = RussianTextStresser() | |
| return _russian_stresser.stress_text(text) | |
| except ImportError: | |
| logger.warning("russian_text_stresser not available - Russian stress labeling skipped") | |
| return text | |
| except Exception as e: | |
| logger.warning(f"Russian stress labeling failed: {e}") | |
| return text | |
| class MTLTokenizer: | |
| def __init__(self, vocab_file_path): | |
| self.tokenizer: Tokenizer = Tokenizer.from_file(vocab_file_path) | |
| model_dir = Path(vocab_file_path).parent | |
| self.cangjie_converter = ChineseCangjieConverter(model_dir) | |
| self.check_vocabset_sot_eot() | |
| def check_vocabset_sot_eot(self): | |
| voc = self.tokenizer.get_vocab() | |
| assert SOT in voc | |
| assert EOT in voc | |
| def preprocess_text(self, raw_text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True): | |
| """ | |
| Text preprocessor that handles lowercase conversion and NFKD normalization. | |
| """ | |
| preprocessed_text = raw_text | |
| if lowercase: | |
| preprocessed_text = preprocessed_text.lower() | |
| if nfkd_normalize: | |
| preprocessed_text = normalize("NFKD", preprocessed_text) | |
| return preprocessed_text | |
| def text_to_tokens(self, text: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True): | |
| text_tokens = self.encode(text, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize) | |
| text_tokens = torch.IntTensor(text_tokens).unsqueeze(0) | |
| return text_tokens | |
| def encode(self, txt: str, language_id: str = None, lowercase: bool = True, nfkd_normalize: bool = True): | |
| txt = self.preprocess_text(txt, language_id=language_id, lowercase=lowercase, nfkd_normalize=nfkd_normalize) | |
| # Language-specific text processing | |
| if language_id == 'zh': | |
| txt = self.cangjie_converter(txt) | |
| elif language_id == 'ja': | |
| txt = hiragana_normalize(txt) | |
| elif language_id == 'he': | |
| txt = add_hebrew_diacritics(txt) | |
| elif language_id == 'ko': | |
| txt = korean_normalize(txt) | |
| elif language_id == 'ru': | |
| txt = add_russian_stress(txt) | |
| # Prepend language token | |
| if language_id: | |
| txt = f"[{language_id.lower()}]{txt}" | |
| txt = txt.replace(' ', SPACE) | |
| return self.tokenizer.encode(txt).ids | |
| def decode(self, seq): | |
| if isinstance(seq, torch.Tensor): | |
| seq = seq.cpu().numpy() | |
| txt = self.tokenizer.decode(seq, skip_special_tokens=False) | |
| txt = txt.replace(' ', '').replace(SPACE, ' ').replace(EOT, '').replace(UNK, '') | |
| return txt | |