max model size / max seq length
Hi,
I see that 'By default, input text longer than 384 word pieces is truncated'. However, in the tokenizer config I see model_max_length is 512. Does the model respect this? Or do i need to set the max seq length somewhere? Thanks,
Hello!
The model indeed respects the token length of 384 via this configuration setting: https://huggingface.co/sentence-transformers/all-mpnet-base-v2/blob/main/sentence_bert_config.json#L2
This parameter has priority over the tokenizer one. I do recognize that it is a bit confusing to have two separate values for the same setting in the model.
- Tom Aarsen
than you for the response
is it possible to set the max_seq_length to 512 via transformers?
You could, but the performance of the model will likely be worse than if you kept it at 384. Feel free to experiment with it:
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, Pooling
transformer = Transformer("sentence-transformers/all-mpnet-base-v2", max_seq_length=512)
pooling = Pooling(transformer.get_word_embedding_dimension(), "mean")
model = SentenceTransformer(modules=[transformer, pooling])
embedding = model.encode("My text!")
print(embedding.shape)
Hi, is it possible to specify max_seq_length if we are using AutoTokenizer and AutoModel? I can pass max_length at tokenisation time, but I doubt that stops the model truncating at 384. Thank you
from transformers import AutoTokenizer, AutoModel
import torch
#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask
#Sentences we want sentence embeddings for
sentences = ['This framework generates embeddings for each input sentence']
#Load AutoModel from huggingface model repository
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
#Tokenize sentences
encoded_input = tokenizer(sentences, padding=True, truncation=True, max_length=128, return_tensors='pt')
#Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input)
#Perform pooling. In this case, mean pooling
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
Some simple testing would indicate that the model processes up to 512. For the same text string, if I truncate at 384, and then again at greater than 384 but less than 512, I get a different vectors back. If I try greater than 512 the model throws an error.
I saw someone mentioned by default it should have a max_seq_length of 384, however, now it became 512, I did nothing but SentenceTransformer("all-mpnet-base-v2")
Could you explain why the default changed from 384 to 512?
SentenceTransformer("all-mpnet-base-v2")
output:
Transformer({'max_seq_length': 512, })
@paulmoonraker
The model indeed crashes after 512, but was trained to work up to 384. 384 is recommended as the sequence length. You can set the max_length on tokenization-time like you've done, or with AutoTokenizer.from_pretrained("...", model_max_length=384)
I believe.
@keyuchen2020
Huh, that is odd. It should indeed be 384. What version of sentence-transformers
are you using? With 2.5.1 I get:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("all-mpnet-base-v2")
print(model)
SentenceTransformer(
(0): Transformer({'max_seq_length': 384, 'do_lower_case': False}) with Transformer model: MPNetModel
(1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
(2): Normalize()
)
i.e. max_seq_length of 384.
- Tom Aarsen