Hanze Dong commited on
Commit
5516bfe
·
1 Parent(s): 15adf4e
app.py CHANGED
@@ -1,8 +1,224 @@
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved.
4
+ """A simple shell chatbot implemented with lmflow APIs.
5
+ """
6
+ import logging
7
+ import json
8
+ import sys
9
+ import warnings
10
  import gradio as gr
11
+ from dataclasses import dataclass, field
12
+ from transformers import HfArgumentParser
13
+ from typing import Optional
14
 
15
+ from lmflow.datasets.dataset import Dataset
16
+ from lmflow.pipeline.auto_pipeline import AutoPipeline
17
+ from lmflow.models.auto_model import AutoModel
18
+ from lmflow.args import ModelArguments, DatasetArguments, AutoArguments
19
 
20
+ MAX_BOXES = 20
 
21
 
22
+ logging.disable(logging.ERROR)
23
+ warnings.filterwarnings("ignore")
24
+
25
+ title = """
26
+ <h1 align="center">LMFlow-CHAT</h1>
27
+ <link rel="stylesheet" href="/path/to/styles/default.min.css">
28
+ <script src="/path/to/highlight.min.js"></script>
29
+ <script>hljs.highlightAll();</script>
30
+
31
+ <img src="https://optimalscale.github.io/LMFlow/_static/logo.png" alt="LMFlow" style="width: 30%; min-width: 60px; display: block; margin: auto; background-color: transparent;">
32
+
33
+ <p>LMFlow is in extensible, convenient, and efficient toolbox for finetuning large machine learning models, designed to be user-friendly, speedy and reliable, and accessible to the entire community.</p>
34
+
35
+ <p>We have thoroughly tested this toolkit and are pleased to make it available under <a class="reference external" href="https://github.com/OptimalScale/LMFlow">Github</a>.</p>
36
+ """
37
+ css = """
38
+ #user {
39
+ float: right;
40
+ position:relative;
41
+ right:5px;
42
+ width:auto;
43
+ min-height:32px;
44
+ max-width: 60%
45
+ line-height: 32px;
46
+ padding: 2px 8px;
47
+ font-size: 14px;
48
+ background: #9DC284;
49
+ border-radius:5px;
50
+ margin:10px 0px;
51
+ }
52
+
53
+ #chatbot {
54
+ float: left;
55
+ position:relative;
56
+ right:5px;
57
+ width:auto;
58
+ min-height:32px;
59
+ max-width: 60%
60
+ line-height: 32px;
61
+ padding: 2px 8px;
62
+ font-size: 14px;
63
+ background:#7BA7D7;
64
+ border-radius:5px;
65
+ margin:10px 0px;
66
+ }
67
+ """
68
+
69
+
70
+ @dataclass
71
+ class ChatbotArguments:
72
+ prompt_structure: Optional[str] = field(
73
+ default="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.###Human: {input_text}###Assistant:",
74
+ metadata={
75
+ "help": "prompt structure given user's input text"
76
+ },
77
+ )
78
+ end_string: Optional[str] = field(
79
+ default="#",
80
+ metadata={
81
+ "help": "end string mark of the chatbot's output"
82
+ },
83
+ )
84
+ max_new_tokens: Optional[int] = field(
85
+ default=1000,
86
+ metadata={
87
+ "help": "maximum number of generated tokens"
88
+ },
89
+ )
90
+ temperature: Optional[float] = field(
91
+ default=0.7,
92
+ metadata={
93
+ "help": "higher this value, more random the model output"
94
+ },
95
+ )
96
+
97
+
98
+ def main():
99
+ pipeline_name = "inferencer"
100
+ PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name)
101
+
102
+ parser = HfArgumentParser((
103
+ ModelArguments,
104
+ PipelineArguments,
105
+ ChatbotArguments,
106
+ ))
107
+ model_args, pipeline_args, chatbot_args = (
108
+ parser.parse_args_into_dataclasses()
109
+ )
110
+ model_args.model_name_or_path = "pinkmanlove/llama-7b-hf"
111
+ model_args.lora_model_path = "./robin-7b"
112
+
113
+ with open ("configs/ds_config_chatbot.json", "r") as f:
114
+ ds_config = json.load(f)
115
+
116
+ model = AutoModel.get_model(
117
+ model_args,
118
+ tune_strategy='none',
119
+ ds_config=ds_config,
120
+ device=pipeline_args.device,
121
+ )
122
+
123
+ # We don't need input data, we will read interactively from stdin
124
+ data_args = DatasetArguments(dataset_path=None)
125
+ dataset = Dataset(data_args)
126
+
127
+ inferencer = AutoPipeline.get_pipeline(
128
+ pipeline_name=pipeline_name,
129
+ model_args=model_args,
130
+ data_args=data_args,
131
+ pipeline_args=pipeline_args,
132
+ )
133
+
134
+ # Chats
135
+ model_name = model_args.model_name_or_path
136
+ if model_args.lora_model_path is not None:
137
+ model_name += f" + {model_args.lora_model_path}"
138
+
139
+
140
+ # context = (
141
+ # "You are a helpful assistant who follows the given instructions"
142
+ # " unconditionally."
143
+ # )
144
+
145
+
146
+ end_string = chatbot_args.end_string
147
+ prompt_structure = chatbot_args.prompt_structure
148
+
149
+
150
+ token_per_step = 4
151
+
152
+
153
+ def chat_stream( context, query: str, history= None, **kwargs):
154
+ if history is None:
155
+ history = []
156
+
157
+ print_index = 0
158
+ context += prompt_structure.format(input_text=query)
159
+ context = context[-model.get_max_length():]
160
+ input_dataset = dataset.from_dict({
161
+ "type": "text_only",
162
+ "instances": [ { "text": context } ]
163
+ })
164
+ for response, flag_break in inferencer.stream_inference(context=context, model=model, max_new_tokens=chatbot_args.max_new_tokens,
165
+ token_per_step=token_per_step, temperature=chatbot_args.temperature,
166
+ end_string=end_string, input_dataset=input_dataset):
167
+ delta = response[print_index:]
168
+ seq = response
169
+ print_index = len(response)
170
+
171
+ yield delta, history + [(query, seq)]
172
+ if flag_break:
173
+ context += response + "\n"
174
+ break
175
+
176
+
177
+
178
+
179
+ def predict(input, history=None):
180
+ try:
181
+ global context
182
+ context = ""
183
+ except SyntaxError:
184
+ pass
185
+
186
+ if history is None:
187
+ history = []
188
+ for response, history in chat_stream(context, input, history):
189
+ updates = []
190
+ for query, response in history:
191
+ updates.append(gr.update(visible=True, value="" + query))
192
+ updates.append(gr.update(visible=True, value="" + response))
193
+ if len(updates) < MAX_BOXES:
194
+ updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates))
195
+ yield [history] + updates
196
+
197
+
198
+
199
+
200
+
201
+ with gr.Blocks(css=css) as demo:
202
+ gr.HTML(title)
203
+ state = gr.State([])
204
+ text_boxes = []
205
+ for i in range(MAX_BOXES):
206
+ if i % 2 == 0:
207
+ text_boxes.append(gr.Markdown(visible=False, label="Q:", elem_id="user"))
208
+ else:
209
+ text_boxes.append(gr.Markdown(visible=False, label="A:", elem_id="chatbot"))
210
+
211
+ txt = gr.Textbox(
212
+ show_label=False,
213
+ placeholder="Enter text and press send.",
214
+ )
215
+ button = gr.Button("Send")
216
+
217
+ button.click(predict, [txt, state], [state] + text_boxes)
218
+ demo.queue().launch()
219
+
220
+
221
+
222
+
223
+ if __name__ == "__main__":
224
+ main()
configs/ds_config_chatbot.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": false
4
+ },
5
+ "bf16": {
6
+ "enabled": true
7
+ },
8
+ "comms_logger": {
9
+ "enabled": false,
10
+ "verbose": false,
11
+ "prof_all": false,
12
+ "debug": false
13
+ },
14
+ "steps_per_print": 20000000000000000,
15
+ "train_micro_batch_size_per_gpu": 1,
16
+ "wall_clock_breakdown": false
17
+ }
configs/ds_config_zero2.json ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+
11
+ "bf16": {
12
+ "enabled": "auto"
13
+ },
14
+
15
+ "optimizer": {
16
+ "type": "AdamW",
17
+ "params": {
18
+ "lr": "auto",
19
+ "betas": "auto",
20
+ "eps": "auto",
21
+ "weight_decay": "auto"
22
+ }
23
+ },
24
+
25
+ "zero_optimization": {
26
+ "stage": 2,
27
+ "offload_optimizer": {
28
+ "device": "cpu",
29
+ "pin_memory": true
30
+ },
31
+ "allgather_partitions": true,
32
+ "allgather_bucket_size": 2e8,
33
+ "overlap_comm": true,
34
+ "reduce_scatter": true,
35
+ "reduce_bucket_size": 2e8,
36
+ "contiguous_gradients": true
37
+ },
38
+
39
+ "gradient_accumulation_steps": "auto",
40
+ "gradient_clipping": "auto",
41
+ "steps_per_print": 2000,
42
+ "train_batch_size": "auto",
43
+ "train_micro_batch_size_per_gpu": "auto",
44
+ "wall_clock_breakdown": false
45
+ }
configs/ds_config_zero3.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+
11
+ "bf16": {
12
+ "enabled": "auto"
13
+ },
14
+
15
+ "optimizer": {
16
+ "type": "AdamW",
17
+ "params": {
18
+ "lr": "auto",
19
+ "betas": "auto",
20
+ "eps": "auto",
21
+ "weight_decay": "auto"
22
+ }
23
+ },
24
+
25
+ "zero_optimization": {
26
+ "stage": 3,
27
+ "offload_optimizer": {
28
+ "device": "cpu",
29
+ "pin_memory": true
30
+ },
31
+ "offload_param": {
32
+ "device": "cpu",
33
+ "pin_memory": true
34
+ },
35
+ "overlap_comm": true,
36
+ "contiguous_gradients": true,
37
+ "sub_group_size": 1e9,
38
+ "reduce_bucket_size": "auto",
39
+ "stage3_prefetch_bucket_size": "auto",
40
+ "stage3_param_persistence_threshold": "auto",
41
+ "stage3_max_live_parameters": 1e9,
42
+ "stage3_max_reuse_distance": 1e9,
43
+ "stage3_gather_16bit_weights_on_model_save": true
44
+ },
45
+
46
+ "gradient_accumulation_steps": "auto",
47
+ "gradient_clipping": "auto",
48
+ "steps_per_print": 2000,
49
+ "train_batch_size": "auto",
50
+ "train_micro_batch_size_per_gpu": "auto",
51
+ "wall_clock_breakdown": false
52
+ }
configs/ds_config_zero3_for_eval.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "zero_optimization": {
6
+ "stage": 3,
7
+ "offload_optimizer": {
8
+ "device": "cpu",
9
+ "pin_memory": true
10
+ },
11
+ "offload_param": {
12
+ "device": "cpu",
13
+ "pin_memory": true
14
+ },
15
+ "overlap_comm": true,
16
+ "contiguous_gradients": true,
17
+ "sub_group_size": 1e9,
18
+ "reduce_bucket_size": "auto",
19
+ "stage3_prefetch_bucket_size": "auto",
20
+ "stage3_param_persistence_threshold": "auto",
21
+ "stage3_max_live_parameters": 1e9,
22
+ "stage3_max_reuse_distance": 1e9,
23
+ "stage3_gather_16bit_weights_on_model_save": true
24
+ },
25
+
26
+ "steps_per_print": 2000,
27
+ "train_micro_batch_size_per_gpu": 1,
28
+ "wall_clock_breakdown": false
29
+ }