polish-splade / README.md
sdadas's picture
Update README.md
c383674 verified
|
raw
history blame
3.65 kB
metadata
tags:
  - transformers
  - information-retrieval
language: pl
license: apache-2.0

Polish-SPLADE

This is a Polish version of SPLADE++ (EnsembleDistil) model described in the paper From distillation to hard negative sampling: Making sparse neural ir models more effective. Sparse Lexical and Expansion (SPLADE) is a family of modern term-based retrieval methods employing Transformer language models. In this approach, the masked language modeling (MLM) head is optimized to generate a vocabulary-sized weight vector adapted for text retrieval. SPLADE is a highly effective sparse retrieval ranking algorithm, achieving results better than classic methods such as BM25 and comparable to high-quality dense encoders.

This model was fine-tuned from polish-distilroberta checkpoint on the Polish translation of the MS MARCO dataset. We used the default training hyperparameters from the official SPLADE repository.

Below is a example of using SPLADE without any additional dependencies other than Huggingface Transformers:

import torch, math
import numpy as np
from transformers import AutoTokenizer, AutoModelForMaskedLM

model_name = "sdadas/polish-splade"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
vocab = {v: k for k, v in tokenizer.get_vocab().items()}

def encode_splade(text: str):
    input = tokenizer([text], padding="longest", truncation=True, return_tensors="pt", max_length=512)
    output = model(**input)
    logits, attention_mask = output["logits"].detach(), input["attention_mask"].detach()
    attention_mask = attention_mask.unsqueeze(-1)
    vector = torch.max(torch.log(torch.add(torch.relu(logits), 1)) * attention_mask, dim=1)
    vector = vector[0].detach().squeeze()
    idx = np.nonzero(vector.cpu().numpy())[0]
    vector = vector[idx]
    return {vocab[k]: float(v) for k, v in zip(list(idx), list(vector))}

def cos_sim(vec1, vec2):
    intersection = set(vec1.keys()) & set(vec2.keys())
    numerator = sum([vec1[x] * vec2[x] for x in intersection])
    sum1 = sum([vec1[x] ** 2 for x in list(vec1.keys())])
    sum2 = sum([vec2[x] ** 2 for x in list(vec2.keys())])
    denominator = math.sqrt(sum1) * math.sqrt(sum2)
    return (numerator / denominator) if denominator else 0.0

question = encode_splade("Jak dożyć 100 lat?")
answer = encode_splade("Trzeba zdrowo się odżywiać i uprawiać sport.")
print(cos_sim(question, answer))

Example of use with the PIRB library:

from search import SpladeEncoder
from sentence_transformers.util import cos_sim

config = {"name": "sdadas/polish-splade", "fp16": True}
encoder = SpladeEncoder(config, True)
results = encoder.encode_batch(["Jak dożyć 100 lat?", "Trzeba zdrowo się odżywiać i uprawiać sport."])
print(cos_sim(results[0], results[1]))

Using SPLADE to index and search large datasets is a more complex task and requires integration with term-based index such as Lucene. For this purpose, you can use the official SPLADE implementation or reimplementation of this model in our PIRB library.

Citation

@article{dadas2024pirb,
  title={{PIRB}: A Comprehensive Benchmark of Polish Dense and Hybrid Text Retrieval Methods}, 
  author={Sławomir Dadas and Michał Perełkiewicz and Rafał Poświata},
  year={2024},
  eprint={2402.13350},
  archivePrefix={arXiv},
  primaryClass={cs.CL}
}