from typing import Optional from collections import deque from queue import Queue import copy class History: def __init__(self, tokenizer, history): ''' init from a list of dict ''' # use deque to meet some special situation self.input_history = deque() self.tokenizer = tokenizer if history: self._transfer_from_list(history) def _transfer_from_list(self, history): for message in history: content = message.get("content") # the token result may not be equal to the result model gen message.update(self.tokenizer(content)) self.input_history.append(message) def append(self, message): content = message.get("content") if "input_ids" not in message or "attention_mask" not in message: message.update(self.tokenizer(content)) self.input_history.append(message) def append_left(self, message): content = message.get("content") if "input_ids" not in message or "attention_mask" not in message: message.update(self.tokenizer(content)) self.input_history.appendleft(message) def pop(self): x = self.input_history.pop() return x def pop_left(self): x = self.input_history.pop_left() return x def update(self, message): self.input_history.pop() self.append(message) def __len__(self): return self.input_history.__len__() def __str__(self): return self.input_history.__str__() def __copy__(self): new_instance = type(self)(self.tokenizer, []) new_instance.input_history = copy.copy(self.input_history) return new_instance def __deepcopy__(self, memodict={}): new_instance = type(self)(self.tokenizer, []) new_instance.input_history = copy.deepcopy(self.input_history) return new_instance class TelechatIterTextStreamer: """ With reference to the TextIterStreamers in transformers, we have rewritten this class """ def __init__( self, tokenizer, history: History = None, skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs ): self.tokenizer = tokenizer self.history = history self.skip_prompt = skip_prompt self.timeout = timeout self.decode_kwargs = decode_kwargs self.text_queue = Queue() self.cache_time = 0 self.text_until = "" self.token_until = [] self.stop_signal = None self.next_tokens_are_prompt = True self.history.append({"role": "bot", "content": self.text_until}) def put(self, value): """ put printable text into queue """ if len(value.shape) > 1 and value.shape[0] > 1: raise ValueError("TextStreamer only supports batch size 1") elif len(value.shape) > 1: value = value[0] if self.skip_prompt and self.next_tokens_are_prompt: self.next_tokens_are_prompt = False return if value[-1] == self.tokenizer.eos_token_id: return # there may be some smart way to decode. self.token_until.extend(value.tolist()) text = self.tokenizer.decode(self.token_until, **self.decode_kwargs) if self._is_printable(text) or self.cache_time >= 6: output_text = text[len(self.text_until):] self.text_until = text else: self.cache_time+=1 return self.on_finalized_text(output_text) def end(self): """Flushes any remaining cache and prints a newline to stdout.""" # Flush the cache, if it exists text = self.tokenizer.decode(self.token_until, **self.decode_kwargs) output_text = text[len(self.text_until):] self.text_until = text self.on_finalized_text(output_text, stream_end=True) self.clear_cache() def clear_cache(self): self.cache_time = 0 self.token_until = [] self.text_until = "" self.history = None self.next_tokens_are_prompt = True def on_finalized_text(self, text: str, stream_end: bool = False): """Put the text tuple in the queue.""" self.history.update({"role": "bot", "content": self.text_until, "input_ids": self.token_until, "attention_mask": [1] * len(self.token_until)}) self.text_queue.put((text, self.history), timeout=self.timeout) if stream_end: self.text_queue.put((self.stop_signal, self.history), timeout=self.timeout) @staticmethod def _is_printable(cp): """Checks whether tokens can be decoded or not""" if "�" in cp: return False return True def __iter__(self): return self def __next__(self): value_now, history_until = self.text_queue.get(timeout=self.timeout) if value_now == self.stop_signal: raise StopIteration() else: return value_now, history_until