Colmand commited on
Commit
f60027b
1 Parent(s): 49e6bea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -1
app.py CHANGED
@@ -1,3 +1,257 @@
 
1
  import gradio as gr
 
 
 
2
 
3
- gr.load("models/meta-llama/Meta-Llama-3-8B").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterator
2
  import gradio as gr
3
+ gr.load("models/meta-llama/Meta-Llama-3-8B").launch()
4
+ from transformers.utils import logging
5
+ from model import get_input_token_length, run
6
 
7
+ logging.set_verbosity_info()
8
+ logger = logging.get_logger("transformers")
9
+
10
+ DEFAULT_SYSTEM_PROMPT = """"""
11
+ MAX_MAX_NEW_TOKENS = 2048
12
+ DEFAULT_MAX_NEW_TOKENS = 1024
13
+ MAX_INPUT_TOKEN_LENGTH = 4000
14
+
15
+ DESCRIPTION = """"""
16
+
17
+ LICENSE = """"""
18
+
19
+ logger.info("Starting")
20
+
21
+
22
+ def clear_and_save_textbox(message: str) -> tuple[str, str]:
23
+ return '', message
24
+
25
+
26
+ def display_input(message: str,
27
+ history: list[tuple[str, str]]) -> list[tuple[str, str]]:
28
+ history.append((message, ''))
29
+ logger.info("display_input=%s", message)
30
+ return history
31
+
32
+
33
+ def delete_prev_fn(
34
+ history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
35
+ try:
36
+ message, _ = history.pop()
37
+ except IndexError:
38
+ message = ''
39
+ return history, message or ''
40
+
41
+
42
+ def generate(
43
+ message: str,
44
+ history_with_input: list[tuple[str, str]],
45
+ system_prompt: str,
46
+ max_new_tokens: int,
47
+ temperature: float,
48
+ top_p: float,
49
+ top_k: int,
50
+ ) -> Iterator[list[tuple[str, str]]]:
51
+ # logger.info("message=%s",message)
52
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
53
+ raise ValueError
54
+
55
+ history = history_with_input[:-1]
56
+ generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
57
+ try:
58
+ first_response = next(generator)
59
+ yield history + [(message, first_response)]
60
+ except StopIteration:
61
+ yield history + [(message, '')]
62
+ for response in generator:
63
+ yield history + [(message, response)]
64
+
65
+
66
+ def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
67
+ generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
68
+ for x in generator:
69
+ pass
70
+ return '', x
71
+
72
+
73
+ def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
74
+ # logger.info("check_input_token_length=%s",message)
75
+ input_token_length = get_input_token_length(message, chat_history, system_prompt)
76
+ # logger.info("input_token_length",input_token_length)
77
+ # logger.info("MAX_INPUT_TOKEN_LENGTH",MAX_INPUT_TOKEN_LENGTH)
78
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
79
+ logger.info("Inside IF condition")
80
+ raise gr.Error(
81
+ f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
82
+ # logger.info("End of check_input_token_length function")
83
+
84
+
85
+ with gr.Blocks(css='style.css') as demo:
86
+ gr.Markdown(DESCRIPTION)
87
+ gr.DuplicateButton(value='Duplicate Space for private use',
88
+ elem_id='duplicate-button')
89
+
90
+ with gr.Group():
91
+ chatbot = gr.Chatbot(label='Chatbot')
92
+ with gr.Row():
93
+ textbox = gr.Textbox(
94
+ container=False,
95
+ show_label=False,
96
+ placeholder='Type a message...',
97
+ scale=10,
98
+ )
99
+ submit_button = gr.Button('Submit',
100
+ variant='primary',
101
+ scale=1,
102
+ min_width=0)
103
+ with gr.Row():
104
+ retry_button = gr.Button('Retry', variant='secondary')
105
+ undo_button = gr.Button('Undo', variant='secondary')
106
+ clear_button = gr.Button('Clear', variant='secondary')
107
+
108
+ saved_input = gr.State()
109
+
110
+ with gr.Accordion(label='Advanced options', open=False):
111
+ system_prompt = gr.Textbox(label='System prompt',
112
+ value=DEFAULT_SYSTEM_PROMPT,
113
+ lines=6)
114
+ max_new_tokens = gr.Slider(
115
+ label='Max new tokens',
116
+ minimum=1,
117
+ maximum=MAX_MAX_NEW_TOKENS,
118
+ step=1,
119
+ value=DEFAULT_MAX_NEW_TOKENS,
120
+ )
121
+ temperature = gr.Slider(
122
+ label='Temperature',
123
+ minimum=0.1,
124
+ maximum=4.0,
125
+ step=0.1,
126
+ value=1.0,
127
+ )
128
+ top_p = gr.Slider(
129
+ label='Top-p (nucleus sampling)',
130
+ minimum=0.05,
131
+ maximum=1.0,
132
+ step=0.05,
133
+ value=0.95,
134
+ )
135
+ top_k = gr.Slider(
136
+ label='Top-k',
137
+ minimum=1,
138
+ maximum=1000,
139
+ step=1,
140
+ value=50,
141
+ )
142
+
143
+ gr.Markdown(LICENSE)
144
+
145
+ textbox.submit(
146
+ fn=clear_and_save_textbox,
147
+ inputs=textbox,
148
+ outputs=[textbox, saved_input],
149
+ api_name=False,
150
+ queue=False,
151
+ ).then(
152
+ fn=display_input,
153
+ inputs=[saved_input, chatbot],
154
+ outputs=chatbot,
155
+ api_name=False,
156
+ queue=False,
157
+ ).then(
158
+ fn=check_input_token_length,
159
+ inputs=[saved_input, chatbot, system_prompt],
160
+ api_name=False,
161
+ queue=False,
162
+ ).success(
163
+ fn=generate,
164
+ inputs=[
165
+ saved_input,
166
+ chatbot,
167
+ system_prompt,
168
+ max_new_tokens,
169
+ temperature,
170
+ top_p,
171
+ top_k,
172
+ ],
173
+ outputs=chatbot,
174
+ api_name=False,
175
+ )
176
+
177
+ button_event_preprocess = submit_button.click(
178
+ fn=clear_and_save_textbox,
179
+ inputs=textbox,
180
+ outputs=[textbox, saved_input],
181
+ api_name=False,
182
+ queue=False,
183
+ ).then(
184
+ fn=display_input,
185
+ inputs=[saved_input, chatbot],
186
+ outputs=chatbot,
187
+ api_name=False,
188
+ queue=False,
189
+ ).then(
190
+ fn=check_input_token_length,
191
+ inputs=[saved_input, chatbot, system_prompt],
192
+ api_name=False,
193
+ queue=False,
194
+ ).success(
195
+ fn=generate,
196
+ inputs=[
197
+ saved_input,
198
+ chatbot,
199
+ system_prompt,
200
+ max_new_tokens,
201
+ temperature,
202
+ top_p,
203
+ top_k,
204
+ ],
205
+ outputs=chatbot,
206
+ api_name=False,
207
+ )
208
+
209
+ retry_button.click(
210
+ fn=delete_prev_fn,
211
+ inputs=chatbot,
212
+ outputs=[chatbot, saved_input],
213
+ api_name=False,
214
+ queue=False,
215
+ ).then(
216
+ fn=display_input,
217
+ inputs=[saved_input, chatbot],
218
+ outputs=chatbot,
219
+ api_name=False,
220
+ queue=False,
221
+ ).then(
222
+ fn=generate,
223
+ inputs=[
224
+ saved_input,
225
+ chatbot,
226
+ system_prompt,
227
+ max_new_tokens,
228
+ temperature,
229
+ top_p,
230
+ top_k,
231
+ ],
232
+ outputs=chatbot,
233
+ api_name=False,
234
+ )
235
+
236
+ undo_button.click(
237
+ fn=delete_prev_fn,
238
+ inputs=chatbot,
239
+ outputs=[chatbot, saved_input],
240
+ api_name=False,
241
+ queue=False,
242
+ ).then(
243
+ fn=lambda x: x,
244
+ inputs=[saved_input],
245
+ outputs=textbox,
246
+ api_name=False,
247
+ queue=False,
248
+ )
249
+
250
+ clear_button.click(
251
+ fn=lambda: ([], ''),
252
+ outputs=[chatbot, saved_input],
253
+ queue=False,
254
+ api_name=False,
255
+ )
256
+
257
+ demo.queue(max_size=20).launch(share=False, server_name="0.0.0.0")