Fill-Mask
Transformers
PyTorch
Japanese
bert
Inference Endpoints
aken12's picture
Create README.md
9928cc8 verified
|
raw
history blame
2.22 kB
metadata
license: cc-by-sa-4.0
datasets:
  - unicamp-dl/mmarco
  - bclavie/mmarco-japanese-hard-negatives
language:
  - ja

SPLADE-japanese-v2 !!

Difference between splade-japanese v1 and v2

you need to install ''' !pip install fugashi ipadic unidic-lite '''

from transformers import AutoModelForMaskedLM,AutoTokenizer
import torch
import numpy as np

model = AutoModelForMaskedLM.from_pretrained("aken12/splade-japanesev2-epoch5") 
tokenizer = AutoTokenizer.from_pretrained("aken12/splade-japanesev2-epoch5")
vocab_dict = {v: k for k, v in tokenizer.get_vocab().items()}

def encode_query(query):
    query = tokenizer(query, return_tensors="pt")
    output = model(**query, return_dict=True).logits
    output, _ = torch.max(torch.log(1 + torch.relu(output)) * query['attention_mask'].unsqueeze(-1), dim=1)
    return output

with torch.no_grad():
    model_output = encode_query(query="筑波大学では何の研究が行われているか?")

reps = model_output
idx = torch.nonzero(reps[0], as_tuple=False)

dict_splade = {}
for i in idx:
    token_value = reps[0][i[0]].item()
    if token_value > 0:
        token = vocab_dict[int(i[0])]
        dict_splade[token] = float(token_value)

sorted_dict_splade = sorted(dict_splade.items(), key=lambda item: item[1], reverse=True)
for token, value in sorted_dict_splade:
    print(token, value)

output

''' 筑波 2.0035860538482666 つくば 1.6586617231369019 研究 1.6227693557739258 大学 1.3798155784606934 実験 0.5522942543029785 学生 0.42351895570755005 分析 0.37844282388687134 国立 0.3685397505760193 キャンパス 0.36495038866996765 茨城 0.3056415021419525 科学 0.2876652181148529 関東 0.24301066994667053 地域 0.21340851485729218 実施 0.1976248174905777 先端 0.192025288939476 サイト 0.11629197001457214 調査 0.09159307181835175 プロジェクト 0.08552580326795578 議論 0.07484486699104309 検討 0.007034890353679657 '''