# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class. """
import warnings
from collections import OrderedDict
from ...utils import logging
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES
logger = logging.get_logger(__name__)
TF_MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("deberta-v2", "TFDebertaV2Model"),
("deberta", "TFDebertaModel"),
("rembert", "TFRemBertModel"),
("roformer", "TFRoFormerModel"),
("convbert", "TFConvBertModel"),
("led", "TFLEDModel"),
("lxmert", "TFLxmertModel"),
("mt5", "TFMT5Model"),
("t5", "TFT5Model"),
("distilbert", "TFDistilBertModel"),
("albert", "TFAlbertModel"),
("bart", "TFBartModel"),
("camembert", "TFCamembertModel"),
("xlm-roberta", "TFXLMRobertaModel"),
("longformer", "TFLongformerModel"),
("roberta", "TFRobertaModel"),
("layoutlm", "TFLayoutLMModel"),
("bert", "TFBertModel"),
("openai-gpt", "TFOpenAIGPTModel"),
("gpt2", "TFGPT2Model"),
("mobilebert", "TFMobileBertModel"),
("transfo-xl", "TFTransfoXLModel"),
("xlnet", "TFXLNetModel"),
("flaubert", "TFFlaubertModel"),
("xlm", "TFXLMModel"),
("ctrl", "TFCTRLModel"),
("electra", "TFElectraModel"),
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
("dpr", "TFDPRQuestionEncoder"),
("mpnet", "TFMPNetModel"),
("mbart", "TFMBartModel"),
("marian", "TFMarianModel"),
("pegasus", "TFPegasusModel"),
("blenderbot", "TFBlenderbotModel"),
("blenderbot-small", "TFBlenderbotSmallModel"),
("wav2vec2", "TFWav2Vec2Model"),
("hubert", "TFHubertModel"),
]
)
TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[
# Model for pre-training mapping
("lxmert", "TFLxmertForPreTraining"),
("t5", "TFT5ForConditionalGeneration"),
("distilbert", "TFDistilBertForMaskedLM"),
("albert", "TFAlbertForPreTraining"),
("bart", "TFBartForConditionalGeneration"),
("camembert", "TFCamembertForMaskedLM"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("roberta", "TFRobertaForMaskedLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("bert", "TFBertForPreTraining"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("mobilebert", "TFMobileBertForPreTraining"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("xlnet", "TFXLNetLMHeadModel"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
("ctrl", "TFCTRLLMHeadModel"),
("electra", "TFElectraForPreTraining"),
("funnel", "TFFunnelForPreTraining"),
("mpnet", "TFMPNetForMaskedLM"),
]
)
TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
# Model with LM heads mapping
("rembert", "TFRemBertForMaskedLM"),
("roformer", "TFRoFormerForMaskedLM"),
("convbert", "TFConvBertForMaskedLM"),
("led", "TFLEDForConditionalGeneration"),
("t5", "TFT5ForConditionalGeneration"),
("distilbert", "TFDistilBertForMaskedLM"),
("albert", "TFAlbertForMaskedLM"),
("marian", "TFMarianMTModel"),
("bart", "TFBartForConditionalGeneration"),
("camembert", "TFCamembertForMaskedLM"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("longformer", "TFLongformerForMaskedLM"),
("roberta", "TFRobertaForMaskedLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("bert", "TFBertForMaskedLM"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("mobilebert", "TFMobileBertForMaskedLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("xlnet", "TFXLNetLMHeadModel"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
("ctrl", "TFCTRLLMHeadModel"),
("electra", "TFElectraForMaskedLM"),
("funnel", "TFFunnelForMaskedLM"),
("mpnet", "TFMPNetForMaskedLM"),
]
)
TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
("rembert", "TFRemBertForCausalLM"),
("roformer", "TFRoFormerForCausalLM"),
("bert", "TFBertLMHeadModel"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("xlnet", "TFXLNetLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
("ctrl", "TFCTRLLMHeadModel"),
]
)
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
("deberta-v2", "TFDebertaV2ForMaskedLM"),
("deberta", "TFDebertaForMaskedLM"),
("rembert", "TFRemBertForMaskedLM"),
("roformer", "TFRoFormerForMaskedLM"),
("convbert", "TFConvBertForMaskedLM"),
("distilbert", "TFDistilBertForMaskedLM"),
("albert", "TFAlbertForMaskedLM"),
("camembert", "TFCamembertForMaskedLM"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("longformer", "TFLongformerForMaskedLM"),
("roberta", "TFRobertaForMaskedLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("bert", "TFBertForMaskedLM"),
("mobilebert", "TFMobileBertForMaskedLM"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("xlm", "TFXLMWithLMHeadModel"),
("electra", "TFElectraForMaskedLM"),
("funnel", "TFFunnelForMaskedLM"),
("mpnet", "TFMPNetForMaskedLM"),
]
)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
("led", "TFLEDForConditionalGeneration"),
("mt5", "TFMT5ForConditionalGeneration"),
("t5", "TFT5ForConditionalGeneration"),
("marian", "TFMarianMTModel"),
("mbart", "TFMBartForConditionalGeneration"),
("pegasus", "TFPegasusForConditionalGeneration"),
("blenderbot", "TFBlenderbotForConditionalGeneration"),
("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
("bart", "TFBartForConditionalGeneration"),
]
)
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("deberta-v2", "TFDebertaV2ForSequenceClassification"),
("deberta", "TFDebertaForSequenceClassification"),
("rembert", "TFRemBertForSequenceClassification"),
("roformer", "TFRoFormerForSequenceClassification"),
("convbert", "TFConvBertForSequenceClassification"),
("distilbert", "TFDistilBertForSequenceClassification"),
("albert", "TFAlbertForSequenceClassification"),
("camembert", "TFCamembertForSequenceClassification"),
("xlm-roberta", "TFXLMRobertaForSequenceClassification"),
("longformer", "TFLongformerForSequenceClassification"),
("roberta", "TFRobertaForSequenceClassification"),
("layoutlm", "TFLayoutLMForSequenceClassification"),
("bert", "TFBertForSequenceClassification"),
("xlnet", "TFXLNetForSequenceClassification"),
("mobilebert", "TFMobileBertForSequenceClassification"),
("flaubert", "TFFlaubertForSequenceClassification"),
("xlm", "TFXLMForSequenceClassification"),
("electra", "TFElectraForSequenceClassification"),
("funnel", "TFFunnelForSequenceClassification"),
("gpt2", "TFGPT2ForSequenceClassification"),
("mpnet", "TFMPNetForSequenceClassification"),
("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
("transfo-xl", "TFTransfoXLForSequenceClassification"),
("ctrl", "TFCTRLForSequenceClassification"),
]
)
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
("deberta-v2", "TFDebertaV2ForQuestionAnswering"),
("deberta", "TFDebertaForQuestionAnswering"),
("rembert", "TFRemBertForQuestionAnswering"),
("roformer", "TFRoFormerForQuestionAnswering"),
("convbert", "TFConvBertForQuestionAnswering"),
("distilbert", "TFDistilBertForQuestionAnswering"),
("albert", "TFAlbertForQuestionAnswering"),
("camembert", "TFCamembertForQuestionAnswering"),
("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
("longformer", "TFLongformerForQuestionAnswering"),
("roberta", "TFRobertaForQuestionAnswering"),
("bert", "TFBertForQuestionAnswering"),
("xlnet", "TFXLNetForQuestionAnsweringSimple"),
("mobilebert", "TFMobileBertForQuestionAnswering"),
("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
("xlm", "TFXLMForQuestionAnsweringSimple"),
("electra", "TFElectraForQuestionAnswering"),
("funnel", "TFFunnelForQuestionAnswering"),
("mpnet", "TFMPNetForQuestionAnswering"),
]
)
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
("deberta-v2", "TFDebertaV2ForTokenClassification"),
("deberta", "TFDebertaForTokenClassification"),
("rembert", "TFRemBertForTokenClassification"),
("roformer", "TFRoFormerForTokenClassification"),
("convbert", "TFConvBertForTokenClassification"),
("distilbert", "TFDistilBertForTokenClassification"),
("albert", "TFAlbertForTokenClassification"),
("camembert", "TFCamembertForTokenClassification"),
("flaubert", "TFFlaubertForTokenClassification"),
("xlm", "TFXLMForTokenClassification"),
("xlm-roberta", "TFXLMRobertaForTokenClassification"),
("longformer", "TFLongformerForTokenClassification"),
("roberta", "TFRobertaForTokenClassification"),
("layoutlm", "TFLayoutLMForTokenClassification"),
("bert", "TFBertForTokenClassification"),
("mobilebert", "TFMobileBertForTokenClassification"),
("xlnet", "TFXLNetForTokenClassification"),
("electra", "TFElectraForTokenClassification"),
("funnel", "TFFunnelForTokenClassification"),
("mpnet", "TFMPNetForTokenClassification"),
]
)
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# Model for Multiple Choice mapping
("rembert", "TFRemBertForMultipleChoice"),
("roformer", "TFRoFormerForMultipleChoice"),
("convbert", "TFConvBertForMultipleChoice"),
("camembert", "TFCamembertForMultipleChoice"),
("xlm", "TFXLMForMultipleChoice"),
("xlm-roberta", "TFXLMRobertaForMultipleChoice"),
("longformer", "TFLongformerForMultipleChoice"),
("roberta", "TFRobertaForMultipleChoice"),
("bert", "TFBertForMultipleChoice"),
("distilbert", "TFDistilBertForMultipleChoice"),
("mobilebert", "TFMobileBertForMultipleChoice"),
("xlnet", "TFXLNetForMultipleChoice"),
("flaubert", "TFFlaubertForMultipleChoice"),
("albert", "TFAlbertForMultipleChoice"),
("electra", "TFElectraForMultipleChoice"),
("funnel", "TFFunnelForMultipleChoice"),
("mpnet", "TFMPNetForMultipleChoice"),
]
)
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
[
("bert", "TFBertForNextSentencePrediction"),
("mobilebert", "TFMobileBertForNextSentencePrediction"),
]
)
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
)
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
[docs]class TFAutoModel(_BaseAutoModelClass):
_model_mapping = TF_MODEL_MAPPING
TFAutoModel = auto_class_update(TFAutoModel)
[docs]class TFAutoModelForPreTraining(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING
TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining")
# Private on purpose, the public class will add the deprecation warnings.
class _TFAutoModelWithLMHead(_BaseAutoModelClass):
_model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING
_TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling")
[docs]class TFAutoModelForCausalLM(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING
TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
TFAutoModelForMaskedLM = auto_class_update(TFAutoModelForMaskedLM, head_doc="masked language modeling")
[docs]class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
TFAutoModelForSeq2SeqLM = auto_class_update(
TFAutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
)
[docs]class TFAutoModelForSequenceClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
TFAutoModelForSequenceClassification = auto_class_update(
TFAutoModelForSequenceClassification, head_doc="sequence classification"
)
[docs]class TFAutoModelForQuestionAnswering(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering")
[docs]class TFAutoModelForTokenClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
TFAutoModelForTokenClassification = auto_class_update(
TFAutoModelForTokenClassification, head_doc="token classification"
)
TFAutoModelForMultipleChoice = auto_class_update(TFAutoModelForMultipleChoice, head_doc="multiple choice")
class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
TFAutoModelForNextSentencePrediction = auto_class_update(
TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
)
class TFAutoModelWithLMHead(_TFAutoModelWithLMHead):
@classmethod
def from_config(cls, config):
warnings.warn(
"The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
"`TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models and "
"`TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning,
)
return super().from_config(config)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
warnings.warn(
"The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
"`TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models and "
"`TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning,
)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)