arborvitae commited on
Commit
22d90fb
1 Parent(s): b5c7d4d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +267 -0
app.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterator
2
+
3
+ import gradio as gr
4
+ import torch
5
+
6
+ from model import get_input_token_length, run
7
+
8
+ DEFAULT_SYSTEM_PROMPT = """\
9
+ You are a software engineer reporting to a senior software engineer. Reply with highest quality, PhD level, detailed, logical, precise, clean answers.
10
+ """
11
+ MAX_MAX_NEW_TOKENS = 2048
12
+ DEFAULT_MAX_NEW_TOKENS = 1024
13
+ MAX_INPUT_TOKEN_LENGTH = 4000
14
+
15
+ DESCRIPTION = """
16
+ """
17
+
18
+ LICENSE = """
19
+ <p/>
20
+ ---
21
+ """
22
+
23
+ if not torch.cuda.is_available():
24
+ DESCRIPTION += '\n<p>Running on CPU.</p>'
25
+
26
+
27
+ def clear_and_save_textbox(message: str) -> tuple[str, str]:
28
+ return '', message
29
+
30
+
31
+ def display_input(message: str,
32
+ history: list[tuple[str, str]]) -> list[tuple[str, str]]:
33
+ history.append((message, ''))
34
+ return history
35
+
36
+
37
+ def delete_prev_fn(
38
+ history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
39
+ try:
40
+ message, _ = history.pop()
41
+ except IndexError:
42
+ message = ''
43
+ return history, message or ''
44
+
45
+
46
+ def generate(
47
+ message: str,
48
+ history_with_input: list[tuple[str, str]],
49
+ system_prompt: str,
50
+ max_new_tokens: int,
51
+ temperature: float,
52
+ top_p: float,
53
+ top_k: int,
54
+ ) -> Iterator[list[tuple[str, str]]]:
55
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
56
+ raise ValueError
57
+
58
+ history = history_with_input[:-1]
59
+ generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
60
+ try:
61
+ first_response = next(generator)
62
+ yield history + [(message, first_response)]
63
+ except StopIteration:
64
+ yield history + [(message, '')]
65
+ for response in generator:
66
+ yield history + [(message, response)]
67
+
68
+
69
+ def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
70
+ generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
71
+ for x in generator:
72
+ pass
73
+ return '', x
74
+
75
+
76
+ def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
77
+ input_token_length = get_input_token_length(message, chat_history, system_prompt)
78
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
79
+ raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
80
+
81
+
82
+ with gr.Blocks(css='style.css') as demo:
83
+ gr.Header("GalaxiCode.ai", level=1, font_size=24)
84
+
85
+ with gr.Group():
86
+ chatbot = gr.Chatbot(label='Chatbot')
87
+ with gr.Row():
88
+ textbox = gr.Textbox(
89
+ container=False,
90
+ show_label=False,
91
+ placeholder='Type a message...',
92
+ scale=10,
93
+ )
94
+ submit_button = gr.Button('Submit',
95
+ variant='primary',
96
+ scale=1,
97
+ min_width=0)
98
+ with gr.Row():
99
+ retry_button = gr.Button('🔄 Retry', variant='secondary')
100
+ undo_button = gr.Button('↩️ Undo', variant='secondary')
101
+ clear_button = gr.Button('🗑️ Clear', variant='secondary')
102
+
103
+ saved_input = gr.State()
104
+
105
+ with gr.Accordion(label='Advanced options', open=False):
106
+ system_prompt = gr.Textbox(label='System prompt',
107
+ value=DEFAULT_SYSTEM_PROMPT,
108
+ lines=6)
109
+ max_new_tokens = gr.Slider(
110
+ label='Max new tokens',
111
+ minimum=1,
112
+ maximum=MAX_MAX_NEW_TOKENS,
113
+ step=1,
114
+ value=DEFAULT_MAX_NEW_TOKENS,
115
+ )
116
+ temperature = gr.Slider(
117
+ label='Temperature',
118
+ minimum=0.1,
119
+ maximum=4.0,
120
+ step=0.1,
121
+ value=1.0,
122
+ )
123
+ top_p = gr.Slider(
124
+ label='Top-p (nucleus sampling)',
125
+ minimum=0.05,
126
+ maximum=1.0,
127
+ step=0.05,
128
+ value=0.95,
129
+ )
130
+ top_k = gr.Slider(
131
+ label='Top-k',
132
+ minimum=1,
133
+ maximum=1000,
134
+ step=1,
135
+ value=50,
136
+ )
137
+
138
+ gr.Examples(
139
+ examples=[
140
+ "X_train, y_train, X_test, y_test = train_test_split(X, y, test_size=0.1)\n\n# Train a logistic regression model, predict the labels on the test set and compute the accuracy score",
141
+ "// Returns every other value in the array as a new array.\nfunction everyOther(arr) {",
142
+ "Poor English: She no went to the market. Corrected English:",
143
+ "def alternating(list1, list2):\n results = []\n for i in range(min(len(list1), len(list2))):\n results.append(list1[i])\n results.append(list2[i])\n if len(list1) > len(list2):\n <FILL_HERE>\n else:\n results.extend(list2[i+1:])\n return results",
144
+ "def remove_non_ascii(s: str) -> str:\n \"\"\" <FILL_ME>\nprint(remove_non_ascii('afkdj$$('))",
145
+ ],
146
+ inputs=textbox,
147
+ outputs=[textbox, chatbot],
148
+ fn=process_example,
149
+ cache_examples=True,
150
+ )
151
+
152
+ gr.Markdown(LICENSE)
153
+
154
+ textbox.submit(
155
+ fn=clear_and_save_textbox,
156
+ inputs=textbox,
157
+ outputs=[textbox, saved_input],
158
+ api_name=False,
159
+ queue=False,
160
+ ).then(
161
+ fn=display_input,
162
+ inputs=[saved_input, chatbot],
163
+ outputs=chatbot,
164
+ api_name=False,
165
+ queue=False,
166
+ ).then(
167
+ fn=check_input_token_length,
168
+ inputs=[saved_input, chatbot, system_prompt],
169
+ api_name=False,
170
+ queue=False,
171
+ ).success(
172
+ fn=generate,
173
+ inputs=[
174
+ saved_input,
175
+ chatbot,
176
+ system_prompt,
177
+ max_new_tokens,
178
+ temperature,
179
+ top_p,
180
+ top_k,
181
+ ],
182
+ outputs=chatbot,
183
+ api_name=False,
184
+ )
185
+
186
+ button_event_preprocess = submit_button.click(
187
+ fn=clear_and_save_textbox,
188
+ inputs=textbox,
189
+ outputs=[textbox, saved_input],
190
+ api_name=False,
191
+ queue=False,
192
+ ).then(
193
+ fn=display_input,
194
+ inputs=[saved_input, chatbot],
195
+ outputs=chatbot,
196
+ api_name=False,
197
+ queue=False,
198
+ ).then(
199
+ fn=check_input_token_length,
200
+ inputs=[saved_input, chatbot, system_prompt],
201
+ api_name=False,
202
+ queue=False,
203
+ ).success(
204
+ fn=generate,
205
+ inputs=[
206
+ saved_input,
207
+ chatbot,
208
+ system_prompt,
209
+ max_new_tokens,
210
+ temperature,
211
+ top_p,
212
+ top_k,
213
+ ],
214
+ outputs=chatbot,
215
+ api_name=False,
216
+ )
217
+
218
+ retry_button.click(
219
+ fn=delete_prev_fn,
220
+ inputs=chatbot,
221
+ outputs=[chatbot, saved_input],
222
+ api_name=False,
223
+ queue=False,
224
+ ).then(
225
+ fn=display_input,
226
+ inputs=[saved_input, chatbot],
227
+ outputs=chatbot,
228
+ api_name=False,
229
+ queue=False,
230
+ ).then(
231
+ fn=generate,
232
+ inputs=[
233
+ saved_input,
234
+ chatbot,
235
+ system_prompt,
236
+ max_new_tokens,
237
+ temperature,
238
+ top_p,
239
+ top_k,
240
+ ],
241
+ outputs=chatbot,
242
+ api_name=False,
243
+ )
244
+
245
+ undo_button.click(
246
+
247
+ fn=delete_prev_fn,
248
+ inputs=chatbot,
249
+ outputs=[chatbot, saved_input],
250
+ api_name=False,
251
+ queue=False,
252
+ ).then(
253
+ fn=lambda x: x,
254
+ inputs=[saved_input],
255
+ outputs=textbox,
256
+ api_name=False,
257
+ queue=False,
258
+ )
259
+
260
+ clear_button.click(
261
+ fn=lambda: ([], ''),
262
+ outputs=[chatbot, saved_input],
263
+ queue=False,
264
+ api_name=False,
265
+ )
266
+
267
+ demo.queue(max_size=20).launch()