wmpscc commited on
Commit
1ca7e27
·
1 Parent(s): 0a9d70b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -16
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
3
 
4
  import torch
@@ -7,28 +8,48 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
8
 
9
  def init_model():
10
- model = AutoModelForCausalLM.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", device_map="cuda:0", torch_dtype=torch.float16, trust_remote_code=True)
 
11
  tokenizer = AutoTokenizer.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", use_fast=False, trust_remote_code=True)
12
  return model, tokenizer
13
-
14
 
15
 
16
- def chat(prompt, top_k, temperature):
17
- prompt = f"### Instruction:{prompt.strip()} ### Response:"
18
- inputs = tokenizer(prompt, return_tensors="pt").to("cuda:0")
19
- generate_ids = model.generate(inputs.input_ids, do_sample=True, max_new_tokens=2048, top_k=int(top_k), top_p=0.84, temperature=float(temperature), repetition_penalty=1.15, eos_token_id=2, bos_token_id=1, pad_token_id=0)
20
- response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
21
- response = response.lstrip(prompt)
22
- print('-log: ',prompt, response)
23
- return response
 
 
 
 
 
 
 
 
24
 
25
 
26
  if __name__ == '__main__':
 
 
 
 
27
  model, tokenizer = init_model()
28
- demo = gr.Interface(
29
- fn=chat,
30
- inputs=["text", gr.Slider(1, 60, value=10, step=1), gr.Slider(0.1, 2.0, value=1.0, step=0.1)],
31
- outputs="text",
 
 
 
 
 
 
 
 
32
  )
33
- demo.launch()
34
-
 
1
  import os
2
+
3
  os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
4
 
5
  import torch
 
8
 
9
 
10
  def init_model():
11
+ model = AutoModelForCausalLM.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", device_map="cuda:0",
12
+ torch_dtype=torch.bfloat16, trust_remote_code=True)
13
  tokenizer = AutoTokenizer.from_pretrained("Linly-AI/Chinese-LLaMA-2-7B-hf", use_fast=False, trust_remote_code=True)
14
  return model, tokenizer
 
15
 
16
 
17
+ def process(message, history):
18
+ input_prompt = ""
19
+ for interaction in history:
20
+ input_prompt = f"{input_prompt} User: {str(interaction[0]).strip(' ')} Bot: {str(interaction[1]).strip(' ')}"
21
+ input_prompt = f"{input_prompt} ### Instruction:{message.strip()} ### Response:"
22
+ inputs = tokenizer(input_prompt, return_tensors="pt").to("cuda:0")
23
+ try:
24
+ generate_ids = model.generate(inputs.input_ids, max_new_tokens=2048, do_sample=True, top_k=30, top_p=0.84,
25
+ temperature=1.0, repetition_penalty=1.15, eos_token_id=2, bos_token_id=1,
26
+ pad_token_id=0)
27
+ response = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
28
+ print('log:', response)
29
+ response = response.split("### Response:")[-1]
30
+ return response
31
+ except:
32
+ return "Error: 会话超长,请重试!"
33
 
34
 
35
  if __name__ == '__main__':
36
+ examples = ["Python和JavaScript编程语言的主要区别是什么?", "影响消费者行为的主要因素是什么?", "请用pytorch实现一个带ReLU激活函数的全连接层的代码",
37
+ "请用C++编程语言实现“给你两个字符串haystack和needle,在haystack字符串中找出needle字符串的第一个匹配项的下标(下标从 0 开始)。如果needle不是haystack的一部分,则返回-1。",
38
+ "为什么有些人选择使用纸质地图或寻求方向,而不是依赖GPS设备或智能手机应用程序?",
39
+ "应对压力最有效的方法是什么?"]
40
  model, tokenizer = init_model()
41
+ demo = gr.ChatInterface(
42
+ process,
43
+ chatbot=gr.Chatbot(height=600),
44
+ textbox=gr.Textbox(placeholder="Input", container=False, scale=7),
45
+ title="Linly ChatFlow",
46
+ description="",
47
+ theme="soft",
48
+ examples=examples,
49
+ cache_examples=True,
50
+ retry_btn="Retry",
51
+ undo_btn="Delete Previous",
52
+ clear_btn="Clear",
53
  )
54
+ demo.queue(concurrency_count=75).launch(share=True, server_name="0.0.0.0", server_port=7862, show_error=True,
55
+ debug=True)