from transformers import AutoTokenizer class BaseStreamer: """ Base class from which `.generate()` streamers should inherit. """ def put(self, value): """Function that is called by `.generate()` to push new tokens""" raise NotImplementedError() def end(self): """Function that is called by `.generate()` to signal the end of generation""" raise NotImplementedError() class ByteStreamer(BaseStreamer): """ Simple text streamer that prints the token(s) to stdout as soon as entire words are formed. The API for the streamer classes is still under development and may change in the future. Parameters: tokenizer (`AutoTokenizer`): The tokenized used to decode the tokens. skip_prompt (`bool`, *optional*, defaults to `False`): Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots. decode_kwargs (`dict`, *optional*): Additional keyword arguments to pass to the tokenizer's `decode` method. Examples: ```python >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer >>> tok = AutoTokenizer.from_pretrained("gpt2") >>> model = AutoModelForCausalLM.from_pretrained("gpt2") >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt") >>> streamer = TextStreamer(tok) >>> # Despite returning the usual output, the streamer will also print the generated text to stdout. >>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20) An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven, ``` """ def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs): self.tokenizer = tokenizer self.skip_prompt = skip_prompt self.decode_kwargs = decode_kwargs # variables used in the streaming process self.token_cache = [] self.print_len = 0 self.next_tokens_are_prompt = True def put(self, value): """ Receives tokens, decodes them, and prints them to stdout as soon as they form entire words. """ if len(value.shape) > 1 and value.shape[0] > 1: raise ValueError("TextStreamer only supports batch size 1") elif len(value.shape) > 1: value = value[0] if self.skip_prompt and self.next_tokens_are_prompt: self.next_tokens_are_prompt = False return # Add the new token to the cache and decodes the entire thing. self.token_cache.extend(value.tolist()) text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs) # After the symbol for a new line, we flush the cache. if text.endswith("\n"): printable_text = text[self.print_len :] self.token_cache = [] self.print_len = 0 else: printable_text = text[self.print_len : self.print_len + 1] self.print_len += len(printable_text) self.on_finalized_text(printable_text) def end(self): """Flushes any remaining cache and prints a newline to stdout.""" # Flush the cache, if it exists if len(self.token_cache) > 0: text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs) printable_text = text[self.print_len :] self.token_cache = [] self.print_len = 0 else: printable_text = "" self.next_tokens_are_prompt = True self.on_finalized_text(printable_text, stream_end=True) def on_finalized_text(self, text: str, stream_end: bool = False): """Prints the new text to stdout. If the stream is ending, also prints a newline.""" print(text, flush=True, end="" if not stream_end else None)