Safetensors
Japanese
bert
japanese-splade-v2 / README.md
hotchpotch's picture
Update README.md
c62178f verified
|
raw
history blame
8.73 kB
metadata
license: mit
datasets:
  - hpprc/emb
  - hotchpotch/japanese-splade-v1-hard-negatives
  - hpprc/msmarco-ja
language:
  - ja
base_model:
  - hotchpotch/japanese-splade-base-v1_5

高性能な日本語 SPLADE (Sparse Lexical and Expansion Model) モデルです。テキストからスパースベクトルへの変換デモで、どのようなスパースベクトルに変換できるか、WebUI から気軽にお試しいただけます。

  • ⭐️記事へのリンク

また、モデルの学習にはYAST - Yet Another SPLADE or Sparse Trainerを使っています。

利用方法

YASEM (Yet Another Splade|Sparse Embedder)

YASEM を利用することで、SPLADEの推論・単語トークンの確認を簡単に行えます。

pip install yasem
from yasem import SpladeEmbedder

model_name = "hotchpotch/japanese-splade-v2"
embedder = SpladeEmbedder(model_name)

query = "車の燃費を向上させる方法は?"
docs = [
    "急発進や急ブレーキを避け、一定速度で走行することで燃費が良くなります。",
    "車の運転時、急発進や急ブレーキをすると、燃費が悪くなります。",
    "車を長持ちさせるには、消耗品を適切なタイミングで交換することが重要です。",
]

print(embedder.rank(query, docs, return_documents=True))
[
 { 'corpus_id': 0
 , 'score': 4.28
 , 'text': '急発進や急ブレーキを避け、一定速度で走行することで燃費が良くなります。' }
 ,
 { 'corpus_id': 2
 , 'score': 2.47
 , 'text': '車を長持ちさせるには、消耗品を適切なタイミングで交換することが重要です。' }
 ,
 { 'corpus_id': 1
 , 'score': 2.34
 , 'text': '車の運転時、急発進や急ブレーキをすると、燃費が悪くなります。' }
]
sentences = [query] + docs

embeddings = embedder.encode(sentences)
similarity = embedder.similarity(embeddings, embeddings)

print(similarity)
[[5.19151189, 4.28027662, 2.34164901, 2.47221905],
[4.28027662, 11.64426784, 5.00328318, 2.15031016],
[2.34164901, 5.00328318, 6.05594296, 1.33752085],
[2.47221905, 2.15031016, 1.33752085, 9.39414744]]
token_values = embedder.get_token_values(embeddings[0])
print(token_values)
{
 '燃費': 1.13,
 '方法': 1.07,
 '車': 1.05,
 '高める': 0.67,
 '向上': 0.56,
 '増加': 0.52,
 '都市': 0.44,
 'ガソリン': 0.32,
 '改善': 0.30,
 ...

transformers からの利用


from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch

model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

def splade_max_pooling(logits, attention_mask):
    relu_log = torch.log(1 + torch.relu(logits))
    weighted_log = relu_log * attention_mask.unsqueeze(-1)
    max_val, _ = torch.max(weighted_log, dim=1)
    return max_val

tokens = tokenizer(
    sentences, return_tensors="pt", padding=True, truncation=True, max_length=512
)
tokens = {k: v.to(model.device) for k, v in tokens.items()}

with torch.no_grad():
    outputs = model(**tokens)
embeddings = splade_max_pooling(outputs.logits, tokens["attention_mask"])

similarity = torch.matmul(embeddings.unsqueeze(0), embeddings.T).squeeze(0)
print(similarity)
tensor([
   [5.1872, 4.2792, 2.3440, 2.4680],
   [4.2792, 11.6327, 4.9983, 2.1470],
   [2.3440, 4.9983, 6.0517, 1.3377],
   [2.4680, 2.1470, 1.3377, 9.3801]
])

ベンチマークスコア

retrieval (JMTEB)

JMTEB の評価結果です。

japanese-splade-v2 は JMTEB をスパースベクトルで評価できるように変更したコードでの評価となっています。 なお、japanese-splade-v2 は JMTEB タスクである jaqket(や派生のjaqra), mrtydi(と派生のmiracl), jagovfaqs, nlp_jornal のデータセットのtrain,dev, testなどのデータは 学習に利用していません

モデル名 jagovfaqs jaqket mrtydi nlp_journal
title_abs
nlp_journal
abs_intro
nlp_journal
title_intro
Avg
<512
Avg
ALL
japanese-splade-v2 0.7313 0.6986 0.5106 0.9831 0.9067 0.8026 0.7309 0.7722
japanese-splade-base-v1 0.6499 0.6992 0.4365 0.8967 0.9766 0.8203 0.6906 0.7465
GLuCoSE-base-ja-v2 0.6979 0.6729 0.4186 0.9511 0.9029 0.7580 0.6851 0.7336
multilingual-e5-large 0.7030 0.5878 0.4363 0.9470 0.8600 0.7248 0.6685 0.7098
ruri-large 0.7668 0.6174 0.3803 0.9658 0.8712 0.7797 0.6826 0.7302
jinaai/jina-embeddings-v3 0.7150 0.4648 0.4545 0.9562 0.9843 0.9385 0.6476 0.7522
sarashina-embedding-v1-1b 0.7168 0.7279 0.4195 0.9696 0.9394 0.8833 0.7085 0.7761
OpenAI/text-embedding-3-large 0.7241 0.4821 0.3488 0.9655 0.9933 0.9547 0.6301 0.7448

スパース性

v1 ではスパース性が強すぎたので、v2 ではバランスをとったスパース性を持たせています。

で計測しています。

Target jaqket-query jaqket-docs mrtydi-query mrtydi-docs jagovfaqs_22k-query jagovfaqs_22k-docs nlp_journal_title_abs-query nlp_journal_title_abs-docs nlp_journal_title_intro-query nlp_journal_title_intro-docs nlp_journal_abs_intro-query nlp_journal_abs_intro-docs
v1 23.3 146.2 13.8 89.3 27.9 73.2 19 75.2 19 95.7 75.3 95.7
v1-mmarco-only 38.9 231.8 20.5 100.4 43.4 97.9 26.4 126.9 26.4 182 127.2 182
v1_5 36.7 268.7 22.8 237.6 47.9 237.3 34.9 225.6 34.9 235.2 224.5 235.2
v2 29.8 379.6 19.4 176.4 42 189.8 29 235.8 29 304.9 233.8 304.9

学習元データセット

ライセンス

MIT