splade-ja-310m-v2

sbintuitions/modernbert-ja-310m をベースとした日本語 SPLADE (Learned Sparse Retrieval) モデルです。大規模な汎用コーパスで学習した汎用ベースモデルを起点に、行政 FAQ・クイズ・学術論文・長文といった特定ドメインのデータで短期ファインチューニングを行い、検索精度を高めています。

モデルの概要

項目 内容
ベースモデル sbintuitions/modernbert-ja-310m
アーキテクチャ Transformer (fill-mask) + SpladePooling (max)
パラメータ数 約 315M
最大入力長 8,192 tokens (本モデルの訓練は 1,024 tokens で実施)
語彙サイズ (= 出力次元数) 102,400
トークナイザ SentencePiece (ModernBERT-Ja の語彙をそのまま使用)
主要フレームワーク sentence-transformers >= 5.4 (SparseEncoder)
損失関数 CachedSpladeLoss(SparseDistillKLDivLoss) による reranker 蒸留
混合精度 bf16

SPLADE は MLM head の logits を語彙次元ごとに max pooling して得られる、文ごとに非ゼロ次元が数十から数百しかないスパース埋め込みです。検索時は内積でスコアリングできるので、転置インデックス (Elasticsearch, OpenSearch, Vespa, Qdrant 等) との親和性が高いのが特徴です。各次元が語彙トークンに対応するため、どの語が検索シグナルとして効いているかを直接確認できる解釈性の高さも利点です。

使い方

from sentence_transformers import SparseEncoder

model = SparseEncoder("mahiyama/splade-ja-310m-v2")

queries = [
    "日本の首都はどこですか?",
    "機械学習とは何ですか",
]
documents = [
    "東京は日本の首都であり、政治・経済の中心地である。",
    "京都は古都として知られ、神社仏閣が多くある。",
    "機械学習は人工知能の一分野で、データからパターンを学習する手法である。",
]

q_emb = model.encode_query(queries)
d_emb = model.encode_document(documents)

# 内積で類似度スコアを計算
scores = model.similarity(q_emb, d_emb)
print(scores)

# 拡張トークンを覗く (top_k=15 の語彙とその重み)
for token, weight in model.decode(q_emb[0], top_k=15):
    print(f"  {token:>20}  {weight:.3f}")

クエリと文書で同じ encoder を使う設計です (asymmetric ではなく symmetric)。スコアは内積で計算してください (cosine ではなく dot)。

トレーニング方法

ベースモデル

sbintuitions/modernbert-ja-310m を基盤としています。このモデルの MLM head を SPLADE の語彙次元射影として利用し、SpladePooling (max) を組み合わせてスパース埋め込みを得ます。

学習データ

学習には、いずれも query / positive / negative_1〜5 の n-tuples 形式に、教師スコアを label として付与したデータセットを使用しています。各データセットは次のパイプラインで構築されています。

# 工程 内容
1 Hard Negative Mining テキスト埋め込みモデルである cl-nagoya/ruri-v3-310m による kNN 検索で、正解に紛らわしい Hard Negatives を採掘しています。具体的には、query とコーパス全文書のコサイン類似度が最も高いトップ 15 件を取得し、正解 (positive) を除いた候補に対して下記のリランカーでスコアを再計算したうえで、positive スコアとの差が一定以上 (確実に分離できる) 候補のみをスコア降順で negative として採用しています。
2 蒸留スコアリング リランカーである cl-nagoya/ruri-v3-reranker-310m で (query, positive) と (query, negative_1〜5) の各ペアを採点し、その raw logit を label として付与しています。この label を教師信号として KLDiv 蒸留 (SparseDistillKLDivLoss を CachedSpladeLoss でラップ) を行っています。
3 品質スコアフィルタ label から算出した quality_score を基準に、学習価値が高い行のみを残した n-tuples-filtered config を使用しています。positive が信頼できる (positive スコアが十分高い)、偽 negative が混入していない (negative スコアが positive を超えない)、negative マージンが過度でない、という条件を満たす valid 行のみを残しています。

具体的なデータセットと採用件数は、後述の段階的トレーニングの各テーブルに記載しています。

段階的トレーニング

汎用性とドメイン適合を両立させるため、2 段階で学習しています。第 1 段階で大規模マルチソースコーパスから汎用 SPLADE ベースを構築し、第 2 段階でそのモデルを起点に特定ドメインデータで短期ファインチューニングを行うことで、汎用ベースの表現を大きく崩さずにドメイン精度を引き上げています。

第 1 段階: 汎用ベースの構築

約 2.2M 行のマルチソースコーパスで KLDiv 蒸留を行い、幅広いドメインに対応する汎用 SPLADE ベースを構築しました。

学習データ

データセット Config 採用件数 合計トークン数
mahiyama/auto-wiki-qa n-tuples-1m 993,749 704.83M
mahiyama/mqa-ja n-tuples-1m 994,263 389.89M
mahiyama/mmarco-ja n-tuples 186,747 77.86M
mahiyama/miracl-retrieval n-tuples-filtered 4,431 3.60M
mahiyama/mrtydi n-tuples-filtered 3,083 3.51M
合計 2,182,273 1,179.69M

ハイパーパラメータ

項目
learning_rate 5e-6
query_regularizer_weight 1e-5
document_regularizer_weight 2.5e-4
scheduler_type quadratic
warmup_ratio 0.2
num_train_epochs 1
max_seq_length 512
per_device_train_batch_size 128
mini_batch_size 32
temperature (SparseDistillKLDivLoss) 2.0

訓練時間

RTX PRO 6000 Blackwell Workstation Edition (96GB) x 1 で約 32.5 時間。

第 2 段階: 特定ドメイン短期ファインチューニング

第 1 段階のモデルを起点に、行政 FAQ・クイズ・学術論文・長文の各ドメインデータで短期 FT を行いました。学習率と正則化を抑えることで、汎用性を保ったままドメイン精度を引き上げています。

学習データ

データセット Config 採用件数 合計トークン数 役割
mahiyama/JaGovFaqs-22k n-tuples-filtered 11,107 8.13M 行政 FAQ
mahiyama/amagasaki-qna n-tuples-filtered 11,069 5.51M 自治体 QA
mahiyama/quiz-works n-tuples-filtered 12,502 16.67M クイズ (短文)
mahiyama/quiz-no-mori n-tuples-filtered 13,422 20.37M クイズ (短〜中文)
mahiyama/anlp-meeting-retrieval title-abs_n-tuples-filtered 1,926 1.72M 学術論文 (title-abs)
mahiyama/anlp-meeting-retrieval abs-intro_n-tuples-filtered 2,014 6.59M 学術論文 (abs-intro)
mahiyama/anlp-meeting-retrieval title-intro_n-tuples-filtered 1,952 6.97M 学術論文 (title-intro)
mahiyama/mldr-retrieval n-tuples-filtered 349 1.97M 長文ドキュメント検索
合計 54,341 67.94M

これらの学習データは、いずれも JMTEB v2 の各データセットの train split のみから構築しています。評価に用いる dev / test split は学習には一切含めていないため、評価結果へのテストデータの混入 (リーク) はありません。

ハイパーパラメータ

項目
learning_rate 2e-6
query_regularizer_weight 1.5e-5
document_regularizer_weight 3e-4
scheduler_type quadratic
warmup_ratio 0.1
num_train_epochs 2
max_seq_length 1024
per_device_train_batch_size 32
mini_batch_size 8
temperature (SparseDistillKLDivLoss) 2.0

訓練時間

RTX PRO 6000 Blackwell Workstation Edition (96GB) x 1 で約 5.0 時間。

正則化重みはいずれの段階でも quadratic warmup でスケジューリングし、訓練序盤に正則化が過剰に効いて埋め込みが潰れるのを防いでいます。

ノイズトークンの抑止

自然な語彙拡張を実現するため、検索シグナルとして無価値なトークンを訓練段階で抑止する仕組みを組み込んでいます。

抑止対象は、句読点や記号、SPM の語頭マーカ ▁ 単独、バイトフォールバック (<0xHH>)、装飾プレースホルダ (○○, ※, ■, → 等)、特殊トークン ([CLS], [SEP], [PAD], [MASK])、コーパス由来のメタトークンなどです。具体的には、語彙全体に対して以下のルールで自動判定したものに、コーパス固有のメタトークンを加えています。

  • Unicode カテゴリが Punctuation / Symbol / Separator / Other / Mark のみで構成されるトークン
  • バイトフォールバック (<0xHH> 形式)
  • SPM 語頭マーカ ▁ 単独
  • 同一文字の繰り返し (ーー, ・・, !! など)
  • modifier letter (Unicode カテゴリ Lm) のみで構成されるトークン
  • ASCII 1 文字の英字・数字
  • 装飾プレースホルダのリテラルリスト (○○, ××, △△, ■, ※, ▼, ◆, …, ★, ☆, ♪, →, ・, ー など)
  • 特殊トークン (CLS / SEP / PAD / MASK / UNK / BOS / EOS)

抑止は次の 2 段階で恒久化しています。

  1. 訓練時の forward pre-hook: SpladePooling の直前で MLM logits の抑止対象次元へ -1e4 を加算し、勾配がほぼ立たないようにします。bf16 環境で NaN を生まないよう -inf ではなく -1e4 を使用しています (relu(log(1+exp(-1e4))) ≈ 0 で実効ゼロ)。

  2. 訓練後の bias 焼込: MLM head の bias パラメータの抑止対象次元へ -1e4 を恒久的に書き込みます。hook を外したあとも推論時にノイズ次元が復活せず、保存モデルは素の SparseEncoder.from_pretrained() で読み込めます (trust_remote_code や追加モジュール不要)。

評価結果 (JMTEB v2)

JMTEB v2 の retrieval 11 タスクで nDCG@10 を測定しました。なお、JMTEB v2 は Sparse Model をサポートしていないため、測定には JMTEB v2 Retrieval タスクで Sparse Embedding モデルを評価するための評価ハーネスである jmteb-v2-sparse-eval を使用しています。

また、評価時には model.max_seq_length をタスクごとに下表の通り設定しています。本来は model.max_seq_length=8192 として全入力を扱いたいところですが、GPU リソースの制約から、各タスクで必要となる最低限のトークン数に抑えて評価しています。max_seq_length を小さくするとそれを超える入力トークンは切り捨てられるため、長文を含むタスクでは評価スコアが下がりやすい点に注意してください。

タスク max_seq_length nDCG@10 Recall@10 MAP@10
NLPJournalTitleAbsRetrieval.V2 512 0.9557 0.9882 0.9448
NLPJournalTitleIntroRetrieval.V2 2456 0.9251 0.9824 0.9059
NLPJournalAbsIntroRetrieval.V2 2456 0.9891 0.9961 0.9866
NLPJournalAbsArticleRetrieval.V2 8192 0.9926 0.9980 0.9907
MintakaRetrieval 512 0.2050 0.2980 0.1763
JaGovFaqsRetrieval 512 0.7603 0.8755 0.7228
JaqketRetrieval 8192 0.7805 0.8957 0.7428
JaCWIRRetrieval 512 0.7358 0.8308 0.7052
MultiLongDocRetrieval 8192 0.5291 0.6400 0.4941
MIRACLRetrieval 512 0.7444 0.8599 0.6655
MrTidyRetrieval 512 0.6340 0.7863 0.5706
平均 0.7501 0.8319 0.7187

Lessons Learned

1. 大規模スケールアップは教師が固定なら飽和する

本プロジェクトは教師リランカーの出力 (logits) を生徒に蒸留する構成です。この構成で訓練データを 300K から 2.2M へ約 7 倍に増やし約 33 GPU 時間を投じても、JMTEB v2 5 タスク平均は 0.8258 から 0.8262 (+0.0004) とほぼ横ばいでした。教師が固定だと生徒が学べる情報量は教師の知識容量に律速され、データを増やしても途中から「同じ知識の反復学習」になり、ある規模 (本件では 300K 前後) で飽和するように見えました。

対照的に、同じ 300K ベースに数千行のドメインデータで短期 Fine-Tuning (FT) を重ねると、1 時間ほどの訓練で平均が 0.8652 (+0.039) へ跳ね上がりました。7 倍のデータ増が +0.0004 だったのと比べ桁違いの効率です。スコアが頭打ちになったら、投資先をデータ量ではなくデータ質 (top-K フィルタ厳格化、ハードネガティブ強化、教師アンサンブル) や的を絞った短期 FT に振り替えるのが良いと感じました。

2. データ多様化はデータ量の増加よりも ROI が高い

ハイパーパラメータを固定したまま、訓練データを単一ソース (Wikipedia 250K) から、文体もクエリ長も異なる 5 ドメイン (Wikipedia QA / コミュニティ Q&A / パッセージ検索 / 雑学クイズ) のミックス (合計 274K、件数はほぼ同じ) に変えるだけで、Jaqket Retrieval の nDCG@10 が改善しました (corpus 8K で +0.018、corpus 65,802 のフル評価で +0.019)。

注目すべきは、同じ +0.018 が前段ではデータ量を 2.5 倍にして得られていた点です。件数をほぼ変えずソース数を 1 から 5 にするだけで、2.5 倍増と同等の伸びが出たことになります。さらにこの 274K モデルは、単一ソースを 1M 行まで増やした構成を約 1/3.6 のデータ量で上回りました。多様化でクエリ側の語彙拡張が広がる (query の平均アクティブ次元数が 15 から 19 に上昇) ためと見られます。実装は各ソースを共通スキーマに揃えて連結・シャッフルするだけで、特別な正規化は不要でした。

3. Catastrophic Forgetting 対策は「Multi-source の少量混合」が有効

特定ドメインだけで集中的に短期 FT をすると、狙ったタスクは伸びる一方、無関係なタスクが大きく退行する「シーソー現象」が起きます。実際、論文ドメインの単一ソースで FT した際は長文タスクが伸びた反面、短文クイズ系の Jaqket が -0.169 と急落しました。失われやすいドメイン (ここでは短文クエリのクイズ系) を訓練データに混ぜると保護でき、混合構成では Jaqket が退行を解消したうえ元を上回りました。

ただし「混ぜる量」が重要だと感じました。Web 一般文章を全体の 3〜4% (約 2,000 行) だけ薄く混ぜても回復せず、10〜20% 以上の比率で投入して初めて回復しました。保護したいドメインは「少量入れる」のではなく「一定比率まで厚く入れる」必要があると考えました。

4. 入力系列長は長文を扱う検索タスクで最大のレバーになる

論文など長文を文書側に持つ検索タスクでは、文書を読み切れるかがスコアを大きく左右します。model.max_seq_length (モデルが一度に読み込めるトークン数) を引き上げると、こうしたタスクでハイパーパラメータやデータ構成のどの変更よりも大きな改善が出ました。たとえば MultiLongDocRetrieval (文書長の中央値が約 6,700 トークン) では、1024 から 4096 に伸ばすだけで nDCG@10 が 0.22 向上しました。

ただし系列長を伸ばすほど訓練・評価・推論の時間とメモリは増えます (512→1024 で 1 ステップの訓練時間が約 1.8 倍)。本プロジェクトでは大規模な汎用ベース学習はスループット優先で 512、その上に短期 FT を重ねた公開モデルは長文対応で 1024 とし、4096 は診断実験で効果を確認しつつ常用はせず長文特化用の選択肢としました。

ライセンス

本モデルの重みは MIT License で提供します。ベースモデルである sbintuitions/modernbert-ja-310m も MIT License です。

ただし、学習データの一部に MS MARCO の日本語訳である mmarco-ja を使用している点に注意してください。MS MARCO は非商用の研究目的での利用を前提に公開されており、商用利用に関するライセンスは明確ではありません。そのため、本モデルを商用目的で利用する場合は、利用者ご自身の責任で MS MARCO の利用規約を確認したうえでご判断ください。

Downloads last month
43
Safetensors
Model size
0.3B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for mahiyama/splade-ja-310m-v2

Finetuned
(13)
this model

Datasets used to train mahiyama/splade-ja-310m-v2