IELTS8 commited on
Commit
516297e
β€’
1 Parent(s): 52efd6b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +417 -0
app.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import logging
4
+ import sys
5
+ import torch
6
+ import gradio as gr
7
+ from huggingface_hub import Repository
8
+ from text_generation import Client
9
+ from app_modules.utils import convert_to_markdown
10
+ # from dialogues import DialogueTemplate
11
+ from share_btn import (community_icon_html, loading_icon_html, share_btn_css,
12
+ share_js)
13
+
14
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
15
+ API_TOKEN = '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'
16
+ API_URL = os.environ.get("API_URL", None)
17
+ API_URL = "https://api-inference.huggingface.co/models/timdettmers/guanaco-33b-merged"
18
+
19
+ client = Client(
20
+ API_URL,
21
+ headers={"Authorization": f"Bearer {API_TOKEN}"},
22
+ )
23
+
24
+ repo = None
25
+
26
+ logging.basicConfig(
27
+ format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
28
+ datefmt="%Y-%m-%dT%H:%M:%SZ",
29
+ )
30
+ logger = logging.getLogger(__name__)
31
+ logger.setLevel(logging.DEBUG)
32
+
33
+ examples = [
34
+ "Describe the advantages and disadvantages of Incremental Sheet Forming.",
35
+ "Describe the applications of Incremental Sheet Forming.",
36
+ "Describe the process parameters included in Incremental Sheet Forming in dot points."
37
+ ]
38
+
39
+
40
+ def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep):
41
+ past = []
42
+ for data in chatbot:
43
+ user_data, model_data = data
44
+
45
+ if not user_data.startswith(user_name):
46
+ user_data = user_name + user_data
47
+ if not model_data.startswith(sep + assistant_name):
48
+ model_data = sep + assistant_name + model_data
49
+
50
+ past.append(user_data + model_data.rstrip() + sep)
51
+
52
+ if not inputs.startswith(user_name):
53
+ inputs = user_name + inputs
54
+
55
+ total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip()
56
+
57
+ return total_inputs
58
+
59
+
60
+ def has_no_history(chatbot, history):
61
+ return not chatbot and not history
62
+
63
+
64
+ header = "A chat between a curious human and an artificial intelligence assistant about Incremental Sheet Forming (ISF). " \
65
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
66
+ prompt_template = "### Human: {query}\n### Assistant:{response}"
67
+
68
+
69
+ def generate(
70
+ user_message,
71
+ chatbot,
72
+ history,
73
+ temperature,
74
+ top_p,
75
+ top_k,
76
+ max_new_tokens,
77
+ repetition_penalty,
78
+ ):
79
+ # Don't return meaningless message when the input is empty
80
+ if not user_message:
81
+ print("Empty input")
82
+
83
+ history.append(user_message)
84
+
85
+ past_messages = []
86
+ for data in chatbot:
87
+ user_data, model_data = data
88
+
89
+ past_messages.extend(
90
+ [{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}]
91
+ )
92
+
93
+ if len(past_messages) < 1:
94
+ prompt = header + prompt_template.format(query=user_message, response="")
95
+ else:
96
+ prompt = header
97
+ for i in range(0, len(past_messages), 2):
98
+ intermediate_prompt = prompt_template.format(query=past_messages[i]["content"],
99
+ response=past_messages[i + 1]["content"])
100
+ print("intermediate: ", intermediate_prompt)
101
+ prompt = prompt + '\n' + intermediate_prompt
102
+
103
+ prompt = prompt + prompt_template.format(query=user_message, response="")
104
+
105
+ temperature = float(temperature)
106
+ if temperature < 1e-2:
107
+ temperature = 1e-2
108
+ top_p = float(top_p)
109
+
110
+ generate_kwargs = dict(
111
+ temperature=temperature,
112
+ max_new_tokens=max_new_tokens,
113
+ top_p=top_p,
114
+ top_k=top_k,
115
+ repetition_penalty=repetition_penalty,
116
+ do_sample=True,
117
+ truncate=999,
118
+ seed=42,
119
+ )
120
+
121
+ stream = client.generate_stream(
122
+ prompt,
123
+ **generate_kwargs,
124
+ )
125
+
126
+ output = ""
127
+ for idx, response in enumerate(stream):
128
+ if response.token.text == '':
129
+ break
130
+
131
+ if response.token.special:
132
+ continue
133
+ output += response.token.text
134
+ if idx == 0:
135
+ history.append(" " + output)
136
+ else:
137
+ history[-1] = output
138
+
139
+ chat = [(convert_to_markdown(history[i].strip()), convert_to_markdown(history[i + 1].strip())) for i in range(0, len(history) - 1, 2)]
140
+
141
+ yield chat, history, user_message, ""
142
+
143
+ return chat, history, user_message, ""
144
+
145
+
146
+ def clear_chat():
147
+ return [], []
148
+
149
+
150
+ def save(
151
+ history,
152
+ temperature=0.7,
153
+ top_p=0.9,
154
+ top_k=50,
155
+ max_new_tokens=512,
156
+ repetition_penalty=1.2,
157
+ max_memory=1024,
158
+ ):
159
+ history = [] if history is None else history
160
+ data_point = {'history': history, 'generation_parameter': {
161
+ "temperature": temperature,
162
+ "top_p": top_p,
163
+ "top_k": top_k,
164
+ "max_new_tokens": max_new_tokens,
165
+ "repetition_penalty": repetition_penalty,
166
+ "max_memory": max_memory,
167
+ }}
168
+ print(data_point)
169
+ file_name = "history.jsonl"
170
+ with open(file_name, 'a') as f:
171
+ for line in [data_point]:
172
+ f.write(json.dumps(line, ensure_ascii=False) + '\n')
173
+
174
+
175
+ def process_example(args):
176
+ for [x, y] in generate(args):
177
+ pass
178
+ return [x, y]
179
+
180
+
181
+ title = """<h1 align="center">ISF Alpaca πŸ’¬</h1>"""
182
+ custom_css = """
183
+ #banner-image {
184
+ display: block;
185
+ margin-left: auto;
186
+ margin-right: auto;
187
+ }
188
+ #chat-message {
189
+ font-size: 14px;
190
+ min-height: 300px;
191
+ }
192
+ """
193
+
194
+ with gr.Blocks(analytics_enabled=False,
195
+ theme=gr.themes.Soft(),
196
+ css=".disclaimer {font-variant-caps: all-small-caps;}") as demo:
197
+ gr.HTML(title)
198
+ # status_display = gr.Markdown("Success", elem_id="status_display")
199
+ with gr.Row():
200
+ with gr.Column():
201
+ gr.Markdown(
202
+ """
203
+ 🏭 The fine-tuned model primarily emphasizes **Knowledge Augmentation** in the Manufacturing domain,
204
+ with **Incremental Sheet Forming (ISF)** serving as a use case.
205
+ """
206
+ )
207
+ history = gr.components.State()
208
+
209
+ with gr.Row(scale=1).style(equal_height=True):
210
+ with gr.Column(scale=5):
211
+ with gr.Row(scale=1):
212
+ chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height=476)
213
+ with gr.Row(scale=1):
214
+ with gr.Column(scale=12):
215
+ user_message = gr.Textbox(
216
+ show_label=False, placeholder="Enter text"
217
+ ).style(container=False)
218
+ with gr.Column(min_width=70, scale=1):
219
+ submit_btn = gr.Button("Send")
220
+ with gr.Column(min_width=70, scale=1):
221
+ stop_btn = gr.Button("Stop")
222
+ with gr.Row():
223
+ gr.Examples(
224
+ examples=examples,
225
+ inputs=[user_message],
226
+ cache_examples=False,
227
+ outputs=[chatbot, history],
228
+ )
229
+ with gr.Row(scale=1):
230
+ clear_history = gr.Button(
231
+ "🧹 New Conversation",
232
+ )
233
+ reset_btn = gr.Button("πŸ”„ Reset Parameter")
234
+ save_btn = gr.Button("πŸ“₯ Save Chat")
235
+ with gr.Column():
236
+ input_component_column = gr.Column(min_width=50, scale=1)
237
+ with input_component_column:
238
+ with gr.Tab(label="Parameter Setting"):
239
+ gr.Markdown("# Parameters")
240
+ temperature = gr.components.Slider(minimum=0, maximum=1, value=0.7, label="Temperature")
241
+ top_p = gr.components.Slider(minimum=0, maximum=1, value=0.9, label="Top p")
242
+ top_k = gr.components.Slider(minimum=0, maximum=100, step=1, value=20, label="Top k")
243
+ max_new_tokens = gr.components.Slider(minimum=1, maximum=2048, step=1, value=512,
244
+ label="Max New Tokens")
245
+ repetition_penalty = gr.components.Slider(minimum=0.1, maximum=10.0, step=0.1, value=1.2,
246
+ label="Repetition Penalty")
247
+ max_memory = gr.components.Slider(minimum=0, maximum=2048, step=1, value=2048, label="Max Memory")
248
+
249
+ history = gr.State([])
250
+ last_user_message = gr.State("")
251
+
252
+ user_message.submit(
253
+ generate,
254
+ inputs=[
255
+ user_message,
256
+ chatbot,
257
+ history,
258
+ temperature,
259
+ top_p,
260
+ top_k,
261
+ max_new_tokens,
262
+ repetition_penalty,
263
+ ],
264
+ outputs=[chatbot, history, last_user_message, user_message],
265
+ )
266
+
267
+ submit_event = submit_btn.click(
268
+ generate,
269
+ inputs=[
270
+ user_message,
271
+ chatbot,
272
+ history,
273
+ temperature,
274
+ top_p,
275
+ top_k,
276
+ max_new_tokens,
277
+ repetition_penalty,
278
+ ],
279
+ outputs=[chatbot, history, last_user_message, user_message],
280
+ )
281
+ # submit_btn.click(
282
+ # lambda: (
283
+ # submit_btn.update(visible=False),
284
+ # stop_btn.update(visible=True),
285
+ # ),
286
+ # inputs=None,
287
+ # outputs=[submit_btn, stop_btn],
288
+ # queue=False,
289
+ # )
290
+
291
+ stop_btn.click(
292
+ lambda: (
293
+ submit_btn.update(visible=True),
294
+ stop_btn.update(visible=True),
295
+ ),
296
+ inputs=None,
297
+ outputs=[submit_btn, stop_btn],
298
+ cancels=[submit_event],
299
+ queue=False,
300
+ )
301
+
302
+ clear_history.click(clear_chat, outputs=[chatbot, history])
303
+ save_btn.click(
304
+ save,
305
+ inputs=[user_message, chatbot, history, temperature, top_p, top_k, max_new_tokens, repetition_penalty],
306
+ outputs=None,
307
+ )
308
+
309
+ input_components_except_states = [user_message, chatbot, history, temperature, top_p, top_k, max_new_tokens,
310
+ repetition_penalty]
311
+
312
+ reset_btn.click(
313
+ None,
314
+ [],
315
+ (input_components_except_states + [input_component_column]), # type: ignore
316
+ _js=f"""() => {json.dumps([getattr(component, "cleared_value", None) for component in input_components_except_states]
317
+ + ([gr.Column.update(visible=True)])
318
+ + ([])
319
+ )}
320
+ """,
321
+ )
322
+
323
+ demo.queue(concurrency_count=16).launch(debug=True, share=True)
324
+
325
+ # with gr.Row():
326
+ # with gr.Box():
327
+ # output = gr.Markdown()
328
+ # chatbot = gr.Chatbot(elem_id="chat-message", label="Chat")
329
+ #
330
+ # with gr.Row():
331
+ # with gr.Column(scale=3):
332
+ # user_message = gr.Textbox(placeholder="Enter your message here", show_label=False, elem_id="q-input")
333
+ # with gr.Row():
334
+ # send_button = gr.Button("Send", elem_id="send-btn", visible=True)
335
+ #
336
+ # clear_chat_button = gr.Button("Clear chat", elem_id="clear-btn", visible=True)
337
+ #
338
+ # with gr.Accordion(label="Parameters", open=False, elem_id="parameters-accordion"):
339
+ # temperature = gr.Slider(
340
+ # label="Temperature",
341
+ # value=0.7,
342
+ # minimum=0.0,
343
+ # maximum=1.0,
344
+ # step=0.1,
345
+ # interactive=True,
346
+ # info="Higher values produce more diverse outputs",
347
+ # )
348
+ # top_p = gr.Slider(
349
+ # label="Top-p (nucleus sampling)",
350
+ # value=0.9,
351
+ # minimum=0.0,
352
+ # maximum=1,
353
+ # step=0.05,
354
+ # interactive=True,
355
+ # info="Higher values sample more low-probability tokens",
356
+ # )
357
+ # max_new_tokens = gr.Slider(
358
+ # label="Max new tokens",
359
+ # value=1024,
360
+ # minimum=0,
361
+ # maximum=2048,
362
+ # step=4,
363
+ # interactive=True,
364
+ # info="The maximum numbers of new tokens",
365
+ # )
366
+ # repetition_penalty = gr.Slider(
367
+ # label="Repetition Penalty",
368
+ # value=1.2,
369
+ # minimum=0.0,
370
+ # maximum=10,
371
+ # step=0.1,
372
+ # interactive=True,
373
+ # info="The parameter for repetition penalty. 1.0 means no penalty.",
374
+ # )
375
+ # with gr.Row():
376
+ # gr.Examples(
377
+ # examples=examples,
378
+ # inputs=[user_message],
379
+ # cache_examples=False,
380
+ # fn=process_example,
381
+ # outputs=[output],
382
+ # )
383
+ #
384
+ # history = gr.State([])
385
+ # last_user_message = gr.State("")
386
+ #
387
+ # user_message.submit(
388
+ # generate,
389
+ # inputs=[
390
+ # user_message,
391
+ # chatbot,
392
+ # history,
393
+ # temperature,
394
+ # top_p,
395
+ # max_new_tokens,
396
+ # repetition_penalty,
397
+ # ],
398
+ # outputs=[chatbot, history, last_user_message, user_message],
399
+ # )
400
+ #
401
+ # send_button.click(
402
+ # generate,
403
+ # inputs=[
404
+ # user_message,
405
+ # chatbot,
406
+ # history,
407
+ # temperature,
408
+ # top_p,
409
+ # max_new_tokens,
410
+ # repetition_penalty,
411
+ # ],
412
+ # outputs=[chatbot, history, last_user_message, user_message],
413
+ # )
414
+ #
415
+ # clear_chat_button.click(clear_chat, outputs=[chatbot, history])
416
+
417
+ demo.queue(concurrency_count=16).launch(debug=True, share=True)