chatham84 commited on
Commit
f581acd
1 Parent(s): 8515885

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +273 -109
app.py CHANGED
@@ -1,111 +1,275 @@
1
- import json
 
2
  import gradio as gr
3
- import os
4
- import requests
5
-
6
- hf_token = os.getenv('HF_TOKEN')
7
- api_url = os.getenv('API_URL')
8
- api_url_nostream = os.getenv('API_URL_NOSTREAM')
9
- headers = {
10
- 'Content-Type': 'application/json',
11
- }
12
-
13
- system_message = "\nYou are a helpful assistant who has a very narrow scope of knowledge: Medical Claims data. You have access to a medical claims database for Northern California. Do not answer questions you do not know. Respond exactly with '''I'm not trained in that area''' for any questions not related to claims data."
14
- title = "Vern SLM Bot"
15
- description = """
16
- Ask Vern Questions about Claims data..."""
17
- css = """.toast-wrap { display: none !important } """
18
- examples=[]
19
-
20
-
21
- def predict(message, chatbot):
22
-
23
- input_prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n "
24
- for interaction in chatbot:
25
- input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s> [INST] "
26
-
27
- input_prompt = input_prompt + str(message) + " [/INST] "
28
-
29
- data = {
30
- "inputs": input_prompt,
31
- "parameters": {"max_new_tokens":256,
32
- "do_sample":True,
33
- "top_p":0.6,
34
- "temperature":0.9,}
35
- }
36
-
37
- response = requests.post(api_url, headers=headers, data=json.dumps(data), auth=('hf', hf_token), stream=True)
38
-
39
- partial_message = ""
40
- for line in response.iter_lines():
41
- if line: # filter out keep-alive new lines
42
- # Decode from bytes to string
43
- decoded_line = line.decode('utf-8')
44
-
45
- # Remove 'data:' prefix
46
- if decoded_line.startswith('data:'):
47
- json_line = decoded_line[5:] # Exclude the first 5 characters ('data:')
48
- else:
49
- gr.Warning(f"This line does not start with 'data:': {decoded_line}")
50
- continue
51
-
52
- # Load as JSON
53
- try:
54
- json_obj = json.loads(json_line)
55
- if 'token' in json_obj:
56
- partial_message = partial_message + json_obj['token']['text']
57
- yield partial_message
58
- elif 'error' in json_obj:
59
- yield json_obj['error'] + '. Please refresh and try again with an appropriate smaller input prompt.'
60
- else:
61
- gr.Warning(f"The key 'token' does not exist in this JSON object: {json_obj}")
62
-
63
- except json.JSONDecodeError:
64
- gr.Warning(f"This line is not valid JSON: {json_line}")
65
- continue
66
- except KeyError as e:
67
- gr.Warning(f"KeyError: {e} occurred for JSON object: {json_obj}")
68
- continue
69
-
70
-
71
- def predict_batch(message, chatbot):
72
-
73
- input_prompt = f"[INST]<<SYS>>\n{system_message}\n<</SYS>>\n\n "
74
- for interaction in chatbot:
75
- input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s> [INST] "
76
-
77
- input_prompt = input_prompt + str(message) + " [/INST] "
78
-
79
- data = {
80
- "inputs": input_prompt,
81
- "parameters": {"max_new_tokens":256}
82
- }
83
-
84
- response = requests.post(api_url_nostream, headers=headers, data=json.dumps(data), auth=('hf', hf_token))
85
-
86
- if response.status_code == 200: # check if the request was successful
87
- try:
88
- json_obj = response.json()
89
- if 'generated_text' in json_obj and len(json_obj['generated_text']) > 0:
90
- return json_obj['generated_text']
91
- elif 'error' in json_obj:
92
- return json_obj['error'] + ' Please refresh and try again with smaller input prompt'
93
- else:
94
- print(f"Unexpected response: {json_obj}")
95
- except json.JSONDecodeError:
96
- print(f"Failed to decode response as JSON: {response.text}")
97
- else:
98
- print(f"Request failed with status code {response.status_code}")
99
-
100
-
101
- # Gradio Demo
102
- with gr.Blocks() as demo:
103
-
104
- with gr.Tab("Streaming"):
105
- gr.ChatInterface(predict, title=title, description=description, css=css, examples=examples, cache_examples=True)
106
-
107
- with gr.Tab("Batch"):
108
- gr.ChatInterface(predict_batch, title=title, description=description, css=css, examples=examples, cache_examples=True)
109
-
110
- demo.queue(concurrency_count=75, max_size=100).launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
 
 
1
+ from typing import Iterator
2
+
3
  import gradio as gr
4
+ import torch
5
+
6
+ from model import get_input_token_length, run
7
+
8
+ DEFAULT_SYSTEM_PROMPT = """\
9
+ You are a helpful assistant who has a very narrow scope of knowledge: Medical Claims data. You have access to a medical claims database for Northern California. Do not answer questions you do not know. Respond exactly with '''I'm not trained in that area''' for any questions not related to claims data.\
10
+ """
11
+ MAX_MAX_NEW_TOKENS = 2048
12
+ DEFAULT_MAX_NEW_TOKENS = 1024
13
+ MAX_INPUT_TOKEN_LENGTH = 4000
14
+
15
+ DESCRIPTION = """
16
+ # Vern Bot
17
+
18
+ Testing Vern Bot below
19
+
20
+ """
21
+
22
+ LICENSE = """
23
+ <p/>
24
+
25
+ ---
26
+ As a derivate work of [Llama-2-13b-chat](https://huggingface.co/meta-llama/Llama-2-13b-chat) by Meta,
27
+ this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-13b-chat/blob/main/USE_POLICY.md).
28
+ """
29
+
30
+ if not torch.cuda.is_available():
31
+ DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>'
32
+
33
+
34
+ def clear_and_save_textbox(message: str) -> tuple[str, str]:
35
+ return '', message
36
+
37
+
38
+ def display_input(message: str,
39
+ history: list[tuple[str, str]]) -> list[tuple[str, str]]:
40
+ history.append((message, ''))
41
+ return history
42
+
43
+
44
+ def delete_prev_fn(
45
+ history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
46
+ try:
47
+ message, _ = history.pop()
48
+ except IndexError:
49
+ message = ''
50
+ return history, message or ''
51
+
52
+
53
+ def generate(
54
+ message: str,
55
+ history_with_input: list[tuple[str, str]],
56
+ system_prompt: str,
57
+ max_new_tokens: int,
58
+ temperature: float,
59
+ top_p: float,
60
+ top_k: int,
61
+ ) -> Iterator[list[tuple[str, str]]]:
62
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
63
+ raise ValueError
64
+
65
+ history = history_with_input[:-1]
66
+ generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
67
+ try:
68
+ first_response = next(generator)
69
+ yield history + [(message, first_response)]
70
+ except StopIteration:
71
+ yield history + [(message, '')]
72
+ for response in generator:
73
+ yield history + [(message, response)]
74
+
75
+
76
+ def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
77
+ generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
78
+ for x in generator:
79
+ pass
80
+ return '', x
81
+
82
+
83
+ def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
84
+ input_token_length = get_input_token_length(message, chat_history, system_prompt)
85
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
86
+ raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
87
+
88
+
89
+ with gr.Blocks(css='style.css') as demo:
90
+ gr.Markdown(DESCRIPTION)
91
+ gr.DuplicateButton(value='',
92
+ elem_id='')
93
+
94
+ with gr.Group():
95
+ chatbot = gr.Chatbot(label='Chatbot')
96
+ with gr.Row():
97
+ textbox = gr.Textbox(
98
+ container=False,
99
+ show_label=False,
100
+ placeholder='Type a message...',
101
+ scale=10,
102
+ )
103
+ submit_button = gr.Button('Submit',
104
+ variant='primary',
105
+ scale=1,
106
+ min_width=0)
107
+ with gr.Row():
108
+ retry_button = gr.Button('🔄 Retry', variant='secondary')
109
+ undo_button = gr.Button('↩️ Undo', variant='secondary')
110
+ clear_button = gr.Button('🗑️ Clear', variant='secondary')
111
+
112
+ saved_input = gr.State()
113
+
114
+ with gr.Accordion(label='Advanced options', open=False):
115
+ system_prompt = gr.Textbox(label='System prompt',
116
+ value=DEFAULT_SYSTEM_PROMPT,
117
+ lines=6)
118
+ max_new_tokens = gr.Slider(
119
+ label='Max new tokens',
120
+ minimum=1,
121
+ maximum=MAX_MAX_NEW_TOKENS,
122
+ step=1,
123
+ value=DEFAULT_MAX_NEW_TOKENS,
124
+ )
125
+ temperature = gr.Slider(
126
+ label='Temperature',
127
+ minimum=0.1,
128
+ maximum=4.0,
129
+ step=0.1,
130
+ value=1.0,
131
+ )
132
+ top_p = gr.Slider(
133
+ label='Top-p (nucleus sampling)',
134
+ minimum=0.05,
135
+ maximum=1.0,
136
+ step=0.05,
137
+ value=0.95,
138
+ )
139
+ top_k = gr.Slider(
140
+ label='Top-k',
141
+ minimum=1,
142
+ maximum=1000,
143
+ step=1,
144
+ value=50,
145
+ )
146
+
147
+ gr.Examples(
148
+ examples=[
149
+ 'Hello there! How are you doing?',
150
+ 'Can you explain briefly to me what is the Python programming language?',
151
+ 'Explain the plot of Cinderella in a sentence.',
152
+ 'How many hours does it take a man to eat a Helicopter?',
153
+ "Write a 100-word article on 'Benefits of Open-Source in AI research'",
154
+ ],
155
+ inputs=textbox,
156
+ outputs=[textbox, chatbot],
157
+ fn=process_example,
158
+ cache_examples=True,
159
+ )
160
+
161
+ gr.Markdown(LICENSE)
162
+
163
+ textbox.submit(
164
+ fn=clear_and_save_textbox,
165
+ inputs=textbox,
166
+ outputs=[textbox, saved_input],
167
+ api_name=False,
168
+ queue=False,
169
+ ).then(
170
+ fn=display_input,
171
+ inputs=[saved_input, chatbot],
172
+ outputs=chatbot,
173
+ api_name=False,
174
+ queue=False,
175
+ ).then(
176
+ fn=check_input_token_length,
177
+ inputs=[saved_input, chatbot, system_prompt],
178
+ api_name=False,
179
+ queue=False,
180
+ ).success(
181
+ fn=generate,
182
+ inputs=[
183
+ saved_input,
184
+ chatbot,
185
+ system_prompt,
186
+ max_new_tokens,
187
+ temperature,
188
+ top_p,
189
+ top_k,
190
+ ],
191
+ outputs=chatbot,
192
+ api_name=False,
193
+ )
194
+
195
+ button_event_preprocess = submit_button.click(
196
+ fn=clear_and_save_textbox,
197
+ inputs=textbox,
198
+ outputs=[textbox, saved_input],
199
+ api_name=False,
200
+ queue=False,
201
+ ).then(
202
+ fn=display_input,
203
+ inputs=[saved_input, chatbot],
204
+ outputs=chatbot,
205
+ api_name=False,
206
+ queue=False,
207
+ ).then(
208
+ fn=check_input_token_length,
209
+ inputs=[saved_input, chatbot, system_prompt],
210
+ api_name=False,
211
+ queue=False,
212
+ ).success(
213
+ fn=generate,
214
+ inputs=[
215
+ saved_input,
216
+ chatbot,
217
+ system_prompt,
218
+ max_new_tokens,
219
+ temperature,
220
+ top_p,
221
+ top_k,
222
+ ],
223
+ outputs=chatbot,
224
+ api_name=False,
225
+ )
226
+
227
+ retry_button.click(
228
+ fn=delete_prev_fn,
229
+ inputs=chatbot,
230
+ outputs=[chatbot, saved_input],
231
+ api_name=False,
232
+ queue=False,
233
+ ).then(
234
+ fn=display_input,
235
+ inputs=[saved_input, chatbot],
236
+ outputs=chatbot,
237
+ api_name=False,
238
+ queue=False,
239
+ ).then(
240
+ fn=generate,
241
+ inputs=[
242
+ saved_input,
243
+ chatbot,
244
+ system_prompt,
245
+ max_new_tokens,
246
+ temperature,
247
+ top_p,
248
+ top_k,
249
+ ],
250
+ outputs=chatbot,
251
+ api_name=False,
252
+ )
253
+
254
+ undo_button.click(
255
+ fn=delete_prev_fn,
256
+ inputs=chatbot,
257
+ outputs=[chatbot, saved_input],
258
+ api_name=False,
259
+ queue=False,
260
+ ).then(
261
+ fn=lambda x: x,
262
+ inputs=[saved_input],
263
+ outputs=textbox,
264
+ api_name=False,
265
+ queue=False,
266
+ )
267
+
268
+ clear_button.click(
269
+ fn=lambda: ([], ''),
270
+ outputs=[chatbot, saved_input],
271
+ queue=False,
272
+ api_name=False,
273
+ )
274
 
275
+ demo.queue(max_size=20).launch()