Allow loading via AutoModelForSequenceClassification

#1
by tomaarsen HF staff - opened

Hello!

Pull Request overview

  • Allow loading via AutoModelForSequenceClassification

Intro

This model is looking awesome! Looking forward to learning more about LoCo as well. I'd be more than happy to help out to make sure that these models load well for your users.
I also sent a message in Slack, but I'm not sure if you all are in there, so I'll repeat it here too:

As a quick introduction, I'm Tom & I am in charge of Sentence Transformers nowadays. I encountered a few slight issues in your model configurations, and I took some time to address them on togethercomputer/m2-bert-80M-8k-retrieval:

  1. Allow loading via AutoModelForSequenceClassification (#1): There was a bug preventing your README snippet from working.
  2. Allow loading via AutoModel (#2): The configuration to load with AutoModel was missing.
  3. Allow loading via AutoTokenizer (#3): The configuration to defer the AutoTokenizer to bert-base-cased did not work - the auto_map can't be used like that sadly. This PR allows loading the tokenizer for this model directly, without having to override model_max_length.

Feel free to check these out & distribute the fixes across your models if you wish. Feel free to ask us if you need any assistance as well (we can also add you to a Slack channel for contact with us, if you're not in one already).

Additionally, I would certainly recommend including the MTEB results for these models in the model README metadata - it could be great for additional visibility.
Lastly, I'm looking into 1st party support for Sentence Transformers, allowing your models to be loaded directly with ST as well! It might allow your models to reach an even larger audience.

Details

I wanted to experiment with this model using:

from transformers import AutoTokenizer, AutoModelForSequenceClassification

max_seq_length = 8192
testing_string = "Every morning, I make a cup of coffee to start my day."
model = AutoModelForSequenceClassification.from_pretrained(
    "togethercomputer/m2-bert-80M-8k-retrieval",
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(
    "bert-base-uncased",
    model_max_length=max_seq_length
)
input_ids = tokenizer(
    [testing_string],
    return_tensors="pt",
    padding="max_length",
    return_token_type_ids=False,
    truncation=True,
    max_length=max_seq_length
)

outputs = model(**input_ids)
embeddings = outputs['sentence_embedding']
print(embeddings[0,:10], embeddings[0].sum())

But I ran into this issue:

You are using a model of type m2_bert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
Traceback (most recent call last):
  File "c:\code\m2-bert-80M-8k-retrieval\demo.py", line 5, in <module>
    model = AutoModelForSequenceClassification.from_pretrained(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom\.conda\envs\sentence-transformers\Lib\site-packages\transformers\models\auto\auto_factory.py", line 511, in from_pretrained
    cls.register(config.__class__, model_class, exist_ok=True)
  File "C:\Users\tom\.conda\envs\sentence-transformers\Lib\site-packages\transformers\models\auto\auto_factory.py", line 537, in register
    raise ValueError(
ValueError: The model class you are passing has a `config_class` attribute that is not consistent with the config class you passed (model has <class 'transformers.models.bert.configuration_bert.BertConfig'> and you passed <class 'transformers_modules.togethercomputer.m2-bert-80M-8k-retrieval.66ea5d6b12ab7e3d332bba708d76f83ce2909b2e.configuration_bert.BertConfig'>. Fix one of those so they match!

In short, the classes that I'm trying to initialize (e.g. your BertForTextEncoding) are configured to work with transformers its BertConfig, rather than your own custom BertConfig. This PR resolves this problem, by overriding the config_class that is adopted from the BertPreTrainedModel superclass.

After this PR

The above script now returns:

You are using a model of type m2_bert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
-- Bidirectional: True
-- Using Long Conv Residual: True
-- Hyena w: 10
-- Hyena w mod: 1
-- Hyena filter order: 128
-- Hyena filter dropout: 0.2
-- Hyena filter wd: 0.1
-- Hyena filter emb dim: 5
-- Hyena filter lr: 0.001
-- Hyena filter lr pos emb: 1e-05
tensor([ 0.0399,  0.2460,  0.4248,  0.1803, -0.0941, -0.1501,  0.0705,  0.0478,
         0.0119, -0.0807], grad_fn=<SliceBackward0>) tensor(-1.7552, grad_fn=<SumBackward0>)

πŸŽ‰

  • Tom Aarsen
tomaarsen changed pull request status to open
tomaarsen changed pull request status to closed
tomaarsen changed pull request status to open
Together org

Thank you!!

danfu09 changed pull request status to merged

Sign up or log in to comment