kjcjohnson's picture
Add GAD libraries
901bbd9
from typing import Dict, List
from transformers_gad.utils import get_tokenizer_model_type, ints2bytes
from transformers import AutoTokenizer
import logging
log = logging.getLogger(__name__)
def get_mapping(tokenizer, unicode=False):
log.debug(f"tokenizer type: {tokenizer.__class__.__name__}")
log.debug(f"tokenizer model type: {get_tokenizer_model_type(tokenizer)}")
if not unicode:
if (
"gpt2" in tokenizer.__class__.__name__.lower()
or "bloom" in tokenizer.__class__.__name__.lower()
or "pretrainedtokenizer" in tokenizer.__class__.__name__.lower()
or "codegen" in tokenizer.__class__.__name__.lower()
or "gptneox" in tokenizer.__class__.__name__.lower()
):
return BBPEMapping(tokenizer)
elif "t5" in tokenizer.__class__.__name__.lower():
return BPEMapping(tokenizer)
elif "llama" in tokenizer.__class__.__name__.lower():
return LlamaBPEMapping(tokenizer)
elif "xglm" in tokenizer.__class__.__name__.lower():
return UniGramMapping(tokenizer)
else:
raise ValueError(f"Unknown tokenizer type: {tokenizer.__class__.__name__}")
else:
if "gpt2" in tokenizer.__class__.__name__.lower():
return UnicodeBBPEMapping(tokenizer)
else:
raise NotImplementedError(
f"Unicode mapping for {tokenizer.__class__.__name__}"
)
class Mapping:
def __init__(self, tokenizer):
self.eos_token_id = tokenizer.eos_token_id
self.bos_token_id = tokenizer.bos_token_id
self.tokenizer = tokenizer
self.special = tokenizer.all_special_ids
def __len__(self):
return len(self.tokenizer.get_vocab())
def _map(self, token_id: int) -> str:
# This is the case for BOS,
if token_id in self.special:
return ""
# if token_id is tensor, convert it to int
if hasattr(token_id, "item"):
token_id = token_id.item()
raw_token = self.tokenizer.convert_ids_to_tokens(token_id)
return raw_token
def map(self, token_id: int, verbose=False) -> bytes:
token = self._map(token_id)
if verbose:
log.debug(f"token_id: {token_id}, token: {token}")
return bytes(token, "utf-8")
class BBPEMapping(Mapping):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _map(self, token_id: int) -> str:
raw_token = super()._map(token_id)
if raw_token.startswith("Ġ"):
raw_token = raw_token.replace("Ġ", " ")
return raw_token
class UnicodeBBPEMapping(Mapping):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.intermediate_encoding = UnicodeBBPEMapping.get_intermediate_encoding(
self.tokenizer
)
def _map(self, token_id: int, verbose=False) -> str:
raw_token = super()._map(token_id)
# if raw_token.startswith("Ġ"):
# raw_token = raw_token.replace("Ġ", " ")
return raw_token
def map(self, token_id: int, verbose=False) -> bytes:
raw_token = self._map(token_id, verbose)
if verbose:
log.debug(f"token_id: {token_id}, raw_token: {raw_token}")
return self.intermediate_encoding.token2bytes(raw_token)
@staticmethod
def get_intermediate_encoding(tokenizer):
if "gpt2" in tokenizer.__class__.__name__.lower():
return ByteEncoding(tokenizer)
else:
return None
class BPEMapping(Mapping):
def __init__(self, tokenizer):
super().__init__(tokenizer)
self.last_token_id = None
def _map(self, token_id: int) -> str:
raw_token = super()._map(token_id)
# we need to check if the token is at the beginning of the sentence to remove the space
# specific to BPE
at_bos = False
if self.last_token_id is not None and self.last_token_id == self.bos_token_id:
at_bos = True
self.last_token_id = token_id
if raw_token.startswith("▁"):
raw_token = raw_token.replace("▁", " ")
if at_bos:
# remove space at the beginning of the sentence
raw_token = raw_token[1:]
return raw_token
class LlamaBPEMapping(BPEMapping):
def __init__(self, tokenizer):
super().__init__(tokenizer)
def _map(self, token_id: int) -> str:
raw_token = super()._map(token_id)
# if the token is hex, token is a string like "<0x00>"
# first 256 tokens are hex
if raw_token.startswith("<0x"):
hex_value = raw_token[4:-1]
raw_token = chr(int(hex_value, 16))
return raw_token
class WordPieceMapping(Mapping):
def __init__(self, tokenizer):
super().__init__(tokenizer)
def map(self, token_id: int) -> bytes:
if token_id in self.special:
return bytes()
return bytes(
self.tokenizer.decode([token_id], clean_up_tokenization_spaces=False),
"utf-8",
)
class UniGramMapping(Mapping):
def __init__(self, tokenizer):
super().__init__(tokenizer)
def map(self, token_id: int) -> bytes:
if token_id in self.special:
return bytes()
return bytes(
self.tokenizer.decode([token_id], clean_up_tokenization_spaces=False),
"utf-8",
)
class XGLMUniGramMapping(Mapping):
def __init__(self, tokenizer):
super().__init__(tokenizer)
self.bos_token_id = tokenizer.eos_token_id
self.eos_token_id = None
class ByteEncoding:
def __init__(self, tokenizer):
# check if the tokenizer is fast, if so, convert it to slow
if tokenizer.is_fast:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer.name_or_path, use_fast=False
)
self.tokenizer = tokenizer
self.byte2char: Dict[int, str] = tokenizer.byte_encoder
self.char2byte: Dict[str, int] = tokenizer.byte_decoder
# code point to byte
self.cdp2byte: Dict[int, int] = {ord(c): b for c, b in self.char2byte.items()}
self.byte2cdp: Dict[int, int] = {v: k for k, v in self.cdp2byte.items()}
def map(self, byte: int) -> int:
assert 0 <= byte < 256, f"byte: {byte} is not in the range [0, 256)"
return ord(self.byte2char[byte])
def token_ids2bytes(self, token_ids: List[int]) -> bytes:
tokens: List[str] = self.tokenizer.convert_ids_to_tokens(token_ids)
# for token id = BOS, the token should be empty string instead of <s>
# TODO, this may cause issues because this means that special tokens like BOS can appear at any position
tokens = [
"" if token in self.tokenizer.all_special_ids else token for token in tokens
]
bytes: List[List[int]] = [self.token2bytes(token) for token in tokens]
# join the bytes
return ints2bytes(sum(bytes, []))
def token_id2bytes(self, token_id: int) -> bytes:
token: str = self.tokenizer.convert_ids_to_tokens(token_id)
return self.token2bytes(token)
def token2bytes(self, token: str) -> bytes:
# import pdb; pdb.set_trace()
bytes_seq: List[int] = [self.char2byte[c] for c in token]
return bytes(bytes_seq)