File size: 2,524 Bytes
1c8f03b
ce448ae
1c8f03b
 
 
 
 
 
 
0df470a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c8f03b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0df470a
1c8f03b
 
 
 
 
 
0df470a
 
 
 
 
1c8f03b
0df470a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import logging
from typing import List

from transformers import PreTrainedTokenizerFast
from tokenizers.decoders import Decoder

logger = logging.getLogger(__name__)


# fmt: off
# https://huggingface.co/docs/transformers/main/en/chat_templating
PROMPT_TEMPLATE = (
    "{{ '<|bos|>' }}" 
    
    "{{ '<rating>' }}"
    "{% if 'rating' not in messages or messages['rating'] is none %}"
    "{{ 'rating:sfw, rating:general' }}"
    "{% else %}"
    "{{ messages['rating'] }}"
    "{% endif %}"
    "{{ '</rating>' }}"

    "{{ '<copyright>' }}"
    "{% if 'copyright' not in messages or messages['copyright'] is none %}"
    "{{ '' }}"
    "{% else %}"
    "{{ messages['copyright'] }}"
    "{% endif %}"
    "{{ '</copyright>' }}"

    "{{ '<character>' }}"
    "{% if 'character' not in messages or messages['character'] is none %}"
    "{{ '' }}"
    "{% else %}"
    "{{ messages['character'] }}"
    "{% endif %}"
    "{{ '</character>' }}"

    "{{ '<general>' }}"
    # length token
    "{% if 'length' not in messages or messages['length'] is none %}"
    "{{ '<|long|>' }}"
    "{% else %}"
    "{{ messages['length'] }}"
    "{% endif %}"

    # general token
    "{% if 'general' not in messages or messages['general'] is none %}"
    "{{ '' }}"
    "{% else %}"
    "{{ messages['general'] }}"
    "{% endif %}"
    "{{ '<|input_end|>' }}"
).strip()
# fmt: on


class DartDecoder:
    def __init__(self, special_tokens: List[str]):
        self.special_tokens = list(special_tokens)

    def decode_chain(self, tokens: List[str]) -> List[str]:
        new_tokens = []
        is_specials = []

        for i, token in enumerate(tokens):
            is_specials.append(token in self.special_tokens)

            if i == 0:
                new_tokens.append(token)
                continue

            # this token or previous token is special
            if is_specials[i] or is_specials[i - 1]:
                new_tokens.append(token)
                continue

            new_tokens.append(f", {token}")

        return new_tokens


class DartTokenizer(PreTrainedTokenizerFast):
    """Dart tokenizer"""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self._tokenizer.decoder = Decoder.custom(  # type: ignore
            DartDecoder(list(self.get_added_vocab().keys()))
        )

    @property
    def default_chat_template(self):
        """
        Danbooru Tags Transformer uses special format prompt to generate danbooru tags.
        """

        return PROMPT_TEMPLATE