Ashishkr commited on
Commit
d397ed2
β€’
1 Parent(s): dcc5de7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +256 -1
app.py CHANGED
@@ -1,3 +1,258 @@
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- gr.Interface.load("models/Ashishkr/llama2_medical_consultation").launch()
 
1
+ from typing import Iterator
2
+
3
  import gradio as gr
4
+ import torch
5
+
6
+ from model import get_input_token_length, run
7
+
8
+ DEFAULT_SYSTEM_PROMPT = """\
9
+ instruction: "If you are a doctor, please answer the medical questions based on the patient's description." \n
10
+
11
+
12
+ """
13
+ MAX_MAX_NEW_TOKENS = 2048
14
+ DEFAULT_MAX_NEW_TOKENS = 1024
15
+ MAX_INPUT_TOKEN_LENGTH = 4000
16
+
17
+
18
+ if not torch.cuda.is_available():
19
+ DESCRIPTION += '\n<p>Running on CPU πŸ₯Ά This demo does not work on CPU.</p>'
20
+
21
+
22
+ def clear_and_save_textbox(message: str) -> tuple[str, str]:
23
+ return '', message
24
+
25
+
26
+ def display_input(message: str,
27
+ history: list[tuple[str, str]]) -> list[tuple[str, str]]:
28
+ history.append((message, ''))
29
+ return history
30
+
31
+
32
+ def delete_prev_fn(
33
+ history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
34
+ try:
35
+ message, _ = history.pop()
36
+ except IndexError:
37
+ message = ''
38
+ return history, message or ''
39
+
40
+
41
+ def generate(
42
+ message: str,
43
+ history_with_input: list[tuple[str, str]],
44
+ system_prompt: str,
45
+ max_new_tokens: int,
46
+ temperature: float,
47
+ top_p: float,
48
+ top_k: int,
49
+ ) -> Iterator[list[tuple[str, str]]]:
50
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
51
+ raise ValueError
52
+
53
+ history = history_with_input[:-1]
54
+ generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
55
+ try:
56
+ first_response = next(generator)
57
+ yield history + [(message, first_response)]
58
+ except StopIteration:
59
+ yield history + [(message, '')]
60
+ for response in generator:
61
+ yield history + [(message, response)]
62
+
63
+
64
+ def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
65
+ generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
66
+ for x in generator:
67
+ pass
68
+ return '', x
69
+
70
+
71
+ def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
72
+ input_token_length = get_input_token_length(message, chat_history, system_prompt)
73
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
74
+ raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
75
+
76
+
77
+ with gr.Blocks(css='style.css') as demo:
78
+ gr.Markdown(DESCRIPTION)
79
+ gr.DuplicateButton(value='Duplicate Space for private use',
80
+ elem_id='duplicate-button')
81
+
82
+ with gr.Group():
83
+ chatbot = gr.Chatbot(label='Chatbot')
84
+ with gr.Row():
85
+ textbox = gr.Textbox(
86
+ container=False,
87
+ show_label=False,
88
+ placeholder='Type a message...',
89
+ scale=10,
90
+ )
91
+ submit_button = gr.Button('Submit',
92
+ variant='primary',
93
+ scale=1,
94
+ min_width=0)
95
+ with gr.Row():
96
+ retry_button = gr.Button('πŸ”„ Retry', variant='secondary')
97
+ undo_button = gr.Button('↩️ Undo', variant='secondary')
98
+ clear_button = gr.Button('πŸ—‘οΈ Clear', variant='secondary')
99
+
100
+ saved_input = gr.State()
101
+
102
+ with gr.Accordion(label='Advanced options', open=False):
103
+ system_prompt = gr.Textbox(label='System prompt',
104
+ value=DEFAULT_SYSTEM_PROMPT,
105
+ lines=6)
106
+ max_new_tokens = gr.Slider(
107
+ label='Max new tokens',
108
+ minimum=1,
109
+ maximum=MAX_MAX_NEW_TOKENS,
110
+ step=1,
111
+ value=DEFAULT_MAX_NEW_TOKENS,
112
+ )
113
+ temperature = gr.Slider(
114
+ label='Temperature',
115
+ minimum=0.1,
116
+ maximum=4.0,
117
+ step=0.1,
118
+ value=1.0,
119
+ )
120
+ top_p = gr.Slider(
121
+ label='Top-p (nucleus sampling)',
122
+ minimum=0.05,
123
+ maximum=1.0,
124
+ step=0.05,
125
+ value=0.95,
126
+ )
127
+ top_k = gr.Slider(
128
+ label='Top-k',
129
+ minimum=1,
130
+ maximum=1000,
131
+ step=1,
132
+ value=50,
133
+ )
134
+
135
+ gr.Examples(
136
+ examples=['I have high fever and sharp pain in jaw'
137
+ ],
138
+ inputs=textbox,
139
+ outputs=[textbox, chatbot],
140
+ fn=process_example,
141
+ cache_examples=True,
142
+ )
143
+
144
+ gr.Markdown(LICENSE)
145
+
146
+ textbox.submit(
147
+ fn=clear_and_save_textbox,
148
+ inputs=textbox,
149
+ outputs=[textbox, saved_input],
150
+ api_name=False,
151
+ queue=False,
152
+ ).then(
153
+ fn=display_input,
154
+ inputs=[saved_input, chatbot],
155
+ outputs=chatbot,
156
+ api_name=False,
157
+ queue=False,
158
+ ).then(
159
+ fn=check_input_token_length,
160
+ inputs=[saved_input, chatbot, system_prompt],
161
+ api_name=False,
162
+ queue=False,
163
+ ).success(
164
+ fn=generate,
165
+ inputs=[
166
+ saved_input,
167
+ chatbot,
168
+ system_prompt,
169
+ max_new_tokens,
170
+ temperature,
171
+ top_p,
172
+ top_k,
173
+ ],
174
+ outputs=chatbot,
175
+ api_name=False,
176
+ )
177
+
178
+ button_event_preprocess = submit_button.click(
179
+ fn=clear_and_save_textbox,
180
+ inputs=textbox,
181
+ outputs=[textbox, saved_input],
182
+ api_name=False,
183
+ queue=False,
184
+ ).then(
185
+ fn=display_input,
186
+ inputs=[saved_input, chatbot],
187
+ outputs=chatbot,
188
+ api_name=False,
189
+ queue=False,
190
+ ).then(
191
+ fn=check_input_token_length,
192
+ inputs=[saved_input, chatbot, system_prompt],
193
+ api_name=False,
194
+ queue=False,
195
+ ).success(
196
+ fn=generate,
197
+ inputs=[
198
+ saved_input,
199
+ chatbot,
200
+ system_prompt,
201
+ max_new_tokens,
202
+ temperature,
203
+ top_p,
204
+ top_k,
205
+ ],
206
+ outputs=chatbot,
207
+ api_name=False,
208
+ )
209
+
210
+ retry_button.click(
211
+ fn=delete_prev_fn,
212
+ inputs=chatbot,
213
+ outputs=[chatbot, saved_input],
214
+ api_name=False,
215
+ queue=False,
216
+ ).then(
217
+ fn=display_input,
218
+ inputs=[saved_input, chatbot],
219
+ outputs=chatbot,
220
+ api_name=False,
221
+ queue=False,
222
+ ).then(
223
+ fn=generate,
224
+ inputs=[
225
+ saved_input,
226
+ chatbot,
227
+ system_prompt,
228
+ max_new_tokens,
229
+ temperature,
230
+ top_p,
231
+ top_k,
232
+ ],
233
+ outputs=chatbot,
234
+ api_name=False,
235
+ )
236
+
237
+ undo_button.click(
238
+ fn=delete_prev_fn,
239
+ inputs=chatbot,
240
+ outputs=[chatbot, saved_input],
241
+ api_name=False,
242
+ queue=False,
243
+ ).then(
244
+ fn=lambda x: x,
245
+ inputs=[saved_input],
246
+ outputs=textbox,
247
+ api_name=False,
248
+ queue=False,
249
+ )
250
+
251
+ clear_button.click(
252
+ fn=lambda: ([], ''),
253
+ outputs=[chatbot, saved_input],
254
+ queue=False,
255
+ api_name=False,
256
+ )
257
 
258
+ demo.queue(max_size=20).launch()