File size: 8,177 Bytes
1bd70cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import torch
from transformers import StoppingCriteria, StoppingCriteriaList

from enums import PromptType, t5_type


class StoppingCriteriaSub(StoppingCriteria):

    def __init__(self, stops=[], stop_words=[], encounters=[], device="cuda", model_max_length=None, tokenizer=None):
        super().__init__()
        assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
        self.encounters = encounters
        self.stops = [stop.to(device) for stop in stops]
        self.stop_words = stop_words
        self.num_stops = [0] * len(stops)
        self.model_max_length = model_max_length
        self.tokenizer = tokenizer

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        #if self.tokenizer:
        #    print('stop: %s' % self.tokenizer.decode(input_ids[0]), flush=True)
        for stopi, (stop, stop_word) in enumerate(zip(self.stops, self.stop_words)):
            current_block = input_ids[0][-len(stop):]
            stop_text = self.tokenizer.decode(current_block)
            len_new_tokens = current_block.shape[0]
            #if len(stop) <= len_new_tokens and torch.all((stop == input_ids[0][-len(stop):])).item():
            if len(stop) <= len_new_tokens and stop_word in stop_text:
                self.num_stops[stopi] += 1
                if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
                    # print("Stopped", flush=True)
                    return True
        if self.model_max_length is not None and input_ids[0].shape[0] >= self.model_max_length:
            # critical limit
            return True
        # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
        # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
        return False


def get_stopping(prompt_type, prompt_dict, tokenizer, device, base_model,
                 human='<human>:', bot="<bot>:", model_max_length=None,
                 prompter=None,
                 stop=None):
    stop_words = []
    encounters = []
    # FIXME: prompt_dict unused currently
    user_human_assistant_types = [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
                                  PromptType.instruct_vicuna.name] + \
                                 [PromptType.guanaco.value, str(PromptType.guanaco.value),
                                  PromptType.guanaco.name] + \
                                 [PromptType.one_shot.value, str(PromptType.one_shot.value),
                                  PromptType.one_shot.name] + \
                                 [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
                                  PromptType.instruct_vicuna2.name] + \
                                 [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
                                  PromptType.instruct_vicuna3.name] + \
                                 [PromptType.instruct_with_end.value, str(PromptType.instruct_with_end.value),
                                  PromptType.instruct_with_end.name]
    human_bot_types = [PromptType.human_bot.value, str(PromptType.human_bot.value),
                       PromptType.human_bot.name] + \
                      [PromptType.human_bot_orig.value, str(PromptType.human_bot_orig.value),
                       PromptType.human_bot_orig.name]
    all_types = user_human_assistant_types + human_bot_types
    if prompt_type in all_types:
        if prompt_type in human_bot_types:
            # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
            # stopping only starts once output is beyond prompt
            # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
            stop_words = [human, bot, '\n' + human, '\n' + bot]
            encounters = [1, 2]
        elif prompt_type in user_human_assistant_types:
            # even below is not enough, generic strings and many ways to encode
            stop_words = [
                '### Human:',
                """
### Human:""",
                """
### Human:
""",
                """###  Human:  """,
                """###  Human:""",
                '### Assistant:',
                """
### Assistant:""",
                """
### Assistant:
""",
                """###  Assistant:  """,
                """###  Assistant:"""
            ]
            if prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
                               PromptType.instruct_vicuna2.name]:
                stop_words = [x.upper() for x in stop_words]
            if prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
                               PromptType.instruct_vicuna3.name]:
                stop_words = [x.replace('Human', 'User') for x in stop_words]
            encounters = [1, 2]
        else:
            # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
            stop_words = ['### End']
            encounters = [1]
    elif prompter and prompter.terminate_response:
        stop_words = prompter.terminate_response
        encounters = [1] * len(stop_words)
    handle_newlines = [True] * len(stop_words)


    # add other stop words too if passed, e.g. for LangChain agents
    if stop:
        stop_words += stop
        encounters += [1] * len(stop)
        handle_newlines += [False] * len(stop)

    # get stop tokens
    stop_words_ids = [
        tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
    # handle single token case
    stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
    stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
    # avoid padding in front of tokens
    if tokenizer._pad_token:  # use hidden variable to avoid annoying properly logger bug
        stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
    if tokenizer._unk_token:  # use hidden variable to avoid annoying properly logger bug
        stop_words_ids = [x[1:] if x[0] == tokenizer.unk_token_id and len(x) > 1 else x for x in stop_words_ids]
        stop_words_ids = [x[:-1] if x[-1] == tokenizer.unk_token_id and len(x) > 1 else x for x in stop_words_ids]
    if tokenizer._eos_token:  # use hidden variable to avoid annoying properly logger bug
        stop_words_ids = [x[:-1] if x[-1] == tokenizer.eos_token_id and len(x) > 1 else x for x in stop_words_ids]
    if tokenizer._bos_token:  # use hidden variable to avoid annoying properly logger bug
        stop_words_ids = [x[1:] if x[0] == tokenizer.bos_token_id and len(x) > 1 else x for x in stop_words_ids]
        stop_words_ids = [x[:-1] if x[-1] == tokenizer.bos_token_id and len(x) > 1 else x for x in stop_words_ids]
    if base_model and t5_type(base_model):
        # T5 encoder converts internal double space to space+new line, so fix
        for stopi, stop_word_id in enumerate(stop_words_ids):
            start = stop_word_id[0:1]
            mlist = stop_word_id[1:-1]
            end = stop_word_id[-1:]
            mlist = [tokenizer.vocab[' '] if x == tokenizer.vocab['\n'] else x for x in mlist]
            stop_words_ids[stopi] = torch.tensor(list(start) + list(mlist) + list(end), device=stop_word_id.device)
    # handle fake \n added
    stop_words_ids = [x[1:] if y[0] == '\n' and handle_newline else x for x, y, handle_newline in
                      zip(stop_words_ids, stop_words, handle_newlines)]
    if stop_words_ids:
        # build stopper
        stopping_criteria = StoppingCriteriaList(
            [StoppingCriteriaSub(stops=stop_words_ids,
                                 stop_words=stop_words,
                                 encounters=encounters, device=device,
                                 model_max_length=model_max_length, tokenizer=tokenizer)])
    else:
        # nothing to stop on
        stopping_criteria = StoppingCriteriaList()
    return stopping_criteria