dart-v1-base / tokenization_dart.py
p1atdev's picture
Upload tokenization_dart.py
d80fbc9 verified
raw
history blame
No virus
2.76 kB
import logging
import json
from typing import Dict, List
from pydantic.dataclasses import dataclass
from transformers import PreTrainedTokenizerFast
from tokenizers.decoders import Decoder
logger = logging.getLogger(__name__)
# fmt: off
# https://huggingface.co/docs/transformers/main/en/chat_templating
PROMPT_TEMPLATE = (
"{{ '<|bos|>' }}"
"{{ '<rating>' }}"
"{% if 'rating' not in messages or messages['rating'] is none %}"
"{{ 'rating:sfw, rating:general' }}"
"{% else %}"
"{{ messages['rating'] }}"
"{% endif %}"
"{{ '</rating>' }}"
"{{ '<copyright>' }}"
"{% if 'copyright' not in messages or messages['copyright'] is none %}"
"{{ '' }}"
"{% else %}"
"{{ messages['copyright'] }}"
"{% endif %}"
"{{ '</copyright>' }}"
"{{ '<character>' }}"
"{% if 'character' not in messages or messages['character'] is none %}"
"{{ '' }}"
"{% else %}"
"{{ messages['character'] }}"
"{% endif %}"
"{{ '</character>' }}"
"{{ '<general>' }}"
"{% if 'general' not in messages or messages['general'] is none %}"
"{{ '' }}"
"{% else %}"
"{{ messages['general'] }}"
"{% endif %}"
).strip()
# fmt: on
@dataclass
class Category:
name: str
bos_token_id: int
eos_token_id: int
@dataclass
class TagCategoryConfig:
categories: Dict[str, Category]
category_to_token_ids: Dict[str, List[int]]
def load_tag_category_config(config_json: str):
with open(config_json, "rb") as file:
config: TagCategoryConfig = TagCategoryConfig(**json.loads(file.read()))
return config
class DartDecoder:
def __init__(self, special_tokens: List[str]):
self.special_tokens = list(special_tokens)
def decode_chain(self, tokens: List[str]) -> List[str]:
new_tokens = []
is_specials = []
for i, token in enumerate(tokens):
is_specials.append(token in self.special_tokens)
if i == 0:
new_tokens.append(token)
continue
# this token or previous token is special
if is_specials[i] or is_specials[i - 1]:
new_tokens.append(token)
continue
new_tokens.append(f", {token}")
return new_tokens
class DartTokenizer(PreTrainedTokenizerFast):
"""Dart tokenizer"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._tokenizer.decoder = Decoder.custom( # type: ignore
DartDecoder(list(self.get_added_vocab().keys()))
)
@property
def default_chat_template(self):
"""
Danbooru Tags Transformer uses special format prompt to generate danbooru tags.
"""
return PROMPT_TEMPLATE