File size: 2,853 Bytes
bf5b1b9
d4c3701
bf5b1b9
d4c3701
 
9310bb8
d4c3701
f0dfae6
bf5b1b9
f0dfae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf5b1b9
 
 
 
 
 
 
 
 
 
 
 
 
f0dfae6
 
 
 
 
 
 
d4c3701
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
---
license: cc-by-nc-4.0
---
## License

非商用ライセンスで公開します。

## Chat Vector

```
Tora-7B-v0.1 = NTQAI/chatntq-ja-7b-v1.0 + (openchat/openchat-3.5-0106 - mistralai/Mistral-7B-v0.1)
```

## 実装

@jovyan様の実装を参考に下記のコードでモデルを作成しました。

```python
import torch
from transformers import AutoModelForCausalLM


def build_chat_vector_model(
    base_model_name,
    inst_model_name,
    target_model_name,
    skip_layers,
    ):

    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        torch_dtype=torch.bfloat16,
        device_map="cpu",
    )
    inst_model = AutoModelForCausalLM.from_pretrained(
        inst_model_name,
        torch_dtype=torch.bfloat16,
        device_map="cpu",
    )

    target_model = AutoModelForCausalLM.from_pretrained(
        target_model_name,
        torch_dtype=torch.bfloat16,
        device_map="cuda",
    )

    # 英語ベースモデル
    for k, v in base_model.state_dict().items():
        print(k, v.shape)

    # 日本語継続事前学習モデル
    for k, v in target_model.state_dict().items():
        print(k, v.shape)

    # 除外対象
    skip_layers = ["model.embed_tokens.weight", "lm_head.weight"]

    for k, v in target_model.state_dict().items():
        # layernormも除外
        if (k in skip_layers) or ("layernorm" in k):
            continue
        chat_vector = inst_model.state_dict()[k] - base_model.state_dict()[k]
        new_v = v + chat_vector.to(v.device)
        v.copy_(new_v)

    target_model.save_pretrained("./chat_model")

    return


if __name__ == '__main__':

    base_model_name = "mistralai/Mistral-7B-v0.1"
    inst_model_name = "openchat/openchat-3.5-0106"
    target_model_name = "NTQAI/chatntq-ja-7b-v1.0"

    skip_layers = ["model.embed_tokens.weight", "lm_head.weight"]

    build_chat_vector_model(
        base_model_name=base_model_name,
        inst_model_name=inst_model_name,
        target_model_name=target_model_name,
        skip_layers=skip_layers
    )

```

## ベンチマーク (Japanese MT bench)

|model|category|score|ver|
|:---|:---|:---|:---|
|Tora-7B-v0.1|Writing|5.4|single-turn|
|Tora-7B-v0.1|Roleplay|6.6|single-turn|
|Tora-7B-v0.1|Reasoning|7.3|single-turn|
|Tora-7B-v0.1|Math|3.5|single-turn|
|Tora-7B-v0.1|Coding|4.7|single-turn|
|Tora-7B-v0.1|Extraction|6.3|single-turn|
|Tora-7B-v0.1|STEM|7.2|single-turn|
|Tora-7B-v0.1|Humanities|8.5|single-turn|

![image/png](https://cdn-uploads.huggingface.co/production/uploads/651e3f30ca333f3c8df692b8/tuFTNH1t65lqgpnS3TuiA.png)

## 謝辞

ChatVectorの記事を執筆してくださった@jovyan様に深くお礼申し上げます。

## 参考

[Chat Vectorを使って日本語LLMをチャットモデルに改造する](https://qiita.com/jovyan/items/ee6affa5ee5bdaada6b4)