xiangpeng.wxp commited on
Commit
b31a023
1 Parent(s): 1defe03

add demo scripts

Browse files
Files changed (2) hide show
  1. polylm_cli_demo.py +94 -0
  2. polylm_web_demo_gradio.py +170 -0
polylm_cli_demo.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import platform
4
+ import warnings
5
+
6
+ import re
7
+ pattern = re.compile("[\n]+")
8
+
9
+ import torch
10
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
11
+ from huggingface_hub import snapshot_download
12
+
13
+ from transformers.generation.utils import logger
14
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
15
+
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument("--model_name", default="DAMO-NLP-MT/polylm-multialpaca-13b",
18
+ choices=["DAMO-NLP-MT/polylm-multialpaca-13b"], type=str)
19
+ parser.add_argument("--multi_round", action="store_true",
20
+ help="Turn multiple rounds interaction on.")
21
+ parser.add_argument("--gpu", default="0", type=str)
22
+ args = parser.parse_args()
23
+
24
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
25
+ num_gpus = len(args.gpu.split(","))
26
+
27
+ if args.model_name in ["DAMO-NLP-MT/polylm-multialpaca-13b-int8", "DAMO-NLP-MT/polylm-multialpaca-13b-int4"] and num_gpus > 1:
28
+ raise ValueError("Quantized models do not support model parallel. Please run on a single GPU (e.g., --gpu 0).")
29
+
30
+ logger.setLevel("ERROR")
31
+ warnings.filterwarnings("ignore")
32
+
33
+ model_path = args.model_name
34
+ if not os.path.exists(args.model_name):
35
+ model_path = snapshot_download(args.model_name)
36
+
37
+ config = AutoConfig.from_pretrained(model_path)
38
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
39
+
40
+ if num_gpus > 1:
41
+ print("Waiting for all devices to be ready, it may take a few minutes...")
42
+ with init_empty_weights():
43
+ raw_model = AutoModelForCausalLM.from_config(config)
44
+ raw_model.tie_weights()
45
+ model = load_checkpoint_and_dispatch(
46
+ raw_model, model_path, device_map="auto", no_split_module_classes=["GPT2Block"]
47
+ )
48
+ else:
49
+ print("Loading model files, it may take a few minutes...")
50
+ model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto").cuda()
51
+
52
+ def clear():
53
+ os.system('cls' if platform.system() == 'Windows' else 'clear')
54
+
55
+ def main():
56
+ print("欢迎使用 PolyLM 多语言人工智能助手!输入内容即可进行对话。输入 clear 以清空对话历史,输入 stop 以终止对话。")
57
+ prompt = ""
58
+ while True:
59
+ query = input()
60
+ if query.strip() == "stop":
61
+ break
62
+ if query.strip() == "clear":
63
+ if args.multi_round:
64
+ prompt = ""
65
+ clear()
66
+ continue
67
+
68
+ text = query.strip()
69
+ text = re.sub(pattern, "\n", text)
70
+ if args.multi_round:
71
+ prompt += f"{text}\n\n"
72
+ else:
73
+ prompt = f"{text}\n\n"
74
+
75
+ inputs = tokenizer(prompt, return_tensors="pt")
76
+ with torch.no_grad():
77
+ outputs = model.generate(
78
+ inputs.input_ids.cuda(),
79
+ attention_mask=inputs.attention_mask.cuda(),
80
+ max_length=1024,
81
+ do_sample=True,
82
+ top_p=0.8,
83
+ temperature=0.7,
84
+ repetition_penalty=1.02,
85
+ num_return_sequences=1,
86
+ eos_token_id=2,
87
+ early_stopping=True)
88
+ response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
89
+ if args.multi_round:
90
+ prompt += f"{response}\n"
91
+ print(f">>> {response}")
92
+
93
+ if __name__ == "__main__":
94
+ main()
polylm_web_demo_gradio.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import warnings
4
+ import mdtex2html
5
+ import gradio as gr
6
+
7
+ import re
8
+ pattern = re.compile("[\n]+")
9
+
10
+ import torch
11
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
12
+ from huggingface_hub import snapshot_download
13
+
14
+ from transformers.generation.utils import logger
15
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
16
+
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--model_name", default="DAMO-NLP-MT/polylm-multialpaca-13b",
19
+ choices=["DAMO-NLP-MT/polylm-multialpaca-13b"], type=str)
20
+ parser.add_argument("--gpu", default="0", type=str)
21
+ args = parser.parse_args()
22
+
23
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
24
+ num_gpus = len(args.gpu.split(","))
25
+
26
+ if ('int8' in args.model_name or 'int4' in args.model_name) and num_gpus > 1:
27
+ raise ValueError("Quantized models do not support model parallel. Please run on a single GPU (e.g., --gpu 0).")
28
+
29
+ logger.setLevel("ERROR")
30
+ warnings.filterwarnings("ignore")
31
+
32
+ model_path = args.model_name
33
+ if not os.path.exists(args.model_name):
34
+ model_path = snapshot_download(args.model_name)
35
+
36
+ config = AutoConfig.from_pretrained(model_path)
37
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
38
+
39
+ if num_gpus > 1:
40
+ print("Waiting for all devices to be ready, it may take a few minutes...")
41
+ with init_empty_weights():
42
+ raw_model = AutoModelForCausalLM.from_config(config)
43
+ raw_model.tie_weights()
44
+ model = load_checkpoint_and_dispatch(
45
+ raw_model, model_path, device_map="auto", no_split_module_classes=["GPT2Block"]
46
+ )
47
+ else:
48
+ print("Loading model files, it may take a few minutes...")
49
+ model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).cuda()
50
+
51
+
52
+ def postprocess(self, y):
53
+ if y is None:
54
+ return []
55
+ for i, (message, response) in enumerate(y):
56
+ y[i] = (
57
+ None if message is None else mdtex2html.convert((message)),
58
+ None if response is None else mdtex2html.convert(response),
59
+ )
60
+ return y
61
+
62
+
63
+ gr.Chatbot.postprocess = postprocess
64
+
65
+
66
+ def parse_text(text):
67
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
68
+ lines = text.split("\n")
69
+ lines = [line for line in lines if line != ""]
70
+ count = 0
71
+ for i, line in enumerate(lines):
72
+ if "```" in line:
73
+ count += 1
74
+ items = line.split('`')
75
+ if count % 2 == 1:
76
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
77
+ else:
78
+ lines[i] = f'<br></code></pre>'
79
+ else:
80
+ if i > 0:
81
+ if count % 2 == 1:
82
+ line = line.replace("`", "\`")
83
+ line = line.replace("<", "&lt;")
84
+ line = line.replace(">", "&gt;")
85
+ line = line.replace(" ", "&nbsp;")
86
+ line = line.replace("*", "&ast;")
87
+ line = line.replace("_", "&lowbar;")
88
+ line = line.replace("-", "&#45;")
89
+ line = line.replace(".", "&#46;")
90
+ line = line.replace("!", "&#33;")
91
+ line = line.replace("(", "&#40;")
92
+ line = line.replace(")", "&#41;")
93
+ line = line.replace("$", "&#36;")
94
+ lines[i] = "<br>"+line
95
+ text = "".join(lines)
96
+ return text
97
+
98
+
99
+ def predict(input, chatbot, max_length, top_p, temperature, history):
100
+ query = input
101
+ query = query.strip()
102
+ query = re.sub(pattern, "\n", query)
103
+
104
+ chatbot.append((query, ""))
105
+ prompt = ""
106
+ for i, (old_query, response) in enumerate(history):
107
+ prompt += f"{old_query}\n\n" + f"{response}\n"
108
+ prompt += f"{query}\n\n"
109
+ inputs = tokenizer(prompt, return_tensors="pt")
110
+ with torch.no_grad():
111
+ outputs = model.generate(
112
+ inputs.input_ids.cuda(),
113
+ attention_mask=inputs.attention_mask.cuda(),
114
+ max_length=max_length,
115
+ do_sample=True,
116
+ top_p=top_p,
117
+ temperature=temperature,
118
+ repetition_penalty=1.02,
119
+ num_return_sequences=1,
120
+ eos_token_id=2,
121
+ early_stopping=True)
122
+ response = tokenizer.decode(
123
+ outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
124
+
125
+ chatbot[-1] = (query, parse_text(response))
126
+ history = history + [(query, response)]
127
+ print("==========================================================================")
128
+ print(f"chatbot is {chatbot}")
129
+ print(f"history is {history}")
130
+ print("==========================================================================")
131
+ return chatbot, history
132
+
133
+
134
+ def reset_user_input():
135
+ return gr.update(value='')
136
+
137
+
138
+ def reset_state():
139
+ return [], []
140
+
141
+
142
+ with gr.Blocks() as demo:
143
+ gr.HTML("""<h1 align="center">欢迎使用 PolyLM 多语言人工智能助手!</h1>""")
144
+
145
+ chatbot = gr.Chatbot()
146
+ with gr.Row():
147
+ with gr.Column(scale=4):
148
+ with gr.Column(scale=12):
149
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
150
+ container=False)
151
+ with gr.Column(min_width=32, scale=1):
152
+ submitBtn = gr.Button("Submit", variant="primary")
153
+ with gr.Column(scale=1):
154
+ emptyBtn = gr.Button("Clear History")
155
+ max_length = gr.Slider(
156
+ 0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
157
+ top_p = gr.Slider(0, 1, value=0.8, step=0.01,
158
+ label="Top P", interactive=True)
159
+ temperature = gr.Slider(
160
+ 0, 1, value=0.7, step=0.01, label="Temperature", interactive=True)
161
+
162
+ history = gr.State([]) # (message, bot_message)
163
+
164
+ submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
165
+ show_progress=True)
166
+ submitBtn.click(reset_user_input, [], [user_input])
167
+
168
+ emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
169
+
170
+ demo.queue().launch(share=False, inbrowser=True)