ubermenchh commited on
Commit
72bf0de
1 Parent(s): eee08ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +234 -59
app.py CHANGED
@@ -1,69 +1,244 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
2
- from threading import Thread
 
3
  import gradio as gr
4
- import torch
5
-
6
- MAX_INPUT_TOKEN_LENGTH = 4096
7
 
8
  model_id = 'HuggingFaceH4/zephyr-7b-beta'
9
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map='auto')
10
- tokenizer = AutoTokenizer.from_pretrained(model_id)
11
- tokenizer.use_default_system_prompt = False
12
-
13
- def generate(input, chat_history=[], system_prompt=False, max_new_tokens=512, temperature=0.5, top_p=0.95, top_k=50, repetition_penalty=1.2):
14
- conversation = []
15
- if system_prompt:
16
- conversation.append({
17
- 'role': 'system',
18
- 'content': system_prompt
19
- })
20
- for user, assistant in chat_history:
21
- conversation.extend({
22
- 'role': 'user',
23
- 'content': user
24
- },
25
- {
26
- 'role': 'assistant',
27
- 'content': assistant
28
- })
29
- conversation.append({
30
- 'role': 'user',
31
- 'content': input
32
- })
33
-
34
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors='pt')
35
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
36
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
37
- gr.Warning(f"Trimed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
38
- input_ids = input_ids.to(model.device)
39
-
40
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
41
  generate_kwargs = dict(
42
- {'input_ids': input_ids},
43
- streamer=streamer,
44
  max_new_tokens=max_new_tokens,
45
  do_sample=True,
46
  top_p=top_p,
47
  top_k=top_k,
48
- temperature=temperature,
49
- num_beams=1,
50
- repetition_penalty=repetition_penalty
51
  )
52
- t = Thread(target=model.generate, kwargs=generate_kwargs)
53
- t.start()
54
-
55
- outputs = []
56
- for text in streamer:
57
- outputs.append(text)
58
- yield ''.join(outputs)
59
-
60
- chat_interface = gr.ChatInterface(
61
- fn=generate,
62
- examples=[
63
- 'What is GPT?',
64
- 'What is Life?',
65
- 'Who is Alan Turing'
66
- ]
67
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- chat_interface.queue(max_size=20).launch()
 
1
+ import os
2
+ from typing import Iterator
3
+ from text_generation import Client
4
  import gradio as gr
 
 
 
5
 
6
  model_id = 'HuggingFaceH4/zephyr-7b-beta'
7
+
8
+ API_URL = "https://api-inference.huggingface.co/models/" + model_id
9
+ HF_TOKEN = os.environ.get('HF_READ_TOKEN', None)
10
+
11
+ client = Client(
12
+ API_URL,
13
+ headers={'Authorization'L f"Bearer {HF_TOKEN"}
14
+ )
15
+ EOS_STRING = "</s>"
16
+ EOT_STRING = "<EOT>"
17
+
18
+ def get_prompt(message, chat_history, system_prompt):
19
+ texts = [f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"]
20
+
21
+ do_strip = False
22
+ for user_input, response in chat_history:
23
+ user_input = user_input.strip() if do_strip else user_input
24
+ do_strip = True
25
+ texts.append(f"{user_input} [/INST] {response.strip()} </s><s>[INST] ")
26
+ message = message.strip() if do_strip else message
27
+ texts.append(f"{message} [/INST]")
28
+ return ''.join(texts)
29
+
30
+ def run(message, chat_history, system_prompt, max_new_tokens=1024, temperature=0.1, top_p=0.9, top_k=50):
31
+ prompt = get_prompt(message, chat_history, system_prompt)
32
+
 
 
 
 
 
 
33
  generate_kwargs = dict(
 
 
34
  max_new_tokens=max_new_tokens,
35
  do_sample=True,
36
  top_p=top_p,
37
  top_k=top_k,
38
+ temperature=temperature
 
 
39
  )
40
+ stream = client.generate_stream(prompt, **generate_kwargs)
41
+ output = ''
42
+ for response in stream:
43
+ if any([end_token in response.token_text for end_token in [EOS_STRING, EOT_STRING]]):
44
+ return output
45
+ else:
46
+ output += response.token.text
47
+ yield output
48
+ return output
49
+
50
+
51
+ DEFAULT_SYSTEM_PROMPT = """
52
+ You are Zephyr. You are an AI assistant, you are moderately-polite and give only true information.
53
+ You carefully provide accurate, factual, thoughtful, nuanced answers, and are brilliant at reasoning.
54
+ If you think there might not be a correct answer, you say so. Since you are autoregressive,
55
+ each token you produce is another opportunity to use computation, therefore you always spend a few sentences explaining background context,
56
+ assumptions, and step-by-step thinking BEFORE you try to answer a question.
57
+ """
58
+ MAX_MAX_NEW_TOKENS = 4096
59
+ DEFAULT_MAX_NEW_TOKENS = 1024
60
+ MAX_INPUT_TOKEN_LENGTH = 4096
61
+
62
+ DESCRIPTION = """
63
+ # Zephyr-7b ChatBot
64
+ """
65
+
66
+ def clear_and_save_textbox(message): return '', message
67
+
68
+ def display_input(message, history=[]):
69
+ history.append((message, ''))
70
+ return history
71
+
72
+ def delete_prev_fn(history=[]):
73
+ try:
74
+ message, _ = history.pop()
75
+ except IndexError:
76
+ message = ''
77
+ return history, message or ''
78
+
79
+ def generate(message, history_with_input, system_prompt, max_new_tokens, temperature, top_p, top_k):
80
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
81
+ raise ValueError
82
+
83
+ history = history_with_input[:-1]
84
+ generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
85
+ try:
86
+ first_response = next(generator)
87
+ yield history + [(message, first_response)]
88
+ except StopIteration:
89
+ yield history + [(message, '')]
90
+ for response in generator:
91
+ yield history + [(message, response)]
92
+
93
+ def process_example(message):
94
+ generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
95
+ for x in generator:
96
+ pass
97
+ return '', x
98
+
99
+ def check_input_token_length(message, chat_history, system_prompt):
100
+ input_token_length = len(message) + len(chat_history)
101
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
102
+ raise gr.Error(f"The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.")
103
+
104
+ with gr.Block() as demo:
105
+ gr.Markdown(DESCRIPTION)
106
+
107
+ with gr.Group():
108
+ with gr.Row():
109
+ textbox = gr.Textbox(
110
+ container=False,
111
+ show_label=False,
112
+ placeholder='Hi, Zephyr',
113
+ scale=10
114
+ )
115
+ submit_button = gr.Button('Submit', variant='primary', scale=1, min_width=0)
116
+
117
+ with gr.Row():
118
+ retry_button = gr.Button('Retry', variant='secondary')
119
+ undo_button = gr.Button('Undo', variant='secondary')
120
+ clear_button = gr.Button('Clear', variant='secondary')
121
+
122
+ saved_input = gr.State()
123
+
124
+ with gr.Accordion(label='Advanced options', open=False):
125
+ system_prompt = gr.Textbox(label='System prompt', value=DEFAULT_SYSTEM_PROMPT, lines=5, interactive=False)
126
+ max_new_tokens = gr.Slider(label='Max New Tokens', minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
127
+ temperature = gr.Slider(label='Temperature', minimum=0.1, maximum=4.0, step=0.1, value=0.1)
128
+ top_p = gr.Slider(label='Top-P (nucleus sampling)', minimum=0.05, maximum=1.0, step=0.05, value=0.9)
129
+ top_k = gr.Slider(label='Top-K', minimum=1, maximum=1000, step=1, value=10)
130
+
131
+ textbox.submit(
132
+ fn=clear_and_save_textbox,
133
+ inputs=textbox,
134
+ outputs=[textbox, saved_input],
135
+ api_name=False,
136
+ queue=False,
137
+ ).then(
138
+ fn=display_input,
139
+ inputs=[saved_input, chatbot],
140
+ outputs=chatbot,
141
+ api_name=False,
142
+ queue=False,
143
+ ).then(
144
+ fn=check_input_token_length,
145
+ inputs=[saved_input, chatbot, system_prompt],
146
+ api_name=False,
147
+ queue=False,
148
+ ).success(
149
+ fn=generate,
150
+ inputs=[
151
+ saved_input,
152
+ chatbot,
153
+ system_prompt,
154
+ max_new_tokens,
155
+ temperature,
156
+ top_p,
157
+ top_k,
158
+ ],
159
+ outputs=chatbot,
160
+ api_name=False,
161
+ )
162
+
163
+ button_event_preprocess = submit_button.click(
164
+ fn=clear_and_save_textbox,
165
+ inputs=textbox,
166
+ outputs=[textbox, saved_input],
167
+ api_name=False,
168
+ queue=False,
169
+ ).then(
170
+ fn=display_input,
171
+ inputs=[saved_input, chatbot],
172
+ outputs=chatbot,
173
+ api_name=False,
174
+ queue=False,
175
+ ).then(
176
+ fn=check_input_token_length,
177
+ inputs=[saved_input, chatbot, system_prompt],
178
+ api_name=False,
179
+ queue=False,
180
+ ).success(
181
+ fn=generate,
182
+ inputs=[
183
+ saved_input,
184
+ chatbot,
185
+ system_prompt,
186
+ max_new_tokens,
187
+ temperature,
188
+ top_p,
189
+ top_k,
190
+ ],
191
+ outputs=chatbot,
192
+ api_name=False,
193
+ )
194
+
195
+ retry_button.click(
196
+ fn=delete_prev_fn,
197
+ inputs=chatbot,
198
+ outputs=[chatbot, saved_input],
199
+ api_name=False,
200
+ queue=False,
201
+ ).then(
202
+ fn=display_input,
203
+ inputs=[saved_input, chatbot],
204
+ outputs=chatbot,
205
+ api_name=False,
206
+ queue=False,
207
+ ).then(
208
+ fn=generate,
209
+ inputs=[
210
+ saved_input,
211
+ chatbot,
212
+ system_prompt,
213
+ max_new_tokens,
214
+ temperature,
215
+ top_p,
216
+ top_k,
217
+ ],
218
+ outputs=chatbot,
219
+ api_name=False,
220
+ )
221
+
222
+ undo_button.click(
223
+ fn=delete_prev_fn,
224
+ inputs=chatbot,
225
+ outputs=[chatbot, saved_input],
226
+ api_name=False,
227
+ queue=False,
228
+ ).then(
229
+ fn=lambda x: x,
230
+ inputs=[saved_input],
231
+ outputs=textbox,
232
+ api_name=False,
233
+ queue=False,
234
+ )
235
+
236
+ clear_button.click(
237
+ fn=lambda: ([], ''),
238
+ outputs=[chatbot, saved_input],
239
+ queue=False,
240
+ api_name=False,
241
+ )
242
+
243
 
244
+ demo.queue(max_size=32).launch(show_api=False)