File size: 4,800 Bytes
9b05693
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
from typing import Optional
from collections import deque
from queue import Queue
import copy


class History:

    def __init__(self, tokenizer, history):
        '''
        init from a list of dict
        '''
        # use deque to meet some special situation
        self.input_history = deque()
        self.tokenizer = tokenizer
        if history:
            self._transfer_from_list(history)

    def _transfer_from_list(self, history):
        for message in history:
            content = message.get("content")
            message.update(self.tokenizer(content))
            self.input_history.append(message)

    def append(self, message):
        content = message.get("content")
        message.update(self.tokenizer(content))
        self.input_history.append(message)

    def append_left(self, message):
        content = message.get("content")
        message.update(self.tokenizer(content))
        self.input_history.appendleft(message)

    def pop(self):
        x = self.input_history.pop()
        return x

    def pop_left(self):
        x = self.pop_left()
        return x

    def update(self, content: str):
        x = self.input_history.pop()
        self.append({"role": x["role"], "content": content})

    def __len__(self):
        return self.input_history.__len__()

    def __str__(self):
        return self.input_history.__str__()

    def __copy__(self):
        new_instance = type(self)(self.tokenizer, [])
        new_instance.input_history = copy.copy(self.input_history)
        return new_instance

    def __deepcopy__(self, memodict={}):
        new_instance = type(self)(self.tokenizer, [])
        new_instance.input_history = copy.deepcopy(self.input_history)
        return new_instance


class TelechatIterTextStreamer:
    """
    With reference to the TextIterStreamers in transformers, we have rewritten this class
    """

    def __init__(
            self, tokenizer, history: History = None, skip_prompt: bool = False, timeout: Optional[float] = None,
            **decode_kwargs
    ):

        self.tokenizer = tokenizer
        self.history = history
        self.skip_prompt = skip_prompt
        self.timeout = timeout
        self.decode_kwargs = decode_kwargs

        self.text_queue = Queue()
        self.token_cache = []
        self.cache_time = 0
        self.text_until = ""
        self.stop_signal = None
        self.next_tokens_are_prompt = True

        self.history.append({"role": "bot", "content": self.text_until})

    def put(self, value):
        """
        put printable text into queue
        """
        if len(value.shape) > 1 and value.shape[0] > 1:
            raise ValueError("TextStreamer only supports batch size 1")
        elif len(value.shape) > 1:
            value = value[0]

        if self.skip_prompt and self.next_tokens_are_prompt:
            self.next_tokens_are_prompt = False
            return

        if value[-1] == self.tokenizer.eos_token_id:
            return

        # there may be some smart way to decode.
        self.token_cache.extend(value.tolist())
        text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
        self.cache_time += 1

        if self._is_printable(text) or self.cache_time >= 6:
            self.text_until += text
            self.token_cache = []
            self.cache_time = 0

        else:
            return

        self.on_finalized_text(text)

    def end(self):
        """Flushes any remaining cache and prints a newline to stdout."""
        # Flush the cache, if it exists
        text = ""
        if len(self.token_cache) > 0:
            text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
            self.text_until += text
        self.on_finalized_text(text, stream_end=True)
        self.clear_cache()

    def clear_cache(self):
        self.cache_time = 0
        self.token_cache = []
        self.text_until = ""
        self.history = None
        self.next_tokens_are_prompt = True

    def on_finalized_text(self, text: str, stream_end: bool = False):
        """Put the text tuple in the queue."""
        self.history.update(self.text_until)
        self.text_queue.put((text, self.history), timeout=self.timeout)
        if stream_end:
            self.text_queue.put((self.stop_signal, self.history), timeout=self.timeout)

    @staticmethod
    def _is_printable(cp):
        """Checks whether tokens can be decoded or not"""
        if "�" in cp:
            return False
        return True

    def __iter__(self):
        return self

    def __next__(self):
        value_now, history_until = self.text_queue.get(timeout=self.timeout)
        if value_now == self.stop_signal:
            raise StopIteration()
        else:
            return value_now, history_until