Update README.md
Browse files
README.md
CHANGED
@@ -1,3 +1,63 @@
|
|
1 |
---
|
|
|
|
|
|
|
|
|
2 |
license: apache-2.0
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
tags:
|
3 |
+
- transformers
|
4 |
+
- information-retrieval
|
5 |
+
language: pl
|
6 |
license: apache-2.0
|
7 |
---
|
8 |
+
|
9 |
+
<h1 align="center">Polish-SPLADE</h1>
|
10 |
+
|
11 |
+
This is a Polish version of SPLADE++ (EnsembleDistil) model described in the paper [From distillation to hard negative sampling: Making sparse neural ir models more effective](https://arxiv.org/abs/2205.04733). Sparse Lexical and Expansion (SPLADE) is a family of modern term-based retrieval methods employing Transformer language models. In this approach, the masked language modeling (MLM) head is optimized to generate a vocabulary-sized weight vector adapted for text retrieval. SPLADE is a highly effective sparse retrieval ranking algorithm, achieving results better than classic methods such as BM25 and comparable to high-quality dense encoders.
|
12 |
+
|
13 |
+
This model was fine-tuned from [polish-distilroberta](https://huggingface.co/sdadas/polish-distilroberta) checkpoint on the Polish translation of the MS MARCO dataset. We used the default training hyperparameters from the official [SPLADE repository](https://github.com/naver/splade).
|
14 |
+
|
15 |
+
Below is a example of using SPLADE without any additional dependencies other than Huggingface Transformers:
|
16 |
+
|
17 |
+
```python
|
18 |
+
import torch, math
|
19 |
+
import numpy as np
|
20 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
21 |
+
|
22 |
+
model_name = "sdadas/polish-splade"
|
23 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
24 |
+
model = AutoModelForMaskedLM.from_pretrained(model_name)
|
25 |
+
vocab = {v: k for k, v in tokenizer.get_vocab().items()}
|
26 |
+
|
27 |
+
def encode_splade(text: str):
|
28 |
+
input = tokenizer([text], padding="longest", truncation=True, return_tensors="pt", max_length=512)
|
29 |
+
output = model(**input)
|
30 |
+
logits, attention_mask = output["logits"].detach(), input["attention_mask"].detach()
|
31 |
+
attention_mask = attention_mask.unsqueeze(-1)
|
32 |
+
vector = torch.max(torch.log(torch.add(torch.relu(logits), 1)) * attention_mask, dim=1)
|
33 |
+
vector = vector[0].detach().squeeze()
|
34 |
+
idx = np.nonzero(vector.cpu().numpy())[0]
|
35 |
+
vector = vector[idx]
|
36 |
+
return {vocab[k]: float(v) for k, v in zip(list(idx), list(vector))}
|
37 |
+
|
38 |
+
def cos_sim(vec1, vec2):
|
39 |
+
intersection = set(vec1.keys()) & set(vec2.keys())
|
40 |
+
numerator = sum([vec1[x] * vec2[x] for x in intersection])
|
41 |
+
sum1 = sum([vec1[x] ** 2 for x in list(vec1.keys())])
|
42 |
+
sum2 = sum([vec2[x] ** 2 for x in list(vec2.keys())])
|
43 |
+
denominator = math.sqrt(sum1) * math.sqrt(sum2)
|
44 |
+
return (numerator / denominator) if denominator else 0.0
|
45 |
+
|
46 |
+
question = encode_splade("Jak dożyć 100 lat?")
|
47 |
+
answer = encode_splade("Trzeba zdrowo się odżywiać i uprawiać sport.")
|
48 |
+
print(cos_sim(question, answer))
|
49 |
+
```
|
50 |
+
|
51 |
+
Example of use with the [PIRB](https://github.com/sdadas/pirb) library:
|
52 |
+
|
53 |
+
```python
|
54 |
+
from search import SpladeEncoder
|
55 |
+
from sentence_transformers.util import cos_sim
|
56 |
+
|
57 |
+
config = {"name": "sdadas/polish-splade", "fp16": True}
|
58 |
+
encoder = SpladeEncoder(config, True)
|
59 |
+
results = encoder.encode_batch(["Jak dożyć 100 lat?", "Trzeba zdrowo się odżywiać i uprawiać sport."])
|
60 |
+
print(cos_sim(results[0], results[1]))
|
61 |
+
```
|
62 |
+
|
63 |
+
Using SPLADE to index and search large datasets is a more complex task and requires integration with term-based index such as Lucene. For this purpose, you can use the official [SPLADE implementation](https://github.com/naver/splade) or reimplementation of this model in our [PIRB library](https://github.com/sdadas/pirb).
|