军舰 commited on
Commit
e3678d1
·
1 Parent(s): 3f8673d

Fine tuning Text2SQL based on Mistral-7B using LoRA on MLX

Browse files
README.md CHANGED
@@ -1,3 +1,268 @@
1
  ---
2
  license: mit
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  ---
4
+
5
+ ## [mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL](https://huggingface.co/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL)
6
+
7
+ 本次微调的模型我已经上传到了 HuggingFace Hub 上,大家可以直接使用。
8
+
9
+ ### 安装
10
+
11
+ ```bash
12
+ pip install mlx-lm
13
+ ```
14
+
15
+ ### 生成
16
+ ```
17
+ python -m mlx_lm.generate --model mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL \
18
+ --max-tokens 50 \
19
+ --prompt "table: students
20
+ columns: Name, Age, School, Grade, Height, Weight
21
+ Q: Which school did Wang Junjian come from?
22
+ A: "
23
+ ```
24
+ ```
25
+ SELECT School FROM Students WHERE Name = 'Wang Junjian'
26
+ ```
27
+
28
+
29
+ ## [在 MLX 上使用 LoRA 基于 Mistral-7B 微调 Text2SQL(一)](https://wangjunjian.com/mlx/lora/2024/01/23/Fine-tuning-Text2SQL-based-on-Mistral-7B-using-LoRA-on-MLX-1.html)
30
+
31
+ 📌 没有使用模型的标注格式生成数据集,导致不能结束,直到生成最大的 Tokens 数量。
32
+
33
+ 这次我们来解决这个问题。
34
+
35
+ ## 数据集 WikiSQL
36
+
37
+ - [WikiSQL](https://github.com/salesforce/WikiSQL)
38
+ - [sqllama/sqllama-V0](https://huggingface.co/sqllama/sqllama-V0/blob/main/wikisql.ipynb)
39
+
40
+ ### 修改脚本 mlx-examples/lora/data/wikisql.py
41
+ ```py
42
+ if __name__ == "__main__":
43
+ # ......
44
+ for dataset, name, size in datasets:
45
+ with open(f"data/{name}.jsonl", "w") as fid:
46
+ for e, t in zip(range(size), dataset):
47
+ """
48
+ t 变量的文本是这样的:
49
+ ------------------------
50
+ <s>table: 1-1058787-1
51
+ columns: Approximate Age, Virtues, Psycho Social Crisis, Significant Relationship, Existential Question [ not in citation given ], Examples
52
+ Q: How many significant relationships list Will as a virtue?
53
+ A: SELECT COUNT Significant Relationship FROM 1-1058787-1 WHERE Virtues = 'Will'</s>
54
+ """
55
+ t = t[3:] # 去掉开头的 <s>,因为 tokenizer 会自动添加 <s>
56
+ json.dump({"text": t}, fid)
57
+ fid.write("\n")
58
+ ```
59
+
60
+ 执行脚本 `data/wikisql.py` 生成数据集。
61
+
62
+ ### 样本示例
63
+
64
+ ```json
65
+ table: 1-10753917-1
66
+ columns: Season, Driver, Team, Engine, Poles, Wins, Podiums, Points, Margin of defeat
67
+ Q: Which podiums did the alfa romeo team have?
68
+ A: SELECT Podiums FROM 1-10753917-1 WHERE Team = 'Alfa Romeo'</s>
69
+ ```
70
+
71
+
72
+ ## 微调
73
+
74
+ - 预训练模型 [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
75
+
76
+ ### LoRA 微调
77
+
78
+ ```bash
79
+ python lora.py --model mistralai/Mistral-7B-v0.1 \
80
+ --train \
81
+ --iters 600
82
+ ```
83
+ ```
84
+ Total parameters 7243.436M
85
+ Trainable parameters 1.704M
86
+ python lora.py --model mistralai/Mistral-7B-v0.1 --train --iters 600 50.58s user 214.71s system 21% cpu 20:26.04 total
87
+ ```
88
+
89
+ 微调万分之 2.35 (1.704M / 7243.436M * 10000)的模型参数。
90
+
91
+ LoRA 微调 600 次迭代,耗时 20 分 26 秒,占用内存 46G。
92
+
93
+ ## 评估
94
+
95
+ 计算测试集困惑度(PPL)和交叉熵损失(Loss)。
96
+
97
+ ```bash
98
+ python lora.py --model mistralai/Mistral-7B-v0.1 \
99
+ --adapter-file adapters.npz \
100
+ --test
101
+ ```
102
+ ```
103
+ Iter 100: Test loss 1.351, Test ppl 3.862.
104
+ Iter 200: Test loss 1.327, Test ppl 3.770.
105
+ Iter 300: Test loss 1.353, Test ppl 3.869.
106
+ Iter 400: Test loss 1.355, Test ppl 3.875.
107
+ Iter 500: Test loss 1.294, Test ppl 3.646.
108
+ Iter 600: Test loss 1.351, Test ppl 3.863.
109
+ ```
110
+
111
+ | Iter | Test loss | Test ppl |
112
+ | :--: | --------: | -------: |
113
+ | 100 | 1.351 | 3.862 |
114
+ | 200 | 1.327 | 3.770 |
115
+ | 300 | 1.353 | 3.869 |
116
+ | 400 | 1.355 | 3.875 |
117
+ | 500 | 1.294 | 3.646 |
118
+ | 600 | 1.351 | 3.863 |
119
+
120
+ 评估占用内存 26G。
121
+
122
+
123
+ ## 融合(Fuse)
124
+
125
+ ```bash
126
+ python fuse.py --model mistralai/Mistral-7B-v0.1 \
127
+ --adapter-file adapters.npz \
128
+ --save-path lora_fused_model
129
+ ```
130
+
131
+
132
+ ## 生成
133
+
134
+ ### 王军建的姓名是什么?
135
+
136
+ ```bash
137
+ python -m mlx_lm.generate --model lora_fused_model \
138
+ --max-tokens 50 \
139
+ --prompt "table: students
140
+ columns: Name, Age, School, Grade, Height, Weight
141
+ Q: What is Wang Junjian's name?
142
+ A: "
143
+ ```
144
+ ```
145
+ SELECT Name FROM students WHERE Name = 'Wang Junjian'
146
+ ```
147
+
148
+ ### 王军建的年龄是多少?
149
+
150
+ ```bash
151
+ python -m mlx_lm.generate --model lora_fused_model \
152
+ --max-tokens 50 \
153
+ --prompt "table: students
154
+ columns: Name, Age, School, Grade, Height, Weight
155
+ Q: How old is Wang Junjian?
156
+ A: "
157
+ ```
158
+ ```
159
+ SELECT Age FROM Students WHERE Name = 'Wang Junjian'
160
+ ```
161
+
162
+ ### 王军建来自哪所学校?
163
+
164
+ ```bash
165
+ python -m mlx_lm.generate --model lora_fused_model \
166
+ --max-tokens 50 \
167
+ --prompt "table: students
168
+ columns: Name, Age, School, Grade, Height, Weight
169
+ Q: Which school did Wang Junjian come from?
170
+ A: "
171
+ ```
172
+ ```
173
+ SELECT School FROM Students WHERE Name = 'Wang Junjian'
174
+ ```
175
+
176
+ ### 查���王军建的姓名、年龄、学校信息。
177
+
178
+ ```bash
179
+ python -m mlx_lm.generate --model lora_fused_model \
180
+ --max-tokens 50 \
181
+ --prompt "table: students
182
+ columns: Name, Age, School, Grade, Height, Weight
183
+ Q: Query Wang Junjian’s name, age, and school information.
184
+ A: "
185
+ ```
186
+ ```
187
+ SELECT Name, Age, School FROM Students WHERE Name = 'Wang Junjian'
188
+ ```
189
+
190
+ ### 查询王军建的所有信息。
191
+
192
+ ```bash
193
+ python -m mlx_lm.generate --model lora_fused_model \
194
+ --max-tokens 50 \
195
+ --prompt "table: students
196
+ columns: Name, Age, School, Grade, Height, Weight
197
+ Q: Query all information about Wang Junjian.
198
+ A: "
199
+ ```
200
+ ```
201
+ SELECT Name FROM students WHERE Name = 'Wang Junjian'
202
+ ```
203
+
204
+ 可能训练数据不足。
205
+
206
+ ### 统计一下九年级有多少学生。
207
+
208
+ ```bash
209
+ python -m mlx_lm.generate --model lora_fused_model \
210
+ --max-tokens 50 \
211
+ --prompt "table: students
212
+ columns: Name, Age, School, Grade, Height, Weight
213
+ Q: Count how many students there are in ninth grade.
214
+ A: "
215
+ ```
216
+ ```
217
+ SELECT COUNT Name FROM Students WHERE Grade = '9th'
218
+ ```
219
+
220
+ ### 统计一下九年级有多少学生(九年级的值是9)。
221
+
222
+ ```bash
223
+ python -m mlx_lm.generate --model lora_fused_model \
224
+ --max-tokens 50 \
225
+ --prompt "table: students
226
+ columns: Name, Age, School, Grade, Height, Weight
227
+ The value for ninth grade is 9.
228
+ Q: Count how many students there are in ninth grade.
229
+ A: "
230
+ ```
231
+
232
+ ```bash
233
+ python -m mlx_lm.generate --model lora_fused_model \
234
+ --max-tokens 50 \
235
+ --prompt "table: students
236
+ columns: Name, Age, School, Grade, Height, Weight
237
+ Q: Count how many students there are in ninth grade.(The value for ninth grade is 9.)
238
+ A: "
239
+ ```
240
+
241
+ ```
242
+ SELECT COUNT Name FROM students WHERE Grade = 9
243
+ ```
244
+
245
+ 附加的提示信息可以轻松添加,不用太在意放置的位置。
246
+
247
+ ## 上传模型
248
+
249
+ ```bash
250
+ python -m mlx_lm.convert \
251
+ --mlx-path lora_fused_model/ \
252
+ --quantize \
253
+ --upload-repo mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL
254
+ ```
255
+
256
+
257
+ ## 参考资料
258
+ - [MLX Community](https://huggingface.co/mlx-community)
259
+ - [Fine-Tuning with LoRA or QLoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora)
260
+ - [Generate Text with LLMs and MLX](https://github.com/ml-explore/mlx-examples/tree/main/llms)
261
+ - [Awesome Text2SQL](https://github.com/eosphoros-ai/Awesome-Text2SQL)
262
+ - [Awesome Text2SQL(中文)](https://github.com/eosphoros-ai/Awesome-Text2SQL/blob/main/README.zh.md)
263
+ - [Mistral AI](https://huggingface.co/mistralai)
264
+ - [A Beginner’s Guide to Fine-Tuning Mistral 7B Instruct Model](https://adithyask.medium.com/a-beginners-guide-to-fine-tuning-mistral-7b-instruct-model-0f39647b20fe)
265
+ - [Mistral Instruct 7B Finetuning on MedMCQA Dataset](https://saankhya.medium.com/mistral-instruct-7b-finetuning-on-medmcqa-dataset-6ec2532b1ff1)
266
+ - [Fine-tuning Mistral on your own data](https://github.com/brevdev/notebooks/blob/main/mistral-finetune-own-data.ipynb)
267
+ - [mlx-examples llms Mistral](https://github.com/ml-explore/mlx-examples/blob/main/llms/mistral/README.md)
268
+
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MistralForCausalLM"
4
+ ],
5
+ "bos_token_id": 1,
6
+ "eos_token_id": 2,
7
+ "hidden_act": "silu",
8
+ "hidden_size": 4096,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 14336,
11
+ "max_position_embeddings": 32768,
12
+ "model_type": "mistral",
13
+ "num_attention_heads": 32,
14
+ "num_hidden_layers": 32,
15
+ "num_key_value_heads": 8,
16
+ "rms_norm_eps": 1e-05,
17
+ "rope_theta": 10000.0,
18
+ "sliding_window": 4096,
19
+ "tie_word_embeddings": false,
20
+ "torch_dtype": "bfloat16",
21
+ "transformers_version": "4.34.0.dev0",
22
+ "use_cache": true,
23
+ "vocab_size": 32000
24
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
3
+ size 493443
tokenizer_config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ }
29
+ },
30
+ "additional_special_tokens": [],
31
+ "bos_token": "<s>",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "</s>",
34
+ "legacy": true,
35
+ "model_max_length": 1000000000000000019884624838656,
36
+ "pad_token": null,
37
+ "sp_model_kwargs": {},
38
+ "spaces_between_special_tokens": false,
39
+ "tokenizer_class": "LlamaTokenizer",
40
+ "unk_token": "<unk>",
41
+ "use_default_system_prompt": false
42
+ }
weights.00.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c03db3218a7f5af63da4226ecc751b87451068ba8421bd3ccb28f6ee87860e2
3
+ size 14483498189