Tora-7B-v0.1 / README.md
ryota39's picture
Update README.md
d4c3701 verified
|
raw
history blame
No virus
2.86 kB
metadata
license: cc-by-nc-4.0

License

非商用ライセンスでの公開となります。

Chat Vector

Tora-7B-v0.1 = NTQAI/chatntq-ja-7b-v1.0 + (openchat/openchat-3.5-0106 - mistralai/Mistral-7B-v0.1)

実装

@jovyan様の実装を参考に下記のコードでモデルを作成しました。

import torch
from transformers import AutoModelForCausalLM


def build_chat_vector_model(
    base_model_name,
    inst_model_name,
    target_model_name,
    skip_layers,
    ):

    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        torch_dtype=torch.bfloat16,
        device_map="cpu",
    )
    inst_model = AutoModelForCausalLM.from_pretrained(
        inst_model_name,
        torch_dtype=torch.bfloat16,
        device_map="cpu",
    )

    target_model = AutoModelForCausalLM.from_pretrained(
        target_model_name,
        torch_dtype=torch.bfloat16,
        device_map="cuda",
    )

    # 英語ベースモデル
    for k, v in base_model.state_dict().items():
        print(k, v.shape)

    # 日本語継続事前学習モデル
    for k, v in target_model.state_dict().items():
        print(k, v.shape)

    # 除外対象
    skip_layers = ["model.embed_tokens.weight", "lm_head.weight"]

    for k, v in target_model.state_dict().items():
        # layernormも除外
        if (k in skip_layers) or ("layernorm" in k):
            continue
        chat_vector = inst_model.state_dict()[k] - base_model.state_dict()[k]
        new_v = v + chat_vector.to(v.device)
        v.copy_(new_v)

    target_model.save_pretrained("./chat_model")

    return


if __name__ == '__main__':

    base_model_name = "mistralai/Mistral-7B-v0.1"
    inst_model_name = "openchat/openchat-3.5-0106"
    target_model_name = "NTQAI/chatntq-ja-7b-v1.0"

    skip_layers = ["model.embed_tokens.weight", "lm_head.weight"]

    build_chat_vector_model(
        base_model_name=base_model_name,
        inst_model_name=inst_model_name,
        target_model_name=target_model_name,
        skip_layers=skip_layers
    )

ベンチマーク (Japanese MT bench)

model category score ver
Tora-7B-v0.1 Writing 5.4 single-turn
Tora-7B-v0.1 Roleplay 6.6 single-turn
Tora-7B-v0.1 Reasoning 7.3 single-turn
Tora-7B-v0.1 Math 3.5 single-turn
Tora-7B-v0.1 Coding 4.7 single-turn
Tora-7B-v0.1 Extraction 6.3 single-turn
Tora-7B-v0.1 STEM 7.2 single-turn
Tora-7B-v0.1 Humanities 8.5 single-turn

image/png

謝辞

ChatVectorの記事を執筆してくださった@jovyan様に深くお礼申し上げます。

参考

Chat Vectorを使って日本語LLMをチャットモデルに改造する