Create README.md
Browse files
README.md
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: cc-by-sa-4.0
|
3 |
+
datasets:
|
4 |
+
- unicamp-dl/mmarco
|
5 |
+
- bclavie/mmarco-japanese-hard-negatives
|
6 |
+
language:
|
7 |
+
- ja
|
8 |
+
---
|
9 |
+
|
10 |
+
SPLADE-japanese-v2 !!
|
11 |
+
|
12 |
+
Difference between splade-japanese v1 and v2
|
13 |
+
- initialize [tohoku-nlp/bert-base-japanese-v3](https://huggingface.co/tohoku-nlp/bert-base-japanese-v3)
|
14 |
+
- knowledge distillation from cross-encoder
|
15 |
+
- [mMARCO](https://github.com/unicamp-dl/mMARCO) Japanese dataset and use bclavie/mmarco-japanese-hard-negatives as hard negatives
|
16 |
+
|
17 |
+
|
18 |
+
you need to install
|
19 |
+
'''
|
20 |
+
!pip install fugashi ipadic unidic-lite
|
21 |
+
'''
|
22 |
+
|
23 |
+
```python
|
24 |
+
from transformers import AutoModelForMaskedLM,AutoTokenizer
|
25 |
+
import torch
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
model = AutoModelForMaskedLM.from_pretrained("aken12/splade-japanesev2-epoch5")
|
29 |
+
tokenizer = AutoTokenizer.from_pretrained("aken12/splade-japanesev2-epoch5")
|
30 |
+
vocab_dict = {v: k for k, v in tokenizer.get_vocab().items()}
|
31 |
+
|
32 |
+
def encode_query(query):
|
33 |
+
query = tokenizer(query, return_tensors="pt")
|
34 |
+
output = model(**query, return_dict=True).logits
|
35 |
+
output, _ = torch.max(torch.log(1 + torch.relu(output)) * query['attention_mask'].unsqueeze(-1), dim=1)
|
36 |
+
return output
|
37 |
+
|
38 |
+
with torch.no_grad():
|
39 |
+
model_output = encode_query(query="筑波大学では何の研究が行われているか?")
|
40 |
+
|
41 |
+
reps = model_output
|
42 |
+
idx = torch.nonzero(reps[0], as_tuple=False)
|
43 |
+
|
44 |
+
dict_splade = {}
|
45 |
+
for i in idx:
|
46 |
+
token_value = reps[0][i[0]].item()
|
47 |
+
if token_value > 0:
|
48 |
+
token = vocab_dict[int(i[0])]
|
49 |
+
dict_splade[token] = float(token_value)
|
50 |
+
|
51 |
+
sorted_dict_splade = sorted(dict_splade.items(), key=lambda item: item[1], reverse=True)
|
52 |
+
for token, value in sorted_dict_splade:
|
53 |
+
print(token, value)
|
54 |
+
|
55 |
+
```
|
56 |
+
|
57 |
+
# output
|
58 |
+
'''
|
59 |
+
筑波 2.0035860538482666
|
60 |
+
つくば 1.6586617231369019
|
61 |
+
研究 1.6227693557739258
|
62 |
+
大学 1.3798155784606934
|
63 |
+
実験 0.5522942543029785
|
64 |
+
学生 0.42351895570755005
|
65 |
+
分析 0.37844282388687134
|
66 |
+
国立 0.3685397505760193
|
67 |
+
キャンパス 0.36495038866996765
|
68 |
+
茨城 0.3056415021419525
|
69 |
+
科学 0.2876652181148529
|
70 |
+
関東 0.24301066994667053
|
71 |
+
地域 0.21340851485729218
|
72 |
+
実施 0.1976248174905777
|
73 |
+
先端 0.192025288939476
|
74 |
+
サイト 0.11629197001457214
|
75 |
+
調査 0.09159307181835175
|
76 |
+
プロジェクト 0.08552580326795578
|
77 |
+
議論 0.07484486699104309
|
78 |
+
検討 0.007034890353679657
|
79 |
+
'''
|