mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL
本次微调的模型我已经上传到了 HuggingFace Hub 上,大家可以进行尝试。
安装 mlx-lm
pip install mlx-lm
生成 SQL
python -m mlx_lm.generate --model mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL \
--max-tokens 50 \
--prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Which school did Wang Junjian come from?
A: "
SELECT School FROM Students WHERE Name = 'Wang Junjian'
在 MLX 上使用 LoRA 基于 Mistral-7B 微调 Text2SQL(一)
📌 没有使用模型的标注格式生成数据集,导致不能结束,直到生成最大的 Tokens 数量。
这次我们来解决这个问题。
数据集 WikiSQL
修改脚本 mlx-examples/lora/data/wikisql.py
if __name__ == "__main__":
# ......
for dataset, name, size in datasets:
with open(f"data/{name}.jsonl", "w") as fid:
for e, t in zip(range(size), dataset):
"""
t 变量的文本是这样的:
------------------------
<s>table: 1-1058787-1
columns: Approximate Age, Virtues, Psycho Social Crisis, Significant Relationship, Existential Question [ not in citation given ], Examples
Q: How many significant relationships list Will as a virtue?
A: SELECT COUNT Significant Relationship FROM 1-1058787-1 WHERE Virtues = 'Will'</s>
"""
t = t[3:] # 去掉开头的 <s>,因为 tokenizer 会自动添加 <s>
json.dump({"text": t}, fid)
fid.write("\n")
执行脚本 data/wikisql.py
生成数据集。
样本示例
table: 1-10753917-1
columns: Season, Driver, Team, Engine, Poles, Wins, Podiums, Points, Margin of defeat
Q: Which podiums did the alfa romeo team have?
A: SELECT Podiums FROM 1-10753917-1 WHERE Team = 'Alfa Romeo'</s>
微调
LoRA 微调
python lora.py --model mistralai/Mistral-7B-v0.1 \
--train \
--iters 600
Total parameters 7243.436M
Trainable parameters 1.704M
python lora.py --model mistralai/Mistral-7B-v0.1 --train --iters 600 50.58s user 214.71s system 21% cpu 20:26.04 total
微调万分之 2.35 (1.704M / 7243.436M * 10000)的模型参数。
LoRA 微调 600 次迭代,耗时 20 分 26 秒,占用内存 46G。
评估
计算测试集困惑度(PPL)和交叉熵损失(Loss)。
python lora.py --model mistralai/Mistral-7B-v0.1 \
--adapter-file adapters.npz \
--test
Iter 100: Test loss 1.351, Test ppl 3.862.
Iter 200: Test loss 1.327, Test ppl 3.770.
Iter 300: Test loss 1.353, Test ppl 3.869.
Iter 400: Test loss 1.355, Test ppl 3.875.
Iter 500: Test loss 1.294, Test ppl 3.646.
Iter 600: Test loss 1.351, Test ppl 3.863.
Iter | Test loss | Test ppl |
---|---|---|
100 | 1.351 | 3.862 |
200 | 1.327 | 3.770 |
300 | 1.353 | 3.869 |
400 | 1.355 | 3.875 |
500 | 1.294 | 3.646 |
600 | 1.351 | 3.863 |
评估占用内存 26G。
融合(Fuse)
python fuse.py --model mistralai/Mistral-7B-v0.1 \
--adapter-file adapters.npz \
--save-path lora_fused_model
生成 SQL
王军建的姓名是什么?
python -m mlx_lm.generate --model lora_fused_model \
--max-tokens 50 \
--prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: What is Wang Junjian's name?
A: "
SELECT Name FROM students WHERE Name = 'Wang Junjian'
王军建的年龄是多少?
python -m mlx_lm.generate --model lora_fused_model \
--max-tokens 50 \
--prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: How old is Wang Junjian?
A: "
SELECT Age FROM Students WHERE Name = 'Wang Junjian'
王军建来自哪所学校?
python -m mlx_lm.generate --model lora_fused_model \
--max-tokens 50 \
--prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Which school did Wang Junjian come from?
A: "
SELECT School FROM Students WHERE Name = 'Wang Junjian'
查询王军建的姓名、年龄、学校信息。
python -m mlx_lm.generate --model lora_fused_model \
--max-tokens 50 \
--prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query Wang Junjian’s name, age, and school information.
A: "
SELECT Name, Age, School FROM Students WHERE Name = 'Wang Junjian'
查询王军建的所有信息。
python -m mlx_lm.generate --model lora_fused_model \
--max-tokens 50 \
--prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query all information about Wang Junjian.
A: "
SELECT Name FROM students WHERE Name = 'Wang Junjian'
可能训练数据不足。
统计一下九年级有多少学生。
python -m mlx_lm.generate --model lora_fused_model \
--max-tokens 50 \
--prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Count how many students there are in ninth grade.
A: "
SELECT COUNT Name FROM Students WHERE Grade = '9th'
统计一下九年级有多少学生(九年级的值是9)。
python -m mlx_lm.generate --model lora_fused_model \
--max-tokens 50 \
--prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
The value for ninth grade is 9.
Q: Count how many students there are in ninth grade.
A: "
python -m mlx_lm.generate --model lora_fused_model \
--max-tokens 50 \
--prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Count how many students there are in ninth grade.(The value for ninth grade is 9.)
A: "
SELECT COUNT Name FROM students WHERE Grade = 9
附加的提示信息可以轻松添加,不用太在意放置的位置。
上传模型到 HuggingFace Hub
加入 MLX Community 组织
在 MLX Community 组织中创建一个新的模型 mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL
git clone https://huggingface.co/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL
将生成的模型文件(
lora_fused_model
目录下的所有文件)复制到仓库目录下上传模型到 HuggingFace Hub
git add .
git commit -m "Fine tuning Text2SQL based on Mistral-7B using LoRA on MLX"
git push
git push 错误
- 不能 push
错误信息:
Uploading LFS objects: 0% (0/2), 0 B | 0 B/s, done.
batch response: Authorization error.
error: failed to push some refs to 'https://huggingface.co/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL'
解决方法:
vim .git/config
[remote "origin"]
url = https://wangjunjian:write_token@huggingface.co/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL
fetch = +refs/heads/*:refs/remotes/origin/*
- 不能上传大于 5GB 的文件
错误信息:
warning: current Git remote contains credentials
batch response:
You need to configure your repository to enable upload of files > 5GB.
Run "huggingface-cli lfs-enable-largefiles ./path/to/your/repo" and try again.
解决方法:
huggingface-cli longin
huggingface-cli lfs-enable-largefiles /Users/junjian/HuggingFace/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL
参考资料
- Downloads last month
- 19
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.