dart-v1-sft / tokenization_dart.py
p1atdev's picture
Upload 5 files
1c8f03b verified
raw history blame
No virus
3.78 kB
import logging
import os
import json
from typing import Optional, Dict, List, Tuple, Union
from pydantic.dataclasses import dataclass
import numpy as np
from numpy.typing import NDArray
from transformers import PreTrainedTokenizerFast
from tokenizers.decoders import Decoder
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {
"category_config": "category_config.json",
}
PRETRAINED_VOCAB_FILES_MAP = {
"category_config": {
"p1atdev/dart-tokenizer-v1": "https://huggingface.co/p1atdev/dart-tokenizer-v1/resolve/main/tag_category.json"
}
}
@dataclass
class Category:
name: str
bos_token_id: int
eos_token_id: int
@dataclass
class TagCategoryConfig:
categories: Dict[str, Category]
category_to_token_ids: Dict[str, List[int]]
def load_tag_category_config(config_json: str):
with open(config_json, "rb") as file:
config: TagCategoryConfig = TagCategoryConfig(**json.loads(file.read()))
return config
class DartDecoder:
def __init__(self, special_tokens: List[str]):
self.special_tokens = list(special_tokens)
def decode_chain(self, tokens: List[str]) -> List[str]:
new_tokens = []
is_specials = []
for i, token in enumerate(tokens):
is_specials.append(token in self.special_tokens)
if i == 0:
new_tokens.append(token)
continue
# this token or previous token is special
if is_specials[i] or is_specials[i - 1]:
new_tokens.append(token)
continue
new_tokens.append(f", {token}")
return new_tokens
class DartTokenizer(PreTrainedTokenizerFast):
"""Dart tokenizer"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
def __init__(self, category_config, **kwargs):
super().__init__(**kwargs)
self._tokenizer.decoder = Decoder.custom( # type: ignore
DartDecoder(list(self.get_added_vocab().keys()))
)
self.category_config = load_tag_category_config(category_config)
self._id_to_category_map = np.zeros(self.vocab_size).astype("uint8")
for (
category_id,
tokens,
) in self.category_config.category_to_token_ids.items():
self._id_to_category_map[tokens] = int(category_id)
def create_vocab_mask(self, value: int = 1):
"""Create an array of vocab size filled with specified value"""
return np.full(self.vocab_size, value).astype("uint8")
def get_token_ids_in_category(self, category_id: Union[int, str]):
"""Get token ids in the specified category"""
return self.category_config.category_to_token_ids[str(category_id)]
def get_category(self, category_id: Union[int, str]):
"""Get the specified category config"""
return self.category_config.categories[str(category_id)]
def convert_ids_to_category_ids(self, token_ids: Union[int, List[int]]):
"""Get the category ids of specified tokens"""
return self._id_to_category_map[token_ids]
def get_banned_tokens_mask(self, tokens: Union[str, List[str], int, List[int]]):
if isinstance(tokens, str):
tokens = [tokens]
elif isinstance(tokens, int):
tokens = [tokens]
elif isinstance(tokens, list):
tokens = [ # type: ignore
self.convert_tokens_to_ids(token) if isinstance(token, str) else token
for token in tokens
]
assert isinstance(tokens, list) and all(
[isinstance(token, int) for token in tokens]
)
mask = self.create_vocab_mask(value=1)
mask[tokens] = 0
return mask