Moses25 commited on
Commit
7fc6987
1 Parent(s): 3936d8f
Files changed (1) hide show
  1. gradio_demo.py +613 -0
gradio_demo.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import (
3
+ LlamaForCausalLM,
4
+ LlamaTokenizer,
5
+ StoppingCriteria,
6
+ BitsAndBytesConfig
7
+ )
8
+ import gradio as gr
9
+ import argparse
10
+ import os
11
+ from queue import Queue
12
+ from threading import Thread
13
+ import traceback
14
+ import gc
15
+ import json
16
+ import requests
17
+ from typing import Iterable, List
18
+ import subprocess
19
+ import re
20
+
21
+ DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant.Help as much as you can. """
22
+
23
+ TEMPLATE_WITH_SYSTEM_PROMPT = (
24
+ "[INST] <<SYS>>\n"
25
+ "{system_prompt}\n"
26
+ "<</SYS>>\n\n"
27
+ "{instruction} [/INST]"
28
+ )
29
+
30
+ TEMPLATE_WITHOUT_SYSTEM_PROMPT = "[INST] {instruction} [/INST]"
31
+
32
+ # Parse command-line arguments
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument(
35
+ '--base_model',
36
+ default=None,
37
+ type=str,
38
+ required=True,
39
+ help='Base model path')
40
+ parser.add_argument('--lora_model', default=None, type=str,
41
+ help="If None, perform inference on the base model")
42
+ parser.add_argument(
43
+ '--tokenizer_path',
44
+ default=None,
45
+ type=str,
46
+ help='If None, lora model path or base model path will be used')
47
+ parser.add_argument(
48
+ '--gpus',
49
+ default="0",
50
+ type=str,
51
+ help='If None, cuda:0 will be used. Inference using multi-cards: --gpus=0,1,... ')
52
+ parser.add_argument('--share', default=True, help='Share gradio domain name')
53
+ parser.add_argument('--port', default=19324, type=int, help='Port of gradio demo')
54
+ parser.add_argument(
55
+ '--max_memory',
56
+ default=1024,
57
+ type=int,
58
+ help='Maximum number of input tokens (including system prompt) to keep. If exceeded, earlier history will be discarded.')
59
+ parser.add_argument(
60
+ '--load_in_8bit',
61
+ action='store_true',
62
+ default=False,
63
+ help='Use 8 bit quantized model')
64
+ parser.add_argument(
65
+ '--load_in_4bit',
66
+ action='store_true',
67
+ default=False,
68
+ help='Use 4 bit quantized model')
69
+ parser.add_argument(
70
+ '--only_cpu',
71
+ action='store_true',
72
+ help='Only use CPU for inference')
73
+ parser.add_argument(
74
+ '--alpha',
75
+ type=str,
76
+ default="1.0",
77
+ help="The scaling factor of NTK method, can be a float or 'auto'. ")
78
+ parser.add_argument(
79
+ "--use_vllm",
80
+ action='store_true',
81
+ help="Use vLLM as back-end LLM service.")
82
+ parser.add_argument(
83
+ "--post_host",
84
+ type=str,
85
+ default="0.0.0.0",
86
+ help="Host of vLLM service.")
87
+ parser.add_argument(
88
+ "--post_port",
89
+ type=int,
90
+ default=7777,
91
+ help="Port of vLLM service.")
92
+ args = parser.parse_args()
93
+
94
+ ENABLE_CFG_SAMPLING = True
95
+ try:
96
+ from transformers.generation import UnbatchedClassifierFreeGuidanceLogitsProcessor
97
+ except ImportError:
98
+ ENABLE_CFG_SAMPLING = False
99
+ print("Install the latest transformers (commit equal or later than d533465) to enable CFG sampling.")
100
+ if args.use_vllm is True:
101
+ print("CFG sampling is disabled when using vLLM.")
102
+ ENABLE_CFG_SAMPLING = False
103
+
104
+ if args.only_cpu is True:
105
+ args.gpus = ""
106
+ if args.load_in_8bit or args.load_in_4bit:
107
+ raise ValueError("Quantization is unavailable on CPU.")
108
+ if args.load_in_8bit and args.load_in_4bit:
109
+ raise ValueError("Only one quantization method can be chosen for inference. Please check your arguments")
110
+ import sys
111
+ parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
112
+ sys.path.append(parent_dir)
113
+ from attn_and_long_ctx_patches import apply_attention_patch, apply_ntk_scaling_patch
114
+ if not args.only_cpu:
115
+ apply_attention_patch(use_memory_efficient_attention=True)
116
+ apply_ntk_scaling_patch(args.alpha)
117
+
118
+ # Set CUDA devices if available
119
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
120
+
121
+
122
+ # Peft library can only import after setting CUDA devices
123
+ from peft import PeftModel
124
+
125
+
126
+ # Set up the required components: model and tokenizer
127
+
128
+ def setup():
129
+ global tokenizer, model, device, share, port, max_memory
130
+ if args.use_vllm:
131
+ # global share, port, max_memory
132
+ max_memory = args.max_memory
133
+ port = args.port
134
+ share = args.share
135
+
136
+ if args.lora_model is not None:
137
+ raise ValueError("vLLM currently does not support LoRA, please merge the LoRA weights to the base model.")
138
+ if args.load_in_8bit or args.load_in_4bit:
139
+ raise ValueError("vLLM currently does not support quantization, please use fp16 (default) or unuse --use_vllm.")
140
+ if args.only_cpu:
141
+ raise ValueError("vLLM requires GPUs with compute capability not less than 7.0. If you want to run only on CPU, please unuse --use_vllm.")
142
+
143
+ if args.tokenizer_path is None:
144
+ args.tokenizer_path = args.base_model
145
+ tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True)
146
+
147
+ print("Start launch vllm server.")
148
+ cmd = f"python -m vllm.entrypoints.api_server \
149
+ --model={args.base_model} \
150
+ --tokenizer={args.tokenizer_path} \
151
+ --tokenizer-mode=slow \
152
+ --tensor-parallel-size={len(args.gpus.split(','))} \
153
+ --host {args.post_host} \
154
+ --port {args.post_port} \
155
+ &"
156
+ subprocess.check_call(cmd, shell=True)
157
+ else:
158
+ max_memory = args.max_memory
159
+ port = args.port
160
+ share = args.share
161
+ load_type = torch.float16
162
+ if torch.cuda.is_available():
163
+ device = torch.device(0)
164
+ else:
165
+ device = torch.device('cpu')
166
+ if args.tokenizer_path is None:
167
+ args.tokenizer_path = args.base_model
168
+ # if args.lora_model is None:
169
+ # args.tokenizer_path = args.base_model
170
+ tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True)
171
+ tokenizer.pad_token_id = 0
172
+ # tokenizer.pad_token = "<>"
173
+ base_model = LlamaForCausalLM.from_pretrained(
174
+ args.base_model,
175
+ torch_dtype=load_type,
176
+ low_cpu_mem_usage=True,
177
+ device_map='auto',
178
+ quantization_config=BitsAndBytesConfig(
179
+ load_in_4bit=args.load_in_4bit,
180
+ load_in_8bit=args.load_in_8bit,
181
+ bnb_4bit_compute_dtype=load_type
182
+ )
183
+ )
184
+
185
+ model_vocab_size = base_model.get_input_embeddings().weight.size(0)
186
+ tokenizer_vocab_size = len(tokenizer)
187
+ print(f"Vocab of the base model: {model_vocab_size}")
188
+ print(f"Vocab of the tokenizer: {tokenizer_vocab_size}")
189
+ if model_vocab_size != tokenizer_vocab_size:
190
+ print("Resize model embeddings to fit tokenizer")
191
+ base_model.resize_token_embeddings(tokenizer_vocab_size)
192
+ if args.lora_model is not None:
193
+ print("loading peft model")
194
+ model = PeftModel.from_pretrained(
195
+ base_model,
196
+ args.lora_model,
197
+ torch_dtype=load_type,
198
+ device_map='auto',
199
+ ).half()
200
+ else:
201
+ model = base_model
202
+
203
+ if device == torch.device('cpu'):
204
+ model.float()
205
+
206
+ model.eval()
207
+
208
+
209
+ # Reset the user input
210
+ def reset_user_input():
211
+ return gr.update(value='')
212
+
213
+
214
+ # Reset the state
215
+ def reset_state():
216
+ return []
217
+
218
+
219
+ def generate_prompt(instruction, response="", with_system_prompt=True, system_prompt=DEFAULT_SYSTEM_PROMPT):
220
+ if with_system_prompt is True:
221
+ prompt = TEMPLATE_WITH_SYSTEM_PROMPT.format_map({'instruction': instruction,'system_prompt': system_prompt})
222
+ else:
223
+ prompt = TEMPLATE_WITHOUT_SYSTEM_PROMPT.format_map({'instruction': instruction})
224
+ if len(response)>0:
225
+ prompt += " " + response
226
+ return prompt
227
+
228
+
229
+ # User interaction function for chat
230
+ def user(user_message, history):
231
+ return gr.update(value="", interactive=False), history + \
232
+ [[user_message, None]]
233
+
234
+
235
+ class Stream(StoppingCriteria):
236
+ def __init__(self, callback_func=None):
237
+ self.callback_func = callback_func
238
+
239
+ def __call__(self, input_ids, scores) -> bool:
240
+ if self.callback_func is not None:
241
+ self.callback_func(input_ids[0])
242
+ return False
243
+
244
+
245
+ class Iteratorize:
246
+ """
247
+ Transforms a function that takes a callback
248
+ into a lazy iterator (generator).
249
+
250
+ Adapted from: https://stackoverflow.com/a/9969000
251
+ """
252
+ def __init__(self, func, kwargs=None, callback=None):
253
+ self.mfunc = func
254
+ self.c_callback = callback
255
+ self.q = Queue()
256
+ self.sentinel = object()
257
+ self.kwargs = kwargs or {}
258
+ self.stop_now = False
259
+
260
+ def _callback(val):
261
+ if self.stop_now:
262
+ raise ValueError
263
+ self.q.put(val)
264
+
265
+ def gentask():
266
+ try:
267
+ ret = self.mfunc(callback=_callback, **self.kwargs)
268
+ except ValueError:
269
+ pass
270
+ except Exception:
271
+ traceback.print_exc()
272
+
273
+ clear_torch_cache()
274
+ self.q.put(self.sentinel)
275
+ if self.c_callback:
276
+ self.c_callback(ret)
277
+
278
+ self.thread = Thread(target=gentask)
279
+ self.thread.start()
280
+
281
+ def __iter__(self):
282
+ return self
283
+
284
+ def __next__(self):
285
+ obj = self.q.get(True, None)
286
+ if obj is self.sentinel:
287
+ raise StopIteration
288
+ else:
289
+ return obj
290
+
291
+ def __del__(self):
292
+ clear_torch_cache()
293
+
294
+ def __enter__(self):
295
+ return self
296
+
297
+ def __exit__(self, exc_type, exc_val, exc_tb):
298
+ self.stop_now = True
299
+ clear_torch_cache()
300
+
301
+
302
+ def clear_torch_cache():
303
+ gc.collect()
304
+ if torch.cuda.device_count() > 0:
305
+ torch.cuda.empty_cache()
306
+
307
+
308
+ def post_http_request(prompt: str,
309
+ api_url: str,
310
+ n: int = 1,
311
+ top_p: float = 0.9,
312
+ top_k: int = 40,
313
+ temperature: float = 0.2,
314
+ max_tokens: int = 1024,
315
+ presence_penalty: float = 1.0,
316
+ use_beam_search: bool = False,
317
+ stream: bool = False) -> requests.Response:
318
+ headers = {"User-Agent": "Test Client"}
319
+ pload = {
320
+ "prompt": prompt,
321
+ "n": n,
322
+ "top_p": 1 if use_beam_search else top_p,
323
+ "top_k": -1 if use_beam_search else top_k,
324
+ "temperature": 0 if use_beam_search else temperature,
325
+ "max_tokens": max_tokens,
326
+ "use_beam_search": use_beam_search,
327
+ "best_of": 5 if use_beam_search else n,
328
+ "presence_penalty": presence_penalty,
329
+ "stream": stream,
330
+ }
331
+ print(pload)
332
+
333
+ response = requests.post(api_url, headers=headers, json=pload, stream=True)
334
+ return response
335
+
336
+
337
+ def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
338
+ for chunk in response.iter_lines(chunk_size=8192,
339
+ decode_unicode=False,
340
+ delimiter=b"\0"):
341
+ if chunk:
342
+ data = json.loads(chunk.decode("utf-8"))
343
+ output = data["text"]
344
+ yield output
345
+
346
+
347
+ # Perform prediction based on the user input and history
348
+ @torch.no_grad()
349
+ def predict(
350
+ history,
351
+ system_prompt,
352
+ negative_prompt,
353
+ max_new_tokens=1024,
354
+ top_p=0.89,
355
+ temperature=0.85,
356
+ top_k=40,
357
+ do_sample=True,
358
+ repetition_penalty=1.2,
359
+ guidance_scale=1.0,
360
+ presence_penalty=0.0,
361
+ ):
362
+ if len(system_prompt) == 0:
363
+ system_prompt = DEFAULT_SYSTEM_PROMPT
364
+ while True:
365
+ print("len(history):", len(history))
366
+ print("history: ", history)
367
+ history[-1][1] = ""
368
+ if len(history) == 1:
369
+ input = history[0][0]
370
+ prompt = generate_prompt(input,response="", with_system_prompt=True, system_prompt=system_prompt)
371
+ print(f"prompt:{prompt}")
372
+ else:
373
+ input = history[0][0]
374
+ response = history[0][1]
375
+ prompt = generate_prompt(input, response=response, with_system_prompt=True, system_prompt=system_prompt)+'</s>'
376
+ for hist in history[1:-1]:
377
+ input = hist[0]
378
+ response = hist[1]
379
+ prompt = prompt + '<s>'+generate_prompt(input, response=response, with_system_prompt=False)+'</s>'
380
+ input = history[-1][0]
381
+ prompt = prompt + '<s>'+generate_prompt(input, response="", with_system_prompt=False)
382
+ print(f"prompt1:{prompt}")
383
+ input_length = len(tokenizer.encode(prompt, add_special_tokens=True))
384
+ print(f"Input length: {input_length}")
385
+ if input_length > max_memory and len(history) > 1:
386
+ print(f"The input length ({input_length}) exceeds the max memory ({max_memory}). The earlier history will be discarded.")
387
+ history = history[1:]
388
+ print("history: ", history)
389
+ else:
390
+ break
391
+
392
+ if args.use_vllm:
393
+ generate_params = {
394
+ 'max_tokens': max_new_tokens,
395
+ 'top_p': top_p,
396
+ 'temperature': temperature,
397
+ 'top_k': top_k,
398
+ "use_beam_search": not do_sample,
399
+ 'presence_penalty': presence_penalty,
400
+ }
401
+
402
+ api_url = f"http://{args.post_host}:{args.post_port}/generate"
403
+
404
+
405
+ response = post_http_request(prompt, api_url, **generate_params, stream=True)
406
+
407
+ for h in get_streaming_response(response):
408
+ for line in h:
409
+ line = line.replace(prompt, '')
410
+ history[-1][1] = line
411
+ yield history
412
+
413
+ else:
414
+ negative_text = None
415
+ if len(negative_prompt) != 0:
416
+ negative_text = re.sub(r"<<SYS>>\n(.*)\n<</SYS>>", f"<<SYS>>\n{negative_prompt}\n<</SYS>>", prompt)
417
+ inputs = tokenizer(prompt, return_tensors="pt")
418
+ input_ids = inputs["input_ids"].to(device)
419
+ if negative_text is None:
420
+ negative_prompt_ids = None
421
+ negative_prompt_attention_mask = None
422
+ else:
423
+ negative_inputs = tokenizer(negative_text,return_tensors="pt")
424
+ negative_prompt_ids = negative_inputs["input_ids"].to(device)
425
+ negative_prompt_attention_mask = negative_inputs["attention_mask"].to(device)
426
+ generate_params = {
427
+ 'input_ids': input_ids,
428
+ 'max_new_tokens': max_new_tokens,
429
+ 'top_p': top_p,
430
+ 'temperature': temperature,
431
+ 'top_k': top_k,
432
+ 'do_sample': do_sample,
433
+ 'repetition_penalty': repetition_penalty,
434
+ }
435
+ if ENABLE_CFG_SAMPLING is True:
436
+ generate_params['guidance_scale'] = guidance_scale
437
+ generate_params['negative_prompt_ids'] = negative_prompt_ids
438
+ generate_params['negative_prompt_attention_mask'] = negative_prompt_attention_mask
439
+
440
+ def generate_with_callback(callback=None, **kwargs):
441
+ if 'stopping_criteria' in kwargs:
442
+ kwargs['stopping_criteria'].append(Stream(callback_func=callback))
443
+ else:
444
+ kwargs['stopping_criteria'] = [Stream(callback_func=callback)]
445
+ clear_torch_cache()
446
+ with torch.no_grad():
447
+ model.generate(**kwargs)
448
+
449
+ def generate_with_streaming(**kwargs):
450
+ return Iteratorize(generate_with_callback, kwargs, callback=None)
451
+
452
+ with generate_with_streaming(**generate_params) as generator:
453
+ for output in generator:
454
+ next_token_ids = output[len(input_ids[0]):]
455
+ if next_token_ids[0] in [tokenizer.eos_token_id,0]:
456
+ break
457
+ new_tokens = tokenizer.decode(
458
+ next_token_ids, skip_special_tokens=True)
459
+ if isinstance(tokenizer, LlamaTokenizer) and len(next_token_ids) > 0:
460
+ if tokenizer.convert_ids_to_tokens(int(next_token_ids[0])).startswith('▁'):
461
+ new_tokens = ' ' + new_tokens
462
+
463
+ history[-1][1] = new_tokens
464
+ yield history
465
+ if len(next_token_ids) >= max_new_tokens:
466
+ break
467
+
468
+
469
+ # Call the setup function to initialize the components
470
+ setup()
471
+
472
+
473
+ # Create the Gradio interface
474
+ with gr.Blocks(
475
+ theme=gr.themes.Soft(),
476
+ css=".disclaimer {font-variant-caps: all-small-caps;}") as demo:
477
+ github_banner_path = 'https://raw.githubusercontent.com/moseshu/llama2-chat/main/llama2.jpg'
478
+ gr.HTML(f'<p align="center"><a href="https://huggingface.co/Moses25/Llama2-Moses-7b-chat"><img src={github_banner_path} width="200" height="80"/>Llama2-Moses-7b</a></p>')
479
+ chatbot = gr.Chatbot().style(height=300)
480
+ with gr.Row():
481
+ with gr.Column(scale=4):
482
+ with gr.Column(scale=3):
483
+ system_prompt_input = gr.Textbox(
484
+ show_label=True,
485
+ label="system prompt(仅在对话开始前或清空历史后修改有效,对话过程中修改无效)",
486
+ placeholder=DEFAULT_SYSTEM_PROMPT,
487
+ lines=1).style(
488
+ container=True)
489
+ negative_prompt_input = gr.Textbox(
490
+ show_label=True,
491
+ label="反向提示语(仅在对话开始前或清空历史后修改有效,对话过程中修改无效)",
492
+ placeholder="option",
493
+ lines=1,
494
+ visible=ENABLE_CFG_SAMPLING).style(
495
+ container=True)
496
+ with gr.Column(scale=10):
497
+ user_input = gr.Textbox(
498
+ show_label=True,
499
+ label="ChatBox",
500
+ text_align='right',
501
+ placeholder="Shift + Enter发送消息...",
502
+ lines=10).style(
503
+ container=True)
504
+ with gr.Column(min_width=24, scale=1):
505
+ submitBtn = gr.Button("Submit", variant="primary")
506
+ with gr.Column(scale=1):
507
+ emptyBtn = gr.Button("Clear History")
508
+ max_new_token = gr.Slider(
509
+ 0,
510
+ 4096,
511
+ value=1024,
512
+ step=1.0,
513
+ label="Maximum New Token Length",
514
+ interactive=True)
515
+ top_p = gr.Slider(0, 1, value=0.9, step=0.01,
516
+ label="Top P", interactive=True)
517
+ temperature = gr.Slider(
518
+ 0,
519
+ 1,
520
+ value=0.7,
521
+ step=0.01,
522
+ label="Temperature",
523
+ interactive=True)
524
+ top_k = gr.Slider(1, 40, value=40, step=1,
525
+ label="Top K", interactive=True)
526
+ do_sample = gr.Checkbox(
527
+ value=True,
528
+ label="Do Sample",
529
+ info="use random sample strategy",
530
+ interactive=True)
531
+ repetition_penalty = gr.Slider(
532
+ 1.0,
533
+ 3.0,
534
+ value=1.1,
535
+ step=0.1,
536
+ label="Repetition Penalty",
537
+ interactive=True,
538
+ visible=False if args.use_vllm else True)
539
+ guidance_scale = gr.Slider(
540
+ 1.0,
541
+ 3.0,
542
+ value=1.0,
543
+ step=0.1,
544
+ label="Guidance Scale",
545
+ interactive=True,
546
+ visible=ENABLE_CFG_SAMPLING)
547
+ presence_penalty = gr.Slider(
548
+ -2.0,
549
+ 2.0,
550
+ value=1.0,
551
+ step=0.1,
552
+ label="Presence Penalty",
553
+ interactive=True,
554
+ visible=True if args.use_vllm else False)
555
+
556
+
557
+ params = [user_input, chatbot]
558
+ predict_params = [
559
+ chatbot,
560
+ system_prompt_input,
561
+ negative_prompt_input,
562
+ max_new_token,
563
+ top_p,
564
+ temperature,
565
+ top_k,
566
+ do_sample,
567
+ repetition_penalty,
568
+ guidance_scale,
569
+ presence_penalty]
570
+ with gr.Row():
571
+ gr.Markdown(
572
+ "免责声明:该模型可能会产生与事实不符的输出,不应依赖该模型来产生与事实相符的信息。模型在各种公共数据集以及得物一些商品信息进行训练。尽管做了大量的数据清洗,但是模型的输出结果还可能存在一些问题",
573
+ elem_classes=["disclaimer"],
574
+ )
575
+ submitBtn.click(
576
+ user,
577
+ params,
578
+ params,
579
+ queue=False).then(
580
+ predict,
581
+ predict_params,
582
+ chatbot).then(
583
+ lambda: gr.update(
584
+ interactive=True),
585
+ None,
586
+ [user_input],
587
+ queue=False)
588
+
589
+ user_input.submit(
590
+ user,
591
+ params,
592
+ params,
593
+ queue=False).then(
594
+ predict,
595
+ predict_params,
596
+ chatbot).then(
597
+ lambda: gr.update(
598
+ interactive=True),
599
+ None,
600
+ [user_input],
601
+ queue=False)
602
+
603
+ submitBtn.click(reset_user_input, [], [user_input])
604
+
605
+ emptyBtn.click(reset_state, outputs=[chatbot], show_progress=True)
606
+
607
+
608
+ # Launch the Gradio interface
609
+ demo.queue().launch(
610
+ share=share,
611
+ inbrowser=True,
612
+ server_name='0.0.0.0',
613
+ server_port=port)