PyTorch

Our best attempt at reproducing RankT5 Enc-Softmax, with a few important differences:

  1. We use a SPLADE first stage for the negatives vs GTR on the paper
  2. We train using Pytorch vs Flaxx on the paper
  3. We use the original t5-3b vs Flan T5-3b on the paper -> Actually the paper also uses t5-3b
  4. The head is not exactly the same, here we add Linear->LayerNorm->Linear and actually make a mistake by not including a nonlinearity. The original paper uses just a dense layer. Fixing this should improve our performance because we have more layers without actually using them correctly

This leads to what seems to be a slightly worse performance (42.8 vs 43.? on the paper) and seems slightly worse on BEIR as well.

To use this model, first clone the huggingface repo

git clone https://huggingface.co/naver/trecdl22-crossencoder-rankT53b-repro

And then we suggest loading it like follows:

import torch
from transformers import T5EncoderModel, AutoTokenizer
from transformers.modeling_outputs import SequenceClassifierOutput

class T5EncoderRerank(torch.nn.Module):
    def __init__(self, model_type_or_dir):
        super().__init__()
        self.model = T5EncoderModel.from_pretrained(model_type_or_dir)
        self.config = self.model.config
        self.first_transform = torch.nn.Linear(self.config.d_model, self.config.d_model)
        self.layer_norm = torch.nn.LayerNorm(self.config.d_model, eps=1e-12)
        self.linear = torch.nn.Linear(self.config.d_model,1)

    def forward(self, **kwargs):
        result = self.model(**kwargs).last_hidden_state[:,0,:]
        first_transformed = self.first_transform(result)
        layer_normed = self.layer_norm(first_transformed)
        logits = self.linear(layer_normed)
        return SequenceClassifierOutput(
            logits=logits
        )


original_model="t5-3b"
path_checkpoint="trecdl22-crossencoder-rankT53b-repro/pytorch_model.bin"

print("Loading")
model = T5EncoderRerank(original_model)
model.load_state_dict(torch.load(path_checkpoint,map_location=torch.device("cpu")))
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(original_model)
print("loaded")
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.