Edit model card

Swallow-MoE-4x7B-lisa

概要

tokyotech-llm/Swallow-7b-hfをベースに、以下の4モデルをgate_mode=randomでMoEし、その後LISAという手法でインストラクションチューニングを施したモデルです。

お試しで作ってみたものなので、性能にはあまり期待しないでください。以下にベンチマーク結果も記載しております。

なお、この学習で使ったLISAの実装には不具合がある可能性が指摘されており、正常に学習できていない可能性があります。

データセット

以下の合計14327件のデータを学習に利用しました。プロンプトフォーマットはAlpacaを利用しています。

なお、ichikara-instructionの利用によりCC-BY-NC-SAを継承します。

学習の設定

RunpodでGPUサーバを借り、A6000x8で学習を行いました。主な学習パラメータは以下の通りです。なお、学習途中でのエラーのため2epochs程度しか学習できておりません。

  • lisa_activated_layers: 8
  • lisa_interval_steps: 13
  • learning_rate: 5e-5
  • num_train_epochs: 約2epochs
  • batch_size: 64
  • max_seq_length: 2048

評価

マージに利用したモデル群と本モデルのjapanese-mt-benchの結果は以下の通りです。(シングルターン)

Swallow-instructよりはスコアが高く、Superswallowよりは低いという何とも言えない結果になっております。 とはいえ、少量のデータセット・たった2epochsの学習でSwallow-instructを超えられているのは一定の成果とも言えるかもしれません。

Model Size Coding Extraction Humanities Math Reasoning Roleplay STEM Writing avg_score
Swallow-7b-instruct-hf 7B 2.0 4.6 5.4 1.7 2.8 5.0 5.9 6.9 4.2875
Superswallow-7b-v0.1 7B 2.0 5.1 7.8 2.1 3.6 6.2 7.3 7.5 5.2000
Superswallow-7b-v0.2 7B 2.2 5.8 6.7 2.5 4.3 5.5 6.6 5.8 4.9250
Superswallow-7b-v0.3 7B 2.1 4.6 8.3 2.1 5.0 6.3 7.7 8.9 5.6250
This model 4x7B 2.0 3.4 7.5 1.9 2.6 5.5 6.3 7.5 4.5875

レーダーチャート

同様に、jsquad(jsquad-1.1-0.3, 2-shots)、jcommonsenseqa(jcommonsenseqa-1.1-0.3, 3-shots)、jnli(jnli-1.3-0.3, 3-shots)、marc_ja(marc_ja-1.1-0.3, 3-shots)結果は以下の通りです。(jsquadは100で割り、それぞれ小数点以下第4位を四捨五入) ここでもSwallow-instructよりはスコアが高く、Superswallowよりは低い結果になっています。なお、こちらは参考として本モデルのインストラクションチューニング前(MoEのみ)のモデルのスコアも載せてあります。

Model Size jsquad(exact_match) jcommonsenseqa(acc) jnli(acc) marc_ja(acc) average
Swallow-7b-instruct-hf 7B 0.757 0.831 0.212 0.945 0.686
Superswallow-7b-v0.1 7B 0.441 0.846 0.374 0.966 0.657
Superswallow-7b-v0.2 7B 0.722 0.846 0.381 0.964 0.728
Superswallow-7b-v0.3 7B 0.721 0.850 0.362 0.964 0.724
This model without fine-tuning 4x7B 0.674 0.809 0.333 0.952 0.692
This model 4x7B 0.741 0.806 0.385 0.948 0.719

評価にはlm-evaluation-harnessを利用しました。

Downloads last month
1
Safetensors
Model size
19.8B params
Tensor type
BF16
·
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Merge of

Datasets used to train Aratako/Swallow-MoE-4x7B-lisa