mgoin commited on
Commit
6fa11ce
·
verified ·
1 Parent(s): 05a721a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -0
app.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import deepsparse
2
+ import gradio as gr
3
+ from typing import Tuple, List
4
+
5
+ deepsparse.cpu.print_hardware_capability()
6
+
7
+ MODEL_ID = "hf:neuralmagic/Llama-2-7b-pruned70-retrained-ultrachat-quant-ds"
8
+
9
+ DESCRIPTION = f"""
10
+ # LLM Chat on CPU with DeepSparse
11
+ The model stub for this example is: {MODEL_ID}
12
+
13
+ #### Accelerated Inference on CPUs
14
+ The Llama 2 model runs purely on CPU courtesy of [sparse software execution by DeepSparse](https://github.com/neuralmagic/deepsparse).
15
+ DeepSparse provides accelerated inference by taking advantage of the model's weight sparsity to deliver tokens fast!
16
+ """
17
+
18
+ MAX_MAX_NEW_TOKENS = 1024
19
+ DEFAULT_MAX_NEW_TOKENS = 200
20
+
21
+ # Setup the engine
22
+ pipe = deepsparse.Pipeline.create(
23
+ task="text-generation",
24
+ model_path=MODEL_ID,
25
+ sequence_length=MAX_MAX_NEW_TOKENS,
26
+ prompt_sequence_length=16,
27
+ num_cores=8,
28
+ )
29
+
30
+
31
+ def clear_and_save_textbox(message: str) -> Tuple[str, str]:
32
+ return "", message
33
+
34
+
35
+ def display_input(
36
+ message: str, history: List[Tuple[str, str]]
37
+ ) -> List[Tuple[str, str]]:
38
+ history.append((message, ""))
39
+ return history
40
+
41
+
42
+ def delete_prev_fn(history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], str]:
43
+ try:
44
+ message, _ = history.pop()
45
+ except IndexError:
46
+ message = ""
47
+ return history, message or ""
48
+
49
+
50
+ with gr.Blocks(css="style.css") as demo:
51
+ gr.Markdown(DESCRIPTION)
52
+
53
+ with gr.Group():
54
+ chatbot = gr.Chatbot(label="Chatbot")
55
+ with gr.Row():
56
+ textbox = gr.Textbox(
57
+ container=False,
58
+ show_label=False,
59
+ placeholder="Type a message...",
60
+ scale=10,
61
+ )
62
+ submit_button = gr.Button("Submit", variant="primary", scale=1, min_width=0)
63
+
64
+ with gr.Row():
65
+ retry_button = gr.Button("🔄 Retry", variant="secondary")
66
+ undo_button = gr.Button("↩️ Undo", variant="secondary")
67
+ clear_button = gr.Button("🗑️ Clear", variant="secondary")
68
+
69
+ saved_input = gr.State()
70
+
71
+ gr.Examples(
72
+ examples=["Write a story about sparse neurons."],
73
+ inputs=[textbox],
74
+ )
75
+
76
+ max_new_tokens = gr.Slider(
77
+ label="Max new tokens",
78
+ value=DEFAULT_MAX_NEW_TOKENS,
79
+ minimum=0,
80
+ maximum=MAX_MAX_NEW_TOKENS,
81
+ step=1,
82
+ interactive=True,
83
+ info="The maximum numbers of new tokens",
84
+ )
85
+ temperature = gr.Slider(
86
+ label="Temperature",
87
+ value=0.9,
88
+ minimum=0.05,
89
+ maximum=1.0,
90
+ step=0.05,
91
+ interactive=True,
92
+ info="Higher values produce more diverse outputs",
93
+ )
94
+ top_p = gr.Slider(
95
+ label="Top-p (nucleus) sampling",
96
+ value=0.40,
97
+ minimum=0.0,
98
+ maximum=1,
99
+ step=0.05,
100
+ interactive=True,
101
+ info="Higher values sample more low-probability tokens",
102
+ )
103
+ top_k = gr.Slider(
104
+ label="Top-k sampling",
105
+ value=20,
106
+ minimum=1,
107
+ maximum=100,
108
+ step=1,
109
+ interactive=True,
110
+ info="Sample from the top_k most likely tokens",
111
+ )
112
+ reptition_penalty = gr.Slider(
113
+ label="Repetition penalty",
114
+ value=1.2,
115
+ minimum=1.0,
116
+ maximum=2.0,
117
+ step=0.05,
118
+ interactive=True,
119
+ info="Penalize repeated tokens",
120
+ )
121
+
122
+ # Generation inference
123
+ def generate(
124
+ message,
125
+ history,
126
+ max_new_tokens: int,
127
+ temperature: float,
128
+ top_p: float,
129
+ top_k: int,
130
+ reptition_penalty: float,
131
+ ):
132
+ generation_config = {
133
+ "max_new_tokens": max_new_tokens,
134
+ "do_sample": True,
135
+ "temperature": temperature,
136
+ "top_p": top_p,
137
+ "top_k": top_k,
138
+ "reptition_penalty": reptition_penalty,
139
+ }
140
+
141
+ conversation = []
142
+ conversation.append({"role": "user", "content": message})
143
+
144
+ formatted_conversation = pipe.tokenizer.apply_chat_template(
145
+ conversation, tokenize=False, add_generation_prompt=True
146
+ )
147
+
148
+ inference = pipe(
149
+ sequences=formatted_conversation,
150
+ generation_config=generation_config,
151
+ streaming=True,
152
+ )
153
+
154
+ for token in inference:
155
+ history[-1][1] += token.generations[0].text
156
+ yield history
157
+
158
+ print(pipe.timer_manager)
159
+
160
+ # Hooking up all the buttons
161
+ textbox.submit(
162
+ fn=clear_and_save_textbox,
163
+ inputs=textbox,
164
+ outputs=[textbox, saved_input],
165
+ api_name=False,
166
+ queue=False,
167
+ ).then(
168
+ fn=display_input,
169
+ inputs=[saved_input, chatbot],
170
+ outputs=chatbot,
171
+ api_name=False,
172
+ queue=False,
173
+ ).success(
174
+ generate,
175
+ inputs=[
176
+ saved_input,
177
+ chatbot,
178
+ max_new_tokens,
179
+ temperature,
180
+ top_p,
181
+ top_k,
182
+ reptition_penalty,
183
+ ],
184
+ outputs=[chatbot],
185
+ api_name=False,
186
+ )
187
+
188
+ submit_button.click(
189
+ fn=clear_and_save_textbox,
190
+ inputs=textbox,
191
+ outputs=[textbox, saved_input],
192
+ api_name=False,
193
+ queue=False,
194
+ ).then(
195
+ fn=display_input,
196
+ inputs=[saved_input, chatbot],
197
+ outputs=chatbot,
198
+ api_name=False,
199
+ queue=False,
200
+ ).success(
201
+ generate,
202
+ inputs=[
203
+ saved_input,
204
+ chatbot,
205
+ max_new_tokens,
206
+ temperature,
207
+ top_p,
208
+ top_k,
209
+ reptition_penalty,
210
+ ],
211
+ outputs=[chatbot],
212
+ api_name=False,
213
+ )
214
+
215
+ retry_button.click(
216
+ fn=delete_prev_fn,
217
+ inputs=chatbot,
218
+ outputs=[chatbot, saved_input],
219
+ api_name=False,
220
+ queue=False,
221
+ ).then(
222
+ fn=display_input,
223
+ inputs=[saved_input, chatbot],
224
+ outputs=chatbot,
225
+ api_name=False,
226
+ queue=False,
227
+ ).then(
228
+ generate,
229
+ inputs=[
230
+ saved_input,
231
+ chatbot,
232
+ max_new_tokens,
233
+ temperature,
234
+ top_p,
235
+ top_k,
236
+ reptition_penalty,
237
+ ],
238
+ outputs=[chatbot],
239
+ api_name=False,
240
+ )
241
+
242
+ undo_button.click(
243
+ fn=delete_prev_fn,
244
+ inputs=chatbot,
245
+ outputs=[chatbot, saved_input],
246
+ api_name=False,
247
+ queue=False,
248
+ ).then(
249
+ fn=lambda x: x,
250
+ inputs=[saved_input],
251
+ outputs=textbox,
252
+ api_name=False,
253
+ queue=False,
254
+ )
255
+
256
+ clear_button.click(
257
+ fn=lambda: ([], ""),
258
+ outputs=[chatbot, saved_input],
259
+ queue=False,
260
+ api_name=False,
261
+ )
262
+
263
+ demo.queue().launch(share=True)