harsh-manvar commited on
Commit
7bb24c5
β€’
1 Parent(s): 1f2e734

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +270 -0
app.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterator
2
+ import gradio as gr
3
+
4
+ # import torch
5
+ from transformers.utils import logging
6
+ from model import get_input_token_length, run
7
+
8
+ logging.set_verbosity_info()
9
+ logger = logging.get_logger("transformers")
10
+
11
+ DEFAULT_SYSTEM_PROMPT = """"""
12
+ MAX_MAX_NEW_TOKENS = 2048
13
+ DEFAULT_MAX_NEW_TOKENS = 1024
14
+ MAX_INPUT_TOKEN_LENGTH = 4000
15
+
16
+ DESCRIPTION = """"""
17
+
18
+ LICENSE = """"""
19
+
20
+ logger.info("Starting")
21
+ def clear_and_save_textbox(message: str) -> tuple[str, str]:
22
+ return '', message
23
+
24
+
25
+ def display_input(message: str,
26
+ history: list[tuple[str, str]]) -> list[tuple[str, str]]:
27
+ history.append((message, ''))
28
+ logger.info("display_input=%s",message)
29
+ return history
30
+
31
+
32
+ def delete_prev_fn(
33
+ history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
34
+ try:
35
+ message, _ = history.pop()
36
+ except IndexError:
37
+ message = ''
38
+ return history, message or ''
39
+
40
+
41
+ def generate(
42
+ message: str,
43
+ history_with_input: list[tuple[str, str]],
44
+ system_prompt: str,
45
+ max_new_tokens: int,
46
+ temperature: float,
47
+ top_p: float,
48
+ top_k: int,
49
+ ) -> Iterator[list[tuple[str, str]]]:
50
+ logger.info("message=%s",message)
51
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
52
+ raise ValueError
53
+
54
+ history = history_with_input[:-1]
55
+ generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
56
+ try:
57
+ first_response = next(generator)
58
+ yield history + [(message, first_response)]
59
+ except StopIteration:
60
+ yield history + [(message, '')]
61
+ for response in generator:
62
+ yield history + [(message, response)]
63
+
64
+
65
+ def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
66
+ generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
67
+ for x in generator:
68
+ pass
69
+ return '', x
70
+
71
+
72
+ def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
73
+ logger.info("check_input_token_length=%s",message)
74
+ input_token_length = get_input_token_length(message, chat_history, system_prompt)
75
+ logger.info("input_token_length",input_token_length)
76
+ logger.info("MAX_INPUT_TOKEN_LENGTH",MAX_INPUT_TOKEN_LENGTH)
77
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
78
+ logger.info("Inside IF condition")
79
+ raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
80
+ logger.info("End of check_input_token_length function")
81
+
82
+
83
+ with gr.Blocks(css='style.css') as demo:
84
+ gr.Markdown(DESCRIPTION)
85
+ gr.DuplicateButton(value='Duplicate Space for private use',
86
+ elem_id='duplicate-button')
87
+
88
+ with gr.Group():
89
+ chatbot = gr.Chatbot(label='Chatbot')
90
+ with gr.Row():
91
+ textbox = gr.Textbox(
92
+ container=False,
93
+ show_label=False,
94
+ placeholder='Type a message...',
95
+ scale=10,
96
+ )
97
+ submit_button = gr.Button('Submit',
98
+ variant='primary',
99
+ scale=1,
100
+ min_width=0)
101
+ with gr.Row():
102
+ retry_button = gr.Button('πŸ”„ Retry', variant='secondary')
103
+ undo_button = gr.Button('↩️ Undo', variant='secondary')
104
+ clear_button = gr.Button('πŸ—‘οΈ Clear', variant='secondary')
105
+
106
+ saved_input = gr.State()
107
+
108
+ with gr.Accordion(label='Advanced options', open=False):
109
+ system_prompt = gr.Textbox(label='System prompt',
110
+ value=DEFAULT_SYSTEM_PROMPT,
111
+ lines=6)
112
+ max_new_tokens = gr.Slider(
113
+ label='Max new tokens',
114
+ minimum=1,
115
+ maximum=MAX_MAX_NEW_TOKENS,
116
+ step=1,
117
+ value=DEFAULT_MAX_NEW_TOKENS,
118
+ )
119
+ temperature = gr.Slider(
120
+ label='Temperature',
121
+ minimum=0.1,
122
+ maximum=4.0,
123
+ step=0.1,
124
+ value=1.0,
125
+ )
126
+ top_p = gr.Slider(
127
+ label='Top-p (nucleus sampling)',
128
+ minimum=0.05,
129
+ maximum=1.0,
130
+ step=0.05,
131
+ value=0.95,
132
+ )
133
+ top_k = gr.Slider(
134
+ label='Top-k',
135
+ minimum=1,
136
+ maximum=1000,
137
+ step=1,
138
+ value=50,
139
+ )
140
+
141
+
142
+ # gr.Examples(
143
+ # examples=[
144
+ # 'Hello there! How are you doing?',
145
+ # 'Can you explain briefly to me what is the Python programming language?',
146
+ # 'Explain the plot of Cinderella in a sentence.',
147
+ # 'How many hours does it take a man to eat a Helicopter?',
148
+ # "Write a 100-word article on 'Benefits of Open-Source in AI research'",
149
+ # ],
150
+ # inputs=textbox,
151
+ # outputs=[textbox, chatbot],
152
+ # fn=process_example,
153
+ # cache_examples=True,
154
+ # )
155
+
156
+ gr.Markdown(LICENSE)
157
+
158
+ textbox.submit(
159
+ fn=clear_and_save_textbox,
160
+ inputs=textbox,
161
+ outputs=[textbox, saved_input],
162
+ api_name=False,
163
+ queue=False,
164
+ ).then(
165
+ fn=display_input,
166
+ inputs=[saved_input, chatbot],
167
+ outputs=chatbot,
168
+ api_name=False,
169
+ queue=False,
170
+ ).then(
171
+ fn=check_input_token_length,
172
+ inputs=[saved_input, chatbot, system_prompt],
173
+ api_name=False,
174
+ queue=False,
175
+ ).success(
176
+ fn=generate,
177
+ inputs=[
178
+ saved_input,
179
+ chatbot,
180
+ system_prompt,
181
+ max_new_tokens,
182
+ temperature,
183
+ top_p,
184
+ top_k,
185
+ ],
186
+ outputs=chatbot,
187
+ api_name=False,
188
+ )
189
+
190
+ button_event_preprocess = submit_button.click(
191
+ fn=clear_and_save_textbox,
192
+ inputs=textbox,
193
+ outputs=[textbox, saved_input],
194
+ api_name=False,
195
+ queue=False,
196
+ ).then(
197
+ fn=display_input,
198
+ inputs=[saved_input, chatbot],
199
+ outputs=chatbot,
200
+ api_name=False,
201
+ queue=False,
202
+ ).then(
203
+ fn=check_input_token_length,
204
+ inputs=[saved_input, chatbot, system_prompt],
205
+ api_name=False,
206
+ queue=False,
207
+ ).success(
208
+ fn=generate,
209
+ inputs=[
210
+ saved_input,
211
+ chatbot,
212
+ system_prompt,
213
+ max_new_tokens,
214
+ temperature,
215
+ top_p,
216
+ top_k,
217
+ ],
218
+ outputs=chatbot,
219
+ api_name=False,
220
+ )
221
+
222
+ retry_button.click(
223
+ fn=delete_prev_fn,
224
+ inputs=chatbot,
225
+ outputs=[chatbot, saved_input],
226
+ api_name=False,
227
+ queue=False,
228
+ ).then(
229
+ fn=display_input,
230
+ inputs=[saved_input, chatbot],
231
+ outputs=chatbot,
232
+ api_name=False,
233
+ queue=False,
234
+ ).then(
235
+ fn=generate,
236
+ inputs=[
237
+ saved_input,
238
+ chatbot,
239
+ system_prompt,
240
+ max_new_tokens,
241
+ temperature,
242
+ top_p,
243
+ top_k,
244
+ ],
245
+ outputs=chatbot,
246
+ api_name=False,
247
+ )
248
+
249
+ undo_button.click(
250
+ fn=delete_prev_fn,
251
+ inputs=chatbot,
252
+ outputs=[chatbot, saved_input],
253
+ api_name=False,
254
+ queue=False,
255
+ ).then(
256
+ fn=lambda x: x,
257
+ inputs=[saved_input],
258
+ outputs=textbox,
259
+ api_name=False,
260
+ queue=False,
261
+ )
262
+
263
+ clear_button.click(
264
+ fn=lambda: ([], ''),
265
+ outputs=[chatbot, saved_input],
266
+ queue=False,
267
+ api_name=False,
268
+ )
269
+
270
+ demo.queue(max_size=20).launch()