军舰
commited on
Commit
·
e3678d1
1
Parent(s):
3f8673d
Fine tuning Text2SQL based on Mistral-7B using LoRA on MLX
Browse files- README.md +265 -0
- config.json +24 -0
- special_tokens_map.json +23 -0
- tokenizer.json +0 -0
- tokenizer.model +3 -0
- tokenizer_config.json +42 -0
- weights.00.safetensors +3 -0
README.md
CHANGED
@@ -1,3 +1,268 @@
|
|
1 |
---
|
2 |
license: mit
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
license: mit
|
3 |
---
|
4 |
+
|
5 |
+
## [mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL](https://huggingface.co/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL)
|
6 |
+
|
7 |
+
本次微调的模型我已经上传到了 HuggingFace Hub 上,大家可以直接使用。
|
8 |
+
|
9 |
+
### 安装
|
10 |
+
|
11 |
+
```bash
|
12 |
+
pip install mlx-lm
|
13 |
+
```
|
14 |
+
|
15 |
+
### 生成
|
16 |
+
```
|
17 |
+
python -m mlx_lm.generate --model mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL \
|
18 |
+
--max-tokens 50 \
|
19 |
+
--prompt "table: students
|
20 |
+
columns: Name, Age, School, Grade, Height, Weight
|
21 |
+
Q: Which school did Wang Junjian come from?
|
22 |
+
A: "
|
23 |
+
```
|
24 |
+
```
|
25 |
+
SELECT School FROM Students WHERE Name = 'Wang Junjian'
|
26 |
+
```
|
27 |
+
|
28 |
+
|
29 |
+
## [在 MLX 上使用 LoRA 基于 Mistral-7B 微调 Text2SQL(一)](https://wangjunjian.com/mlx/lora/2024/01/23/Fine-tuning-Text2SQL-based-on-Mistral-7B-using-LoRA-on-MLX-1.html)
|
30 |
+
|
31 |
+
📌 没有使用模型的标注格式生成数据集,导致不能结束,直到生成最大的 Tokens 数量。
|
32 |
+
|
33 |
+
这次我们来解决这个问题。
|
34 |
+
|
35 |
+
## 数据集 WikiSQL
|
36 |
+
|
37 |
+
- [WikiSQL](https://github.com/salesforce/WikiSQL)
|
38 |
+
- [sqllama/sqllama-V0](https://huggingface.co/sqllama/sqllama-V0/blob/main/wikisql.ipynb)
|
39 |
+
|
40 |
+
### 修改脚本 mlx-examples/lora/data/wikisql.py
|
41 |
+
```py
|
42 |
+
if __name__ == "__main__":
|
43 |
+
# ......
|
44 |
+
for dataset, name, size in datasets:
|
45 |
+
with open(f"data/{name}.jsonl", "w") as fid:
|
46 |
+
for e, t in zip(range(size), dataset):
|
47 |
+
"""
|
48 |
+
t 变量的文本是这样的:
|
49 |
+
------------------------
|
50 |
+
<s>table: 1-1058787-1
|
51 |
+
columns: Approximate Age, Virtues, Psycho Social Crisis, Significant Relationship, Existential Question [ not in citation given ], Examples
|
52 |
+
Q: How many significant relationships list Will as a virtue?
|
53 |
+
A: SELECT COUNT Significant Relationship FROM 1-1058787-1 WHERE Virtues = 'Will'</s>
|
54 |
+
"""
|
55 |
+
t = t[3:] # 去掉开头的 <s>,因为 tokenizer 会自动添加 <s>
|
56 |
+
json.dump({"text": t}, fid)
|
57 |
+
fid.write("\n")
|
58 |
+
```
|
59 |
+
|
60 |
+
执行脚本 `data/wikisql.py` 生成数据集。
|
61 |
+
|
62 |
+
### 样本示例
|
63 |
+
|
64 |
+
```json
|
65 |
+
table: 1-10753917-1
|
66 |
+
columns: Season, Driver, Team, Engine, Poles, Wins, Podiums, Points, Margin of defeat
|
67 |
+
Q: Which podiums did the alfa romeo team have?
|
68 |
+
A: SELECT Podiums FROM 1-10753917-1 WHERE Team = 'Alfa Romeo'</s>
|
69 |
+
```
|
70 |
+
|
71 |
+
|
72 |
+
## 微调
|
73 |
+
|
74 |
+
- 预训练模型 [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
|
75 |
+
|
76 |
+
### LoRA 微调
|
77 |
+
|
78 |
+
```bash
|
79 |
+
python lora.py --model mistralai/Mistral-7B-v0.1 \
|
80 |
+
--train \
|
81 |
+
--iters 600
|
82 |
+
```
|
83 |
+
```
|
84 |
+
Total parameters 7243.436M
|
85 |
+
Trainable parameters 1.704M
|
86 |
+
python lora.py --model mistralai/Mistral-7B-v0.1 --train --iters 600 50.58s user 214.71s system 21% cpu 20:26.04 total
|
87 |
+
```
|
88 |
+
|
89 |
+
微调万分之 2.35 (1.704M / 7243.436M * 10000)的模型参数。
|
90 |
+
|
91 |
+
LoRA 微调 600 次迭代,耗时 20 分 26 秒,占用内存 46G。
|
92 |
+
|
93 |
+
## 评估
|
94 |
+
|
95 |
+
计算测试集困惑度(PPL)和交叉熵损失(Loss)。
|
96 |
+
|
97 |
+
```bash
|
98 |
+
python lora.py --model mistralai/Mistral-7B-v0.1 \
|
99 |
+
--adapter-file adapters.npz \
|
100 |
+
--test
|
101 |
+
```
|
102 |
+
```
|
103 |
+
Iter 100: Test loss 1.351, Test ppl 3.862.
|
104 |
+
Iter 200: Test loss 1.327, Test ppl 3.770.
|
105 |
+
Iter 300: Test loss 1.353, Test ppl 3.869.
|
106 |
+
Iter 400: Test loss 1.355, Test ppl 3.875.
|
107 |
+
Iter 500: Test loss 1.294, Test ppl 3.646.
|
108 |
+
Iter 600: Test loss 1.351, Test ppl 3.863.
|
109 |
+
```
|
110 |
+
|
111 |
+
| Iter | Test loss | Test ppl |
|
112 |
+
| :--: | --------: | -------: |
|
113 |
+
| 100 | 1.351 | 3.862 |
|
114 |
+
| 200 | 1.327 | 3.770 |
|
115 |
+
| 300 | 1.353 | 3.869 |
|
116 |
+
| 400 | 1.355 | 3.875 |
|
117 |
+
| 500 | 1.294 | 3.646 |
|
118 |
+
| 600 | 1.351 | 3.863 |
|
119 |
+
|
120 |
+
评估占用内存 26G。
|
121 |
+
|
122 |
+
|
123 |
+
## 融合(Fuse)
|
124 |
+
|
125 |
+
```bash
|
126 |
+
python fuse.py --model mistralai/Mistral-7B-v0.1 \
|
127 |
+
--adapter-file adapters.npz \
|
128 |
+
--save-path lora_fused_model
|
129 |
+
```
|
130 |
+
|
131 |
+
|
132 |
+
## 生成
|
133 |
+
|
134 |
+
### 王军建的姓名是什么?
|
135 |
+
|
136 |
+
```bash
|
137 |
+
python -m mlx_lm.generate --model lora_fused_model \
|
138 |
+
--max-tokens 50 \
|
139 |
+
--prompt "table: students
|
140 |
+
columns: Name, Age, School, Grade, Height, Weight
|
141 |
+
Q: What is Wang Junjian's name?
|
142 |
+
A: "
|
143 |
+
```
|
144 |
+
```
|
145 |
+
SELECT Name FROM students WHERE Name = 'Wang Junjian'
|
146 |
+
```
|
147 |
+
|
148 |
+
### 王军建的年龄是多少?
|
149 |
+
|
150 |
+
```bash
|
151 |
+
python -m mlx_lm.generate --model lora_fused_model \
|
152 |
+
--max-tokens 50 \
|
153 |
+
--prompt "table: students
|
154 |
+
columns: Name, Age, School, Grade, Height, Weight
|
155 |
+
Q: How old is Wang Junjian?
|
156 |
+
A: "
|
157 |
+
```
|
158 |
+
```
|
159 |
+
SELECT Age FROM Students WHERE Name = 'Wang Junjian'
|
160 |
+
```
|
161 |
+
|
162 |
+
### 王军建来自哪所学校?
|
163 |
+
|
164 |
+
```bash
|
165 |
+
python -m mlx_lm.generate --model lora_fused_model \
|
166 |
+
--max-tokens 50 \
|
167 |
+
--prompt "table: students
|
168 |
+
columns: Name, Age, School, Grade, Height, Weight
|
169 |
+
Q: Which school did Wang Junjian come from?
|
170 |
+
A: "
|
171 |
+
```
|
172 |
+
```
|
173 |
+
SELECT School FROM Students WHERE Name = 'Wang Junjian'
|
174 |
+
```
|
175 |
+
|
176 |
+
### 查���王军建的姓名、年龄、学校信息。
|
177 |
+
|
178 |
+
```bash
|
179 |
+
python -m mlx_lm.generate --model lora_fused_model \
|
180 |
+
--max-tokens 50 \
|
181 |
+
--prompt "table: students
|
182 |
+
columns: Name, Age, School, Grade, Height, Weight
|
183 |
+
Q: Query Wang Junjian’s name, age, and school information.
|
184 |
+
A: "
|
185 |
+
```
|
186 |
+
```
|
187 |
+
SELECT Name, Age, School FROM Students WHERE Name = 'Wang Junjian'
|
188 |
+
```
|
189 |
+
|
190 |
+
### 查询王军建的所有信息。
|
191 |
+
|
192 |
+
```bash
|
193 |
+
python -m mlx_lm.generate --model lora_fused_model \
|
194 |
+
--max-tokens 50 \
|
195 |
+
--prompt "table: students
|
196 |
+
columns: Name, Age, School, Grade, Height, Weight
|
197 |
+
Q: Query all information about Wang Junjian.
|
198 |
+
A: "
|
199 |
+
```
|
200 |
+
```
|
201 |
+
SELECT Name FROM students WHERE Name = 'Wang Junjian'
|
202 |
+
```
|
203 |
+
|
204 |
+
可能训练数据不足。
|
205 |
+
|
206 |
+
### 统计一下九年级有多少学生。
|
207 |
+
|
208 |
+
```bash
|
209 |
+
python -m mlx_lm.generate --model lora_fused_model \
|
210 |
+
--max-tokens 50 \
|
211 |
+
--prompt "table: students
|
212 |
+
columns: Name, Age, School, Grade, Height, Weight
|
213 |
+
Q: Count how many students there are in ninth grade.
|
214 |
+
A: "
|
215 |
+
```
|
216 |
+
```
|
217 |
+
SELECT COUNT Name FROM Students WHERE Grade = '9th'
|
218 |
+
```
|
219 |
+
|
220 |
+
### 统计一下九年级有多少学生(九年级的值是9)。
|
221 |
+
|
222 |
+
```bash
|
223 |
+
python -m mlx_lm.generate --model lora_fused_model \
|
224 |
+
--max-tokens 50 \
|
225 |
+
--prompt "table: students
|
226 |
+
columns: Name, Age, School, Grade, Height, Weight
|
227 |
+
The value for ninth grade is 9.
|
228 |
+
Q: Count how many students there are in ninth grade.
|
229 |
+
A: "
|
230 |
+
```
|
231 |
+
|
232 |
+
```bash
|
233 |
+
python -m mlx_lm.generate --model lora_fused_model \
|
234 |
+
--max-tokens 50 \
|
235 |
+
--prompt "table: students
|
236 |
+
columns: Name, Age, School, Grade, Height, Weight
|
237 |
+
Q: Count how many students there are in ninth grade.(The value for ninth grade is 9.)
|
238 |
+
A: "
|
239 |
+
```
|
240 |
+
|
241 |
+
```
|
242 |
+
SELECT COUNT Name FROM students WHERE Grade = 9
|
243 |
+
```
|
244 |
+
|
245 |
+
附加的提示信息可以轻松添加,不用太在意放置的位置。
|
246 |
+
|
247 |
+
## 上传模型
|
248 |
+
|
249 |
+
```bash
|
250 |
+
python -m mlx_lm.convert \
|
251 |
+
--mlx-path lora_fused_model/ \
|
252 |
+
--quantize \
|
253 |
+
--upload-repo mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL
|
254 |
+
```
|
255 |
+
|
256 |
+
|
257 |
+
## 参考资料
|
258 |
+
- [MLX Community](https://huggingface.co/mlx-community)
|
259 |
+
- [Fine-Tuning with LoRA or QLoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora)
|
260 |
+
- [Generate Text with LLMs and MLX](https://github.com/ml-explore/mlx-examples/tree/main/llms)
|
261 |
+
- [Awesome Text2SQL](https://github.com/eosphoros-ai/Awesome-Text2SQL)
|
262 |
+
- [Awesome Text2SQL(中文)](https://github.com/eosphoros-ai/Awesome-Text2SQL/blob/main/README.zh.md)
|
263 |
+
- [Mistral AI](https://huggingface.co/mistralai)
|
264 |
+
- [A Beginner’s Guide to Fine-Tuning Mistral 7B Instruct Model](https://adithyask.medium.com/a-beginners-guide-to-fine-tuning-mistral-7b-instruct-model-0f39647b20fe)
|
265 |
+
- [Mistral Instruct 7B Finetuning on MedMCQA Dataset](https://saankhya.medium.com/mistral-instruct-7b-finetuning-on-medmcqa-dataset-6ec2532b1ff1)
|
266 |
+
- [Fine-tuning Mistral on your own data](https://github.com/brevdev/notebooks/blob/main/mistral-finetune-own-data.ipynb)
|
267 |
+
- [mlx-examples llms Mistral](https://github.com/ml-explore/mlx-examples/blob/main/llms/mistral/README.md)
|
268 |
+
|
config.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"MistralForCausalLM"
|
4 |
+
],
|
5 |
+
"bos_token_id": 1,
|
6 |
+
"eos_token_id": 2,
|
7 |
+
"hidden_act": "silu",
|
8 |
+
"hidden_size": 4096,
|
9 |
+
"initializer_range": 0.02,
|
10 |
+
"intermediate_size": 14336,
|
11 |
+
"max_position_embeddings": 32768,
|
12 |
+
"model_type": "mistral",
|
13 |
+
"num_attention_heads": 32,
|
14 |
+
"num_hidden_layers": 32,
|
15 |
+
"num_key_value_heads": 8,
|
16 |
+
"rms_norm_eps": 1e-05,
|
17 |
+
"rope_theta": 10000.0,
|
18 |
+
"sliding_window": 4096,
|
19 |
+
"tie_word_embeddings": false,
|
20 |
+
"torch_dtype": "bfloat16",
|
21 |
+
"transformers_version": "4.34.0.dev0",
|
22 |
+
"use_cache": true,
|
23 |
+
"vocab_size": 32000
|
24 |
+
}
|
special_tokens_map.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "</s>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"unk_token": {
|
17 |
+
"content": "<unk>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
}
|
23 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055
|
3 |
+
size 493443
|
tokenizer_config.json
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_bos_token": true,
|
3 |
+
"add_eos_token": false,
|
4 |
+
"added_tokens_decoder": {
|
5 |
+
"0": {
|
6 |
+
"content": "<unk>",
|
7 |
+
"lstrip": false,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false,
|
11 |
+
"special": true
|
12 |
+
},
|
13 |
+
"1": {
|
14 |
+
"content": "<s>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": false,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false,
|
19 |
+
"special": true
|
20 |
+
},
|
21 |
+
"2": {
|
22 |
+
"content": "</s>",
|
23 |
+
"lstrip": false,
|
24 |
+
"normalized": false,
|
25 |
+
"rstrip": false,
|
26 |
+
"single_word": false,
|
27 |
+
"special": true
|
28 |
+
}
|
29 |
+
},
|
30 |
+
"additional_special_tokens": [],
|
31 |
+
"bos_token": "<s>",
|
32 |
+
"clean_up_tokenization_spaces": false,
|
33 |
+
"eos_token": "</s>",
|
34 |
+
"legacy": true,
|
35 |
+
"model_max_length": 1000000000000000019884624838656,
|
36 |
+
"pad_token": null,
|
37 |
+
"sp_model_kwargs": {},
|
38 |
+
"spaces_between_special_tokens": false,
|
39 |
+
"tokenizer_class": "LlamaTokenizer",
|
40 |
+
"unk_token": "<unk>",
|
41 |
+
"use_default_system_prompt": false
|
42 |
+
}
|
weights.00.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4c03db3218a7f5af63da4226ecc751b87451068ba8421bd3ccb28f6ee87860e2
|
3 |
+
size 14483498189
|