|
from typing import Dict, List |
|
|
|
from transformers_gad.utils import get_tokenizer_model_type, ints2bytes |
|
from transformers import AutoTokenizer |
|
import logging |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
def get_mapping(tokenizer, unicode=False): |
|
log.debug(f"tokenizer type: {tokenizer.__class__.__name__}") |
|
log.debug(f"tokenizer model type: {get_tokenizer_model_type(tokenizer)}") |
|
if not unicode: |
|
if ( |
|
"gpt2" in tokenizer.__class__.__name__.lower() |
|
or "bloom" in tokenizer.__class__.__name__.lower() |
|
or "pretrainedtokenizer" in tokenizer.__class__.__name__.lower() |
|
or "codegen" in tokenizer.__class__.__name__.lower() |
|
or "gptneox" in tokenizer.__class__.__name__.lower() |
|
): |
|
return BBPEMapping(tokenizer) |
|
elif "t5" in tokenizer.__class__.__name__.lower(): |
|
return BPEMapping(tokenizer) |
|
elif "llama" in tokenizer.__class__.__name__.lower(): |
|
return LlamaBPEMapping(tokenizer) |
|
elif "xglm" in tokenizer.__class__.__name__.lower(): |
|
return UniGramMapping(tokenizer) |
|
else: |
|
raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__.__name__}") |
|
else: |
|
if "gpt2" in tokenizer.__class__.__name__.lower(): |
|
return UnicodeBBPEMapping(tokenizer) |
|
else: |
|
raise NotImplementedError( |
|
f"Unicode mapping for {tokenizer.__class__.__name__}" |
|
) |
|
|
|
|
|
class Mapping: |
|
def __init__(self, tokenizer): |
|
self.eos_token_id = tokenizer.eos_token_id |
|
self.bos_token_id = tokenizer.bos_token_id |
|
self.tokenizer = tokenizer |
|
self.special = tokenizer.all_special_ids |
|
|
|
def __len__(self): |
|
return len(self.tokenizer.get_vocab()) |
|
|
|
def _map(self, token_id: int) -> str: |
|
|
|
if token_id in self.special: |
|
return "" |
|
|
|
if hasattr(token_id, "item"): |
|
token_id = token_id.item() |
|
raw_token = self.tokenizer.convert_ids_to_tokens(token_id) |
|
return raw_token |
|
|
|
def map(self, token_id: int, verbose=False) -> bytes: |
|
token = self._map(token_id) |
|
if verbose: |
|
log.debug(f"token_id: {token_id}, token: {token}") |
|
return bytes(token, "utf-8") |
|
|
|
|
|
class BBPEMapping(Mapping): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def _map(self, token_id: int) -> str: |
|
raw_token = super()._map(token_id) |
|
if raw_token.startswith("Ġ"): |
|
raw_token = raw_token.replace("Ġ", " ") |
|
return raw_token |
|
|
|
|
|
class UnicodeBBPEMapping(Mapping): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.intermediate_encoding = UnicodeBBPEMapping.get_intermediate_encoding( |
|
self.tokenizer |
|
) |
|
|
|
def _map(self, token_id: int, verbose=False) -> str: |
|
raw_token = super()._map(token_id) |
|
|
|
|
|
return raw_token |
|
|
|
def map(self, token_id: int, verbose=False) -> bytes: |
|
raw_token = self._map(token_id, verbose) |
|
if verbose: |
|
log.debug(f"token_id: {token_id}, raw_token: {raw_token}") |
|
return self.intermediate_encoding.token2bytes(raw_token) |
|
|
|
@staticmethod |
|
def get_intermediate_encoding(tokenizer): |
|
if "gpt2" in tokenizer.__class__.__name__.lower(): |
|
return ByteEncoding(tokenizer) |
|
else: |
|
return None |
|
|
|
|
|
class BPEMapping(Mapping): |
|
def __init__(self, tokenizer): |
|
super().__init__(tokenizer) |
|
self.last_token_id = None |
|
|
|
def _map(self, token_id: int) -> str: |
|
raw_token = super()._map(token_id) |
|
|
|
|
|
|
|
at_bos = False |
|
if self.last_token_id is not None and self.last_token_id == self.bos_token_id: |
|
at_bos = True |
|
self.last_token_id = token_id |
|
if raw_token.startswith("▁"): |
|
raw_token = raw_token.replace("▁", " ") |
|
if at_bos: |
|
|
|
raw_token = raw_token[1:] |
|
return raw_token |
|
|
|
|
|
class LlamaBPEMapping(BPEMapping): |
|
def __init__(self, tokenizer): |
|
super().__init__(tokenizer) |
|
|
|
def _map(self, token_id: int) -> str: |
|
raw_token = super()._map(token_id) |
|
|
|
|
|
if raw_token.startswith("<0x"): |
|
hex_value = raw_token[4:-1] |
|
raw_token = chr(int(hex_value, 16)) |
|
return raw_token |
|
|
|
|
|
class WordPieceMapping(Mapping): |
|
def __init__(self, tokenizer): |
|
super().__init__(tokenizer) |
|
|
|
def map(self, token_id: int) -> bytes: |
|
if token_id in self.special: |
|
return bytes() |
|
return bytes( |
|
self.tokenizer.decode([token_id], clean_up_tokenization_spaces=False), |
|
"utf-8", |
|
) |
|
|
|
|
|
class UniGramMapping(Mapping): |
|
def __init__(self, tokenizer): |
|
super().__init__(tokenizer) |
|
|
|
def map(self, token_id: int) -> bytes: |
|
if token_id in self.special: |
|
return bytes() |
|
return bytes( |
|
self.tokenizer.decode([token_id], clean_up_tokenization_spaces=False), |
|
"utf-8", |
|
) |
|
|
|
|
|
class XGLMUniGramMapping(Mapping): |
|
def __init__(self, tokenizer): |
|
super().__init__(tokenizer) |
|
self.bos_token_id = tokenizer.eos_token_id |
|
self.eos_token_id = None |
|
|
|
|
|
class ByteEncoding: |
|
def __init__(self, tokenizer): |
|
|
|
if tokenizer.is_fast: |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
tokenizer.name_or_path, use_fast=False |
|
) |
|
self.tokenizer = tokenizer |
|
self.byte2char: Dict[int, str] = tokenizer.byte_encoder |
|
self.char2byte: Dict[str, int] = tokenizer.byte_decoder |
|
|
|
self.cdp2byte: Dict[int, int] = {ord(c): b for c, b in self.char2byte.items()} |
|
self.byte2cdp: Dict[int, int] = {v: k for k, v in self.cdp2byte.items()} |
|
|
|
def map(self, byte: int) -> int: |
|
assert 0 <= byte < 256, f"byte: {byte} is not in the range [0, 256)" |
|
return ord(self.byte2char[byte]) |
|
|
|
def token_ids2bytes(self, token_ids: List[int]) -> bytes: |
|
tokens: List[str] = self.tokenizer.convert_ids_to_tokens(token_ids) |
|
|
|
|
|
tokens = [ |
|
"" if token in self.tokenizer.all_special_ids else token for token in tokens |
|
] |
|
bytes: List[List[int]] = [self.token2bytes(token) for token in tokens] |
|
|
|
return ints2bytes(sum(bytes, [])) |
|
|
|
def token_id2bytes(self, token_id: int) -> bytes: |
|
token: str = self.tokenizer.convert_ids_to_tokens(token_id) |
|
return self.token2bytes(token) |
|
|
|
def token2bytes(self, token: str) -> bytes: |
|
|
|
bytes_seq: List[int] = [self.char2byte[c] for c in token] |
|
return bytes(bytes_seq) |
|
|