mwitiderrick commited on
Commit
be66a58
β€’
1 Parent(s): ae84c21

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -0
app.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import deepsparse
2
+ import gradio as gr
3
+ from typing import Tuple, List
4
+
5
+ deepsparse.cpu.print_hardware_capability()
6
+
7
+ MODEL_ID = "mgoin/TinyStories-1M-deepsparse"
8
+
9
+ DESCRIPTION = f"""
10
+ # MPT Sparse Finetuned on GSM8k with DeepSparse
11
+ ![NM Logo](https://files.slack.com/files-pri/T020WGRLR8A-F05TXD28BBK/neuralmagic-logo.png?pub_secret=54e8db19db)
12
+ Model ID: {MODEL_ID}
13
+ **πŸš€ Experience the power of LLM mathematical reasoning** through our MPT sparse finetuned on the [GSM8K dataset](https://huggingface.co/datasets/gsm8k).
14
+ GSM8K, short for Grade School Math 8K, is a collection of 8.5K high-quality linguistically diverse grade school math word problems, designed to challenge question-answering systems with multi-step reasoning.
15
+ Observe the model's performance in deciphering complex math questions, such as "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?" and offering detailed step-by-step solutions.
16
+ ## Accelerated Inferenced on CPUs
17
+ The MPT model runs purely on CPU courtesy of software acceleration by DeepSparse. DeepSparse offers accelerated inference by taking advantage of the MPT model's sparse sparsity
18
+ hence delivering results fast.
19
+ ![Speed](https://files.slack.com/files-pri/T020WGRLR8A-F0605DZ0B7G/image3.png?pub_secret=ab0f1d72b6)
20
+ """
21
+ from huggingface_hub import snapshot_download
22
+ MODEL_ID = snapshot_download(repo_id=MODEL_ID, use_auth_token="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")
23
+
24
+ MAX_MAX_NEW_TOKENS = 1024
25
+ DEFAULT_MAX_NEW_TOKENS = 200
26
+
27
+ # Setup the engine
28
+ pipe = deepsparse.Pipeline.create(
29
+ task="text-generation",
30
+ model_path=MODEL_ID,
31
+ sequence_length=MAX_MAX_NEW_TOKENS,
32
+ prompt_sequence_length=16,
33
+ )
34
+
35
+ return "", message
36
+
37
+
38
+ def display_input(
39
+ message: str, history: List[Tuple[str, str]]
40
+ ) -> List[Tuple[str, str]]:
41
+ history.append((message, ""))
42
+ return history
43
+
44
+
45
+ def delete_prev_fn(history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], str]:
46
+ try:
47
+ message, _ = history.pop()
48
+ except IndexError:
49
+ message = ""
50
+ return history, message or ""
51
+
52
+ with gr.Blocks() as demo:
53
+ with gr.Row():
54
+ with gr.Column():
55
+ gr.Markdown(DESCRIPTION)
56
+ with gr.Column():
57
+ gr.Markdown("""### MPT Sparse Finetuned Demo""")
58
+
59
+ with gr.Group():
60
+ chatbot = gr.Chatbot(label="Chatbot")
61
+ with gr.Row():
62
+ textbox = gr.Textbox(container=False,placeholder="Type a message...",scale=10,)
63
+ submit_button = gr.Button("Submit", variant="primary", scale=1, min_width=0)
64
+
65
+ with gr.Row():
66
+ retry_button = gr.Button("πŸ”„ Retry", variant="secondary")
67
+ undo_button = gr.Button("↩️ Undo", variant="secondary")
68
+ clear_button = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
69
+
70
+ saved_input = gr.State()
71
+
72
+ gr.Examples(examples=[
73
+ "James decides to run 3 sprints 3 times a week. He runs 60 meters each sprint. How many total meters does he run a week?",
74
+ "Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks?",
75
+ "Gretchen has 110 coins. There are 30 more gold coins than silver coins. How many gold coins does Gretchen have?",],inputs=[textbox],)
76
+
77
+ max_new_tokens = gr.Slider(
78
+ label="Max new tokens",
79
+ value=DEFAULT_MAX_NEW_TOKENS,
80
+ minimum=0,
81
+ maximum=MAX_MAX_NEW_TOKENS,
82
+ step=1,
83
+ interactive=True,
84
+ info="The maximum numbers of new tokens",)
85
+ temperature = gr.Slider(
86
+ label="Temperature",
87
+ value=0.3,
88
+ minimum=0.05,
89
+ maximum=1.0,
90
+ step=0.05,
91
+ interactive=True,
92
+ info="Higher values produce more diverse outputs",
93
+ )
94
+ top_p = gr.Slider(
95
+ label="Top-p (nucleus) sampling",
96
+ value=0.40,
97
+ minimum=0.0,
98
+ maximum=1,
99
+ step=0.05,
100
+ interactive=True,
101
+ info="Higher values sample more low-probability tokens",
102
+ )
103
+ top_k = gr.Slider(
104
+ label="Top-k sampling",
105
+ value=20,
106
+ minimum=1,
107
+ maximum=100,
108
+ step=1,
109
+ interactive=True,
110
+ info="Sample from the top_k most likely tokens",
111
+ )
112
+ repetition_penalty = gr.Slider(
113
+ label="Repetition penalty",
114
+ value=1.2,
115
+ minimum=1.0,
116
+ maximum=2.0,
117
+ step=0.05,
118
+ interactive=True,
119
+ info="Penalize repeated tokens",
120
+ )
121
+
122
+ # Generation inference
123
+ def generate(
124
+ message,
125
+ history,
126
+ max_new_tokens: int,
127
+ temperature: float,
128
+ top_p: float,
129
+ top_k: int,
130
+ repetition_penalty: float,
131
+ ):
132
+ generation_config = { "max_new_tokens": max_new_tokens,"temperature": temperature,"top_p": top_p,"top_k": top_k,"repetition_penalty": repetition_penalty,}
133
+ inference = pipe(sequences=message, streaming=True, **generation_config
134
+ history[-1][1] += message
135
+ for token in inference:
136
+ history[-1][1] += token.generations[0].text
137
+ yield history
138
+ print(pipe.timer_manager)
139
+ textbox.submit(
140
+ fn=clear_and_save_textbox,
141
+ inputs=textbox,
142
+ outputs=[textbox, saved_input],
143
+ api_name=False,
144
+ queue=False,
145
+ ).then(
146
+ fn=display_input,
147
+ inputs=[saved_input, chatbot],
148
+ outputs=chatbot,
149
+ api_name=False,
150
+ queue=False,
151
+ ).success(
152
+ generate,
153
+ inputs=[
154
+ saved_input,
155
+ chatbot,
156
+ max_new_tokens,
157
+ temperature,
158
+ top_p,
159
+ top_k,
160
+ repetition_penalty,
161
+ ],
162
+ outputs=[chatbot],
163
+ api_name=False,
164
+ )
165
+
166
+ submit_button.click(
167
+ fn=clear_and_save_textbox,
168
+ inputs=textbox,
169
+ outputs=[textbox, saved_input],
170
+ api_name=False,
171
+ queue=False,
172
+ ).then(
173
+ fn=display_input,
174
+ inputs=[saved_input, chatbot],
175
+ outputs=chatbot,
176
+ api_name=False,
177
+ queue=False,
178
+ ).success(
179
+ generate,
180
+ inputs=[saved_input, chatbot, max_new_tokens, temperature],
181
+ outputs=[chatbot],
182
+ api_name=False,
183
+ )
184
+
185
+ retry_button.click(
186
+ fn=delete_prev_fn,
187
+ inputs=chatbot,
188
+ outputs=[chatbot, saved_input],
189
+ api_name=False,
190
+ queue=False,
191
+ ).then(
192
+ fn=display_input,
193
+ inputs=[saved_input, chatbot],
194
+ outputs=chatbot,
195
+ api_name=False,
196
+ queue=False,
197
+ ).then(
198
+ generate,
199
+ inputs=[saved_input, chatbot, max_new_tokens, temperature],
200
+ outputs=[chatbot],
201
+ api_name=False,
202
+ )
203
+ undo_button.click(
204
+ fn=delete_prev_fn,
205
+ inputs=chatbot,
206
+ outputs=[chatbot, saved_input],
207
+ api_name=False,
208
+ queue=False,
209
+ ).then(
210
+ fn=lambda x: x,
211
+ inputs=[saved_input],
212
+ outputs=textbox,
213
+ api_name=False,
214
+ queue=False,
215
+ )
216
+ clear_button.click(
217
+ fn=lambda: ([], ""),
218
+ outputs=[chatbot, saved_input],
219
+ queue=False,
220
+ api_name=False,
221
+ )
222
+
223
+
224
+
225
+
226
+ demo.queue().launch()
227
+