liamclarkza commited on
Commit
6617c7e
1 Parent(s): d064dec

Fix AutoModel not loading model correctly due to config_class inconsistency

Browse files

This fixes an issue when using AutoModel to instantiate the model where the config class instantiated with the model is from the transformers library instead of the model's module. This causes the instantiation to fail with the error below. See [this Github issue](https://github.com/huggingface/transformers/issues/31068) for more details.
```
Traceback (most recent call last):
model = AutoModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 560, in from_pretrained
cls.register(config.__class__, model_class, exist_ok=True)
File ".../lib/python3.11/site-packages/transformers/models/auto/auto_factory.py", line 586, in register
raise ValueError(
ValueError: The model class you are passing has a `config_class` attribute that is not consistent with the config class you passed (model has <class 'transformers.models.bert.configuration_bert.BertConfig'> and you passed <class 'transformers_modules.zhihan1996.DNABERT-2-117M.d064dece8a8b41d9fb8729fbe3435278786931f1.configuration_bert.BertConfig'>. Fix one of those so they match!
```

Files changed (1) hide show
  1. bert_layers.py +3 -1
bert_layers.py CHANGED
@@ -24,6 +24,7 @@ from transformers.modeling_utils import PreTrainedModel
24
  from .bert_padding import (index_first_axis,
25
  index_put_first_axis, pad_input,
26
  unpad_input, unpad_input_only)
 
27
 
28
  try:
29
  from .flash_attn_triton import flash_attn_qkvpacked_func
@@ -564,7 +565,8 @@ class BertModel(BertPreTrainedModel):
564
  all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
565
  ```
566
  """
567
-
 
568
  def __init__(self, config, add_pooling_layer=True):
569
  super(BertModel, self).__init__(config)
570
  self.embeddings = BertEmbeddings(config)
 
24
  from .bert_padding import (index_first_axis,
25
  index_put_first_axis, pad_input,
26
  unpad_input, unpad_input_only)
27
+ from .configuration_bert import BertConfig
28
 
29
  try:
30
  from .flash_attn_triton import flash_attn_qkvpacked_func
 
565
  all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
566
  ```
567
  """
568
+ config_class = BertConfig
569
+
570
  def __init__(self, config, add_pooling_layer=True):
571
  super(BertModel, self).__init__(config)
572
  self.embeddings = BertEmbeddings(config)