|
--- |
|
license: cc-by-nc-4.0 |
|
--- |
|
## License |
|
|
|
非商用ライセンスで公開します。 |
|
|
|
## Chat Vector |
|
|
|
``` |
|
Tora-7B-v0.1 = NTQAI/chatntq-ja-7b-v1.0 + (openchat/openchat-3.5-0106 - mistralai/Mistral-7B-v0.1) |
|
``` |
|
|
|
## 実装 |
|
|
|
@jovyan様の実装を参考に下記のコードでモデルを作成しました。 |
|
|
|
```python |
|
import torch |
|
from transformers import AutoModelForCausalLM |
|
|
|
|
|
def build_chat_vector_model( |
|
base_model_name, |
|
inst_model_name, |
|
target_model_name, |
|
skip_layers, |
|
): |
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
base_model_name, |
|
torch_dtype=torch.bfloat16, |
|
device_map="cpu", |
|
) |
|
inst_model = AutoModelForCausalLM.from_pretrained( |
|
inst_model_name, |
|
torch_dtype=torch.bfloat16, |
|
device_map="cpu", |
|
) |
|
|
|
target_model = AutoModelForCausalLM.from_pretrained( |
|
target_model_name, |
|
torch_dtype=torch.bfloat16, |
|
device_map="cuda", |
|
) |
|
|
|
# 英語ベースモデル |
|
for k, v in base_model.state_dict().items(): |
|
print(k, v.shape) |
|
|
|
# 日本語継続事前学習モデル |
|
for k, v in target_model.state_dict().items(): |
|
print(k, v.shape) |
|
|
|
# 除外対象 |
|
skip_layers = ["model.embed_tokens.weight", "lm_head.weight"] |
|
|
|
for k, v in target_model.state_dict().items(): |
|
# layernormも除外 |
|
if (k in skip_layers) or ("layernorm" in k): |
|
continue |
|
chat_vector = inst_model.state_dict()[k] - base_model.state_dict()[k] |
|
new_v = v + chat_vector.to(v.device) |
|
v.copy_(new_v) |
|
|
|
target_model.save_pretrained("./chat_model") |
|
|
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
base_model_name = "mistralai/Mistral-7B-v0.1" |
|
inst_model_name = "openchat/openchat-3.5-0106" |
|
target_model_name = "NTQAI/chatntq-ja-7b-v1.0" |
|
|
|
skip_layers = ["model.embed_tokens.weight", "lm_head.weight"] |
|
|
|
build_chat_vector_model( |
|
base_model_name=base_model_name, |
|
inst_model_name=inst_model_name, |
|
target_model_name=target_model_name, |
|
skip_layers=skip_layers |
|
) |
|
|
|
``` |
|
|
|
## ベンチマーク (Japanese MT bench) |
|
- single turnのみ評価 |
|
|
|
|model|category|score|ver| |
|
|:---|:---|:---|:---| |
|
|Tora-7B-v0.1|Writing|5.4|single-turn| |
|
|Tora-7B-v0.1|Roleplay|6.6|single-turn| |
|
|Tora-7B-v0.1|Reasoning|7.3|single-turn| |
|
|Tora-7B-v0.1|Math|3.5|single-turn| |
|
|Tora-7B-v0.1|Coding|4.7|single-turn| |
|
|Tora-7B-v0.1|Extraction|6.3|single-turn| |
|
|Tora-7B-v0.1|STEM|7.2|single-turn| |
|
|Tora-7B-v0.1|Humanities|8.5|single-turn| |
|
|
|
![image/png](https://cdn-uploads.huggingface.co/production/uploads/651e3f30ca333f3c8df692b8/tuFTNH1t65lqgpnS3TuiA.png) |
|
|
|
## ベンチマーク (Nejumi leaderboard) |
|
|
|
- runs.summary["mtbench_leaderboard_table"]の結果を転記 |
|
|
|
|model|category|score| |
|
|:---|:---|:---| |
|
|Tora-7B-v0.1|Writing|7.55| |
|
|Tora-7B-v0.1|Roleplay|7.5| |
|
|Tora-7B-v0.1|Reasoning|4.35| |
|
|Tora-7B-v0.1|Math|2.95| |
|
|Tora-7B-v0.1|Coding|3.7| |
|
|Tora-7B-v0.1|Extraction|7.0| |
|
|Tora-7B-v0.1|STEM|7.85| |
|
|Tora-7B-v0.1|Humanities|9.65| |
|
|Tora-7B-v0.1|AVG_mtbench|6.319| |
|
|
|
- runs.summary["jaster_radar_table"]の結果を転記 |
|
|
|
|model|category|score| |
|
|:---|:---|:---| |
|
|Tora-7B-v0.1|NLI|0.588| |
|
|Tora-7B-v0.1|QA|0.1708| |
|
|Tora-7B-v0.1|RC|0.798| |
|
|Tora-7B-v0.1|MC|0.25| |
|
|Tora-7B-v0.1|EL|0.0| |
|
|Tora-7B-v0.1|FA|0.1359| |
|
|Tora-7B-v0.1|MR|0.2| |
|
|
|
|
|
## 謝辞 |
|
|
|
ChatVectorの記事を執筆してくださった@jovyan様に深くお礼申し上げます。 |
|
|
|
## 参考 |
|
|
|
[Chat Vectorを使って日本語LLMをチャットモデルに改造する](https://qiita.com/jovyan/items/ee6affa5ee5bdaada6b4) |