Our best attempt at reproducing RankT5 Enc-Softmax, with a few important differences:
- We use a SPLADE first stage for the negatives vs GTR on the paper
- We train using Pytorch vs Flaxx on the paper
We use the original t5-3b vs Flan T5-3b on the paper-> Actually the paper also uses t5-3b- The head is not exactly the same, here we add Linear->LayerNorm->Linear and actually make a mistake by not including a nonlinearity. The original paper uses just a dense layer. Fixing this should improve our performance because we have more layers without actually using them correctly
This leads to what seems to be a slightly worse performance (42.8 vs 43.? on the paper) and seems slightly worse on BEIR as well.
To use this model, first clone the huggingface repo
git clone https://huggingface.co/naver/trecdl22-crossencoder-rankT53b-repro
And then we suggest loading it like follows:
import torch
from transformers import T5EncoderModel, AutoTokenizer
from transformers.modeling_outputs import SequenceClassifierOutput
class T5EncoderRerank(torch.nn.Module):
def __init__(self, model_type_or_dir):
super().__init__()
self.model = T5EncoderModel.from_pretrained(model_type_or_dir)
self.config = self.model.config
self.first_transform = torch.nn.Linear(self.config.d_model, self.config.d_model)
self.layer_norm = torch.nn.LayerNorm(self.config.d_model, eps=1e-12)
self.linear = torch.nn.Linear(self.config.d_model,1)
def forward(self, **kwargs):
result = self.model(**kwargs).last_hidden_state[:,0,:]
first_transformed = self.first_transform(result)
layer_normed = self.layer_norm(first_transformed)
logits = self.linear(layer_normed)
return SequenceClassifierOutput(
logits=logits
)
original_model="t5-3b"
path_checkpoint="trecdl22-crossencoder-rankT53b-repro/pytorch_model.bin"
print("Loading")
model = T5EncoderRerank(original_model)
model.load_state_dict(torch.load(path_checkpoint,map_location=torch.device("cpu")))
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
tokenizer = AutoTokenizer.from_pretrained(original_model)
print("loaded")