Spaces:
Sleeping
Sleeping
alonsosilva
commited on
Commit
•
38c7b49
1
Parent(s):
dbf4a45
Use TextIteratorStreamer instead of custom Streamer
Browse files
app.py
CHANGED
@@ -1,149 +1,9 @@
|
|
1 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
2 |
|
3 |
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
|
4 |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
|
5 |
-
|
6 |
-
class BaseStreamer:
|
7 |
-
"""
|
8 |
-
Base class from which `.generate()` streamers should inherit.
|
9 |
-
"""
|
10 |
-
|
11 |
-
def put(self, value):
|
12 |
-
"""Function that is called by `.generate()` to push new tokens"""
|
13 |
-
raise NotImplementedError()
|
14 |
-
|
15 |
-
def end(self):
|
16 |
-
"""Function that is called by `.generate()` to signal the end of generation"""
|
17 |
-
raise NotImplementedError()
|
18 |
-
|
19 |
-
class TextStreamer(BaseStreamer):
|
20 |
-
"""
|
21 |
-
Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
|
22 |
-
|
23 |
-
<Tip warning={true}>
|
24 |
-
|
25 |
-
The API for the streamer classes is still under development and may change in the future.
|
26 |
-
|
27 |
-
</Tip>
|
28 |
-
|
29 |
-
Parameters:
|
30 |
-
tokenizer (`AutoTokenizer`):
|
31 |
-
The tokenized used to decode the tokens.
|
32 |
-
skip_prompt (`bool`, *optional*, defaults to `False`):
|
33 |
-
Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
|
34 |
-
decode_kwargs (`dict`, *optional*):
|
35 |
-
Additional keyword arguments to pass to the tokenizer's `decode` method.
|
36 |
-
|
37 |
-
Examples:
|
38 |
-
|
39 |
-
```python
|
40 |
-
>>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
41 |
-
|
42 |
-
>>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
|
43 |
-
>>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
44 |
-
>>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
|
45 |
-
>>> streamer = TextStreamer(tok)
|
46 |
-
|
47 |
-
>>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
|
48 |
-
>>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
|
49 |
-
An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
|
50 |
-
```
|
51 |
-
"""
|
52 |
-
|
53 |
-
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
|
54 |
-
self.tokenizer = tokenizer
|
55 |
-
self.skip_prompt = skip_prompt
|
56 |
-
self.decode_kwargs = decode_kwargs
|
57 |
-
|
58 |
-
# variables used in the streaming process
|
59 |
-
self.token_cache = []
|
60 |
-
self.print_len = 0
|
61 |
-
self.next_tokens_are_prompt = True
|
62 |
-
|
63 |
-
def put(self, value):
|
64 |
-
"""
|
65 |
-
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
|
66 |
-
"""
|
67 |
-
if len(value.shape) > 1 and value.shape[0] > 1:
|
68 |
-
raise ValueError("TextStreamer only supports batch size 1")
|
69 |
-
elif len(value.shape) > 1:
|
70 |
-
value = value[0]
|
71 |
-
|
72 |
-
if self.skip_prompt and self.next_tokens_are_prompt:
|
73 |
-
self.next_tokens_are_prompt = False
|
74 |
-
return
|
75 |
-
|
76 |
-
# Add the new token to the cache and decodes the entire thing.
|
77 |
-
self.token_cache.extend(value.tolist())
|
78 |
-
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
|
79 |
-
|
80 |
-
# After the symbol for a new line, we flush the cache.
|
81 |
-
if text.endswith("\n"):
|
82 |
-
printable_text = text[self.print_len :]
|
83 |
-
self.token_cache = []
|
84 |
-
self.print_len = 0
|
85 |
-
# If the last token is a CJK character, we print the characters.
|
86 |
-
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
|
87 |
-
printable_text = text[self.print_len :]
|
88 |
-
self.print_len += len(printable_text)
|
89 |
-
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
|
90 |
-
# which may change with the subsequent token -- there are probably smarter ways to do this!)
|
91 |
-
else:
|
92 |
-
printable_text = text[self.print_len : text.rfind(" ") + 1]
|
93 |
-
self.print_len += len(printable_text)
|
94 |
-
|
95 |
-
self.on_finalized_text(printable_text)
|
96 |
-
|
97 |
-
def end(self):
|
98 |
-
"""Flushes any remaining cache and prints a newline to stdout."""
|
99 |
-
# Flush the cache, if it exists
|
100 |
-
if len(self.token_cache) > 0:
|
101 |
-
text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
|
102 |
-
printable_text = text[self.print_len :]
|
103 |
-
self.token_cache = []
|
104 |
-
self.print_len = 0
|
105 |
-
else:
|
106 |
-
printable_text = ""
|
107 |
-
|
108 |
-
self.next_tokens_are_prompt = True
|
109 |
-
self.on_finalized_text(printable_text, stream_end=True)
|
110 |
-
|
111 |
-
def on_finalized_text(self, text: str, stream_end: bool = False):
|
112 |
-
"""Prints the new text to stdout. If the stream is ending, also prints a newline."""
|
113 |
-
# print(text, flush=True, end="" if not stream_end else None)
|
114 |
-
messages.value = [
|
115 |
-
*messages.value[:-1],
|
116 |
-
{
|
117 |
-
"role": "assistant",
|
118 |
-
"content": messages.value[-1]["content"] + text,
|
119 |
-
},
|
120 |
-
]
|
121 |
-
|
122 |
-
def _is_chinese_char(self, cp):
|
123 |
-
"""Checks whether CP is the codepoint of a CJK character."""
|
124 |
-
# This defines a "chinese character" as anything in the CJK Unicode block:
|
125 |
-
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
126 |
-
#
|
127 |
-
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
128 |
-
# despite its name. The modern Korean Hangul alphabet is a different block,
|
129 |
-
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
130 |
-
# space-separated words, so they are not treated specially and handled
|
131 |
-
# like the all of the other languages.
|
132 |
-
if (
|
133 |
-
(cp >= 0x4E00 and cp <= 0x9FFF)
|
134 |
-
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
135 |
-
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
136 |
-
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
137 |
-
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
138 |
-
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
139 |
-
or (cp >= 0xF900 and cp <= 0xFAFF)
|
140 |
-
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
141 |
-
): #
|
142 |
-
return True
|
143 |
-
|
144 |
-
return False
|
145 |
-
|
146 |
-
streamer = TextStreamer(tokenizer, skip_prompt=True)
|
147 |
|
148 |
import re
|
149 |
import solara
|
@@ -176,7 +36,17 @@ def Page():
|
|
176 |
add_generation_prompt=True
|
177 |
)
|
178 |
inputs = tokenizer(text, return_tensors="pt")
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
def result():
|
181 |
if messages.value != []:
|
182 |
response(messages.value[-1]["content"])
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
2 |
+
from threading import Thread
|
3 |
|
4 |
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
|
5 |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
|
6 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
import re
|
9 |
import solara
|
|
|
36 |
add_generation_prompt=True
|
37 |
)
|
38 |
inputs = tokenizer(text, return_tensors="pt")
|
39 |
+
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=512)
|
40 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
41 |
+
thread.start()
|
42 |
+
for text in streamer:
|
43 |
+
messages.value = [
|
44 |
+
*messages.value[:-1],
|
45 |
+
{
|
46 |
+
"role": "assistant",
|
47 |
+
"content": messages.value[-1]["content"] + text,
|
48 |
+
},
|
49 |
+
]
|
50 |
def result():
|
51 |
if messages.value != []:
|
52 |
response(messages.value[-1]["content"])
|