Fill-Mask
Transformers
PyTorch
Japanese
bert
Inference Endpoints
aken12 commited on
Commit
9928cc8
1 Parent(s): a480771

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +79 -0
README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-sa-4.0
3
+ datasets:
4
+ - unicamp-dl/mmarco
5
+ - bclavie/mmarco-japanese-hard-negatives
6
+ language:
7
+ - ja
8
+ ---
9
+
10
+ SPLADE-japanese-v2 !!
11
+
12
+ Difference between splade-japanese v1 and v2
13
+ - initialize [tohoku-nlp/bert-base-japanese-v3](https://huggingface.co/tohoku-nlp/bert-base-japanese-v3)
14
+ - knowledge distillation from cross-encoder
15
+ - [mMARCO](https://github.com/unicamp-dl/mMARCO) Japanese dataset and use bclavie/mmarco-japanese-hard-negatives as hard negatives
16
+
17
+
18
+ you need to install
19
+ '''
20
+ !pip install fugashi ipadic unidic-lite
21
+ '''
22
+
23
+ ```python
24
+ from transformers import AutoModelForMaskedLM,AutoTokenizer
25
+ import torch
26
+ import numpy as np
27
+
28
+ model = AutoModelForMaskedLM.from_pretrained("aken12/splade-japanesev2-epoch5")
29
+ tokenizer = AutoTokenizer.from_pretrained("aken12/splade-japanesev2-epoch5")
30
+ vocab_dict = {v: k for k, v in tokenizer.get_vocab().items()}
31
+
32
+ def encode_query(query):
33
+ query = tokenizer(query, return_tensors="pt")
34
+ output = model(**query, return_dict=True).logits
35
+ output, _ = torch.max(torch.log(1 + torch.relu(output)) * query['attention_mask'].unsqueeze(-1), dim=1)
36
+ return output
37
+
38
+ with torch.no_grad():
39
+ model_output = encode_query(query="筑波大学では何の研究が行われているか?")
40
+
41
+ reps = model_output
42
+ idx = torch.nonzero(reps[0], as_tuple=False)
43
+
44
+ dict_splade = {}
45
+ for i in idx:
46
+ token_value = reps[0][i[0]].item()
47
+ if token_value > 0:
48
+ token = vocab_dict[int(i[0])]
49
+ dict_splade[token] = float(token_value)
50
+
51
+ sorted_dict_splade = sorted(dict_splade.items(), key=lambda item: item[1], reverse=True)
52
+ for token, value in sorted_dict_splade:
53
+ print(token, value)
54
+
55
+ ```
56
+
57
+ # output
58
+ '''
59
+ 筑波 2.0035860538482666
60
+ つくば 1.6586617231369019
61
+ 研究 1.6227693557739258
62
+ 大学 1.3798155784606934
63
+ 実験 0.5522942543029785
64
+ 学生 0.42351895570755005
65
+ 分析 0.37844282388687134
66
+ 国立 0.3685397505760193
67
+ キャンパス 0.36495038866996765
68
+ 茨城 0.3056415021419525
69
+ 科学 0.2876652181148529
70
+ 関東 0.24301066994667053
71
+ 地域 0.21340851485729218
72
+ 実施 0.1976248174905777
73
+ 先端 0.192025288939476
74
+ サイト 0.11629197001457214
75
+ 調査 0.09159307181835175
76
+ プロジェクト 0.08552580326795578
77
+ 議論 0.07484486699104309
78
+ 検討 0.007034890353679657
79
+ '''