Source code for transformers.models.auto.modeling_tf_auto

# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class. """


import warnings
from collections import OrderedDict

from ...utils import logging

# Add modeling imports here
from ..albert.modeling_tf_albert import (
    TFAlbertForMaskedLM,
    TFAlbertForMultipleChoice,
    TFAlbertForPreTraining,
    TFAlbertForQuestionAnswering,
    TFAlbertForSequenceClassification,
    TFAlbertForTokenClassification,
    TFAlbertModel,
)
from ..bart.modeling_tf_bart import TFBartForConditionalGeneration, TFBartModel
from ..bert.modeling_tf_bert import (
    TFBertForMaskedLM,
    TFBertForMultipleChoice,
    TFBertForNextSentencePrediction,
    TFBertForPreTraining,
    TFBertForQuestionAnswering,
    TFBertForSequenceClassification,
    TFBertForTokenClassification,
    TFBertLMHeadModel,
    TFBertModel,
)
from ..blenderbot.modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
from ..blenderbot_small.modeling_tf_blenderbot_small import (
    TFBlenderbotSmallForConditionalGeneration,
    TFBlenderbotSmallModel,
)
from ..camembert.modeling_tf_camembert import (
    TFCamembertForMaskedLM,
    TFCamembertForMultipleChoice,
    TFCamembertForQuestionAnswering,
    TFCamembertForSequenceClassification,
    TFCamembertForTokenClassification,
    TFCamembertModel,
)
from ..convbert.modeling_tf_convbert import (
    TFConvBertForMaskedLM,
    TFConvBertForMultipleChoice,
    TFConvBertForQuestionAnswering,
    TFConvBertForSequenceClassification,
    TFConvBertForTokenClassification,
    TFConvBertModel,
)
from ..ctrl.modeling_tf_ctrl import TFCTRLForSequenceClassification, TFCTRLLMHeadModel, TFCTRLModel
from ..distilbert.modeling_tf_distilbert import (
    TFDistilBertForMaskedLM,
    TFDistilBertForMultipleChoice,
    TFDistilBertForQuestionAnswering,
    TFDistilBertForSequenceClassification,
    TFDistilBertForTokenClassification,
    TFDistilBertModel,
)
from ..dpr.modeling_tf_dpr import TFDPRQuestionEncoder
from ..electra.modeling_tf_electra import (
    TFElectraForMaskedLM,
    TFElectraForMultipleChoice,
    TFElectraForPreTraining,
    TFElectraForQuestionAnswering,
    TFElectraForSequenceClassification,
    TFElectraForTokenClassification,
    TFElectraModel,
)
from ..flaubert.modeling_tf_flaubert import (
    TFFlaubertForMultipleChoice,
    TFFlaubertForQuestionAnsweringSimple,
    TFFlaubertForSequenceClassification,
    TFFlaubertForTokenClassification,
    TFFlaubertModel,
    TFFlaubertWithLMHeadModel,
)
from ..funnel.modeling_tf_funnel import (
    TFFunnelBaseModel,
    TFFunnelForMaskedLM,
    TFFunnelForMultipleChoice,
    TFFunnelForPreTraining,
    TFFunnelForQuestionAnswering,
    TFFunnelForSequenceClassification,
    TFFunnelForTokenClassification,
    TFFunnelModel,
)
from ..gpt2.modeling_tf_gpt2 import TFGPT2ForSequenceClassification, TFGPT2LMHeadModel, TFGPT2Model
from ..hubert.modeling_tf_hubert import TFHubertModel
from ..layoutlm.modeling_tf_layoutlm import (
    TFLayoutLMForMaskedLM,
    TFLayoutLMForSequenceClassification,
    TFLayoutLMForTokenClassification,
    TFLayoutLMModel,
)
from ..led.modeling_tf_led import TFLEDForConditionalGeneration, TFLEDModel
from ..longformer.modeling_tf_longformer import (
    TFLongformerForMaskedLM,
    TFLongformerForMultipleChoice,
    TFLongformerForQuestionAnswering,
    TFLongformerForSequenceClassification,
    TFLongformerForTokenClassification,
    TFLongformerModel,
)
from ..lxmert.modeling_tf_lxmert import TFLxmertForPreTraining, TFLxmertModel
from ..marian.modeling_tf_marian import TFMarianModel, TFMarianMTModel
from ..mbart.modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel
from ..mobilebert.modeling_tf_mobilebert import (
    TFMobileBertForMaskedLM,
    TFMobileBertForMultipleChoice,
    TFMobileBertForNextSentencePrediction,
    TFMobileBertForPreTraining,
    TFMobileBertForQuestionAnswering,
    TFMobileBertForSequenceClassification,
    TFMobileBertForTokenClassification,
    TFMobileBertModel,
)
from ..mpnet.modeling_tf_mpnet import (
    TFMPNetForMaskedLM,
    TFMPNetForMultipleChoice,
    TFMPNetForQuestionAnswering,
    TFMPNetForSequenceClassification,
    TFMPNetForTokenClassification,
    TFMPNetModel,
)
from ..mt5.modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
from ..openai.modeling_tf_openai import TFOpenAIGPTForSequenceClassification, TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
from ..roberta.modeling_tf_roberta import (
    TFRobertaForMaskedLM,
    TFRobertaForMultipleChoice,
    TFRobertaForQuestionAnswering,
    TFRobertaForSequenceClassification,
    TFRobertaForTokenClassification,
    TFRobertaModel,
)
from ..roformer.modeling_tf_roformer import (
    TFRoFormerForCausalLM,
    TFRoFormerForMaskedLM,
    TFRoFormerForMultipleChoice,
    TFRoFormerForQuestionAnswering,
    TFRoFormerForSequenceClassification,
    TFRoFormerForTokenClassification,
    TFRoFormerModel,
)
from ..t5.modeling_tf_t5 import TFT5ForConditionalGeneration, TFT5Model
from ..transfo_xl.modeling_tf_transfo_xl import (
    TFTransfoXLForSequenceClassification,
    TFTransfoXLLMHeadModel,
    TFTransfoXLModel,
)
from ..wav2vec2.modeling_tf_wav2vec2 import TFWav2Vec2Model
from ..xlm.modeling_tf_xlm import (
    TFXLMForMultipleChoice,
    TFXLMForQuestionAnsweringSimple,
    TFXLMForSequenceClassification,
    TFXLMForTokenClassification,
    TFXLMModel,
    TFXLMWithLMHeadModel,
)
from ..xlm_roberta.modeling_tf_xlm_roberta import (
    TFXLMRobertaForMaskedLM,
    TFXLMRobertaForMultipleChoice,
    TFXLMRobertaForQuestionAnswering,
    TFXLMRobertaForSequenceClassification,
    TFXLMRobertaForTokenClassification,
    TFXLMRobertaModel,
)
from ..xlnet.modeling_tf_xlnet import (
    TFXLNetForMultipleChoice,
    TFXLNetForQuestionAnsweringSimple,
    TFXLNetForSequenceClassification,
    TFXLNetForTokenClassification,
    TFXLNetLMHeadModel,
    TFXLNetModel,
)
from .auto_factory import _BaseAutoModelClass, auto_class_update
from .configuration_auto import (
    AlbertConfig,
    BartConfig,
    BertConfig,
    BlenderbotConfig,
    BlenderbotSmallConfig,
    CamembertConfig,
    ConvBertConfig,
    CTRLConfig,
    DistilBertConfig,
    DPRConfig,
    ElectraConfig,
    FlaubertConfig,
    FunnelConfig,
    GPT2Config,
    HubertConfig,
    LayoutLMConfig,
    LEDConfig,
    LongformerConfig,
    LxmertConfig,
    MarianConfig,
    MBartConfig,
    MobileBertConfig,
    MPNetConfig,
    MT5Config,
    OpenAIGPTConfig,
    PegasusConfig,
    RobertaConfig,
    RoFormerConfig,
    T5Config,
    TransfoXLConfig,
    Wav2Vec2Config,
    XLMConfig,
    XLMRobertaConfig,
    XLNetConfig,
)


logger = logging.get_logger(__name__)


TF_MODEL_MAPPING = OrderedDict(
    [
        # Base model mapping
        (RoFormerConfig, TFRoFormerModel),
        (ConvBertConfig, TFConvBertModel),
        (LEDConfig, TFLEDModel),
        (LxmertConfig, TFLxmertModel),
        (MT5Config, TFMT5Model),
        (T5Config, TFT5Model),
        (DistilBertConfig, TFDistilBertModel),
        (AlbertConfig, TFAlbertModel),
        (BartConfig, TFBartModel),
        (CamembertConfig, TFCamembertModel),
        (XLMRobertaConfig, TFXLMRobertaModel),
        (LongformerConfig, TFLongformerModel),
        (RobertaConfig, TFRobertaModel),
        (LayoutLMConfig, TFLayoutLMModel),
        (BertConfig, TFBertModel),
        (OpenAIGPTConfig, TFOpenAIGPTModel),
        (GPT2Config, TFGPT2Model),
        (MobileBertConfig, TFMobileBertModel),
        (TransfoXLConfig, TFTransfoXLModel),
        (XLNetConfig, TFXLNetModel),
        (FlaubertConfig, TFFlaubertModel),
        (XLMConfig, TFXLMModel),
        (CTRLConfig, TFCTRLModel),
        (ElectraConfig, TFElectraModel),
        (FunnelConfig, (TFFunnelModel, TFFunnelBaseModel)),
        (DPRConfig, TFDPRQuestionEncoder),
        (MPNetConfig, TFMPNetModel),
        (BartConfig, TFBartModel),
        (MBartConfig, TFMBartModel),
        (MarianConfig, TFMarianModel),
        (PegasusConfig, TFPegasusModel),
        (BlenderbotConfig, TFBlenderbotModel),
        (BlenderbotSmallConfig, TFBlenderbotSmallModel),
        (Wav2Vec2Config, TFWav2Vec2Model),
        (HubertConfig, TFHubertModel),
    ]
)

TF_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
    [
        # Model for pre-training mapping
        (LxmertConfig, TFLxmertForPreTraining),
        (T5Config, TFT5ForConditionalGeneration),
        (DistilBertConfig, TFDistilBertForMaskedLM),
        (AlbertConfig, TFAlbertForPreTraining),
        (BartConfig, TFBartForConditionalGeneration),
        (CamembertConfig, TFCamembertForMaskedLM),
        (XLMRobertaConfig, TFXLMRobertaForMaskedLM),
        (RobertaConfig, TFRobertaForMaskedLM),
        (LayoutLMConfig, TFLayoutLMForMaskedLM),
        (BertConfig, TFBertForPreTraining),
        (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
        (GPT2Config, TFGPT2LMHeadModel),
        (MobileBertConfig, TFMobileBertForPreTraining),
        (TransfoXLConfig, TFTransfoXLLMHeadModel),
        (XLNetConfig, TFXLNetLMHeadModel),
        (FlaubertConfig, TFFlaubertWithLMHeadModel),
        (XLMConfig, TFXLMWithLMHeadModel),
        (CTRLConfig, TFCTRLLMHeadModel),
        (ElectraConfig, TFElectraForPreTraining),
        (FunnelConfig, TFFunnelForPreTraining),
        (MPNetConfig, TFMPNetForMaskedLM),
    ]
)

TF_MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
    [
        # Model with LM heads mapping
        (RoFormerConfig, TFRoFormerForMaskedLM),
        (ConvBertConfig, TFConvBertForMaskedLM),
        (LEDConfig, TFLEDForConditionalGeneration),
        (T5Config, TFT5ForConditionalGeneration),
        (DistilBertConfig, TFDistilBertForMaskedLM),
        (AlbertConfig, TFAlbertForMaskedLM),
        (MarianConfig, TFMarianMTModel),
        (BartConfig, TFBartForConditionalGeneration),
        (CamembertConfig, TFCamembertForMaskedLM),
        (XLMRobertaConfig, TFXLMRobertaForMaskedLM),
        (LongformerConfig, TFLongformerForMaskedLM),
        (RobertaConfig, TFRobertaForMaskedLM),
        (LayoutLMConfig, TFLayoutLMForMaskedLM),
        (BertConfig, TFBertForMaskedLM),
        (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
        (GPT2Config, TFGPT2LMHeadModel),
        (MobileBertConfig, TFMobileBertForMaskedLM),
        (TransfoXLConfig, TFTransfoXLLMHeadModel),
        (XLNetConfig, TFXLNetLMHeadModel),
        (FlaubertConfig, TFFlaubertWithLMHeadModel),
        (XLMConfig, TFXLMWithLMHeadModel),
        (CTRLConfig, TFCTRLLMHeadModel),
        (ElectraConfig, TFElectraForMaskedLM),
        (FunnelConfig, TFFunnelForMaskedLM),
        (MPNetConfig, TFMPNetForMaskedLM),
    ]
)

TF_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
    [
        # Model for Causal LM mapping
        (RoFormerConfig, TFRoFormerForCausalLM),
        (BertConfig, TFBertLMHeadModel),
        (OpenAIGPTConfig, TFOpenAIGPTLMHeadModel),
        (GPT2Config, TFGPT2LMHeadModel),
        (TransfoXLConfig, TFTransfoXLLMHeadModel),
        (XLNetConfig, TFXLNetLMHeadModel),
        (
            XLMConfig,
            TFXLMWithLMHeadModel,
        ),  # XLM can be MLM and CLM => model should be split similar to BERT; leave here for now
        (CTRLConfig, TFCTRLLMHeadModel),
    ]
)

TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
    [
        # Model for Masked LM mapping
        (RoFormerConfig, TFRoFormerForMaskedLM),
        (ConvBertConfig, TFConvBertForMaskedLM),
        (DistilBertConfig, TFDistilBertForMaskedLM),
        (AlbertConfig, TFAlbertForMaskedLM),
        (CamembertConfig, TFCamembertForMaskedLM),
        (XLMRobertaConfig, TFXLMRobertaForMaskedLM),
        (LongformerConfig, TFLongformerForMaskedLM),
        (RobertaConfig, TFRobertaForMaskedLM),
        (LayoutLMConfig, TFLayoutLMForMaskedLM),
        (BertConfig, TFBertForMaskedLM),
        (MobileBertConfig, TFMobileBertForMaskedLM),
        (FlaubertConfig, TFFlaubertWithLMHeadModel),
        (XLMConfig, TFXLMWithLMHeadModel),
        (ElectraConfig, TFElectraForMaskedLM),
        (FunnelConfig, TFFunnelForMaskedLM),
        (MPNetConfig, TFMPNetForMaskedLM),
    ]
)


TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
    [
        # Model for Seq2Seq Causal LM mapping
        (LEDConfig, TFLEDForConditionalGeneration),
        (MT5Config, TFMT5ForConditionalGeneration),
        (T5Config, TFT5ForConditionalGeneration),
        (MarianConfig, TFMarianMTModel),
        (MBartConfig, TFMBartForConditionalGeneration),
        (PegasusConfig, TFPegasusForConditionalGeneration),
        (BlenderbotConfig, TFBlenderbotForConditionalGeneration),
        (BlenderbotSmallConfig, TFBlenderbotSmallForConditionalGeneration),
        (BartConfig, TFBartForConditionalGeneration),
    ]
)

TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
    [
        # Model for Sequence Classification mapping
        (RoFormerConfig, TFRoFormerForSequenceClassification),
        (ConvBertConfig, TFConvBertForSequenceClassification),
        (DistilBertConfig, TFDistilBertForSequenceClassification),
        (AlbertConfig, TFAlbertForSequenceClassification),
        (CamembertConfig, TFCamembertForSequenceClassification),
        (XLMRobertaConfig, TFXLMRobertaForSequenceClassification),
        (LongformerConfig, TFLongformerForSequenceClassification),
        (RobertaConfig, TFRobertaForSequenceClassification),
        (LayoutLMConfig, TFLayoutLMForSequenceClassification),
        (BertConfig, TFBertForSequenceClassification),
        (XLNetConfig, TFXLNetForSequenceClassification),
        (MobileBertConfig, TFMobileBertForSequenceClassification),
        (FlaubertConfig, TFFlaubertForSequenceClassification),
        (XLMConfig, TFXLMForSequenceClassification),
        (ElectraConfig, TFElectraForSequenceClassification),
        (FunnelConfig, TFFunnelForSequenceClassification),
        (GPT2Config, TFGPT2ForSequenceClassification),
        (MPNetConfig, TFMPNetForSequenceClassification),
        (OpenAIGPTConfig, TFOpenAIGPTForSequenceClassification),
        (TransfoXLConfig, TFTransfoXLForSequenceClassification),
        (CTRLConfig, TFCTRLForSequenceClassification),
    ]
)

TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
    [
        # Model for Question Answering mapping
        (RoFormerConfig, TFRoFormerForQuestionAnswering),
        (ConvBertConfig, TFConvBertForQuestionAnswering),
        (DistilBertConfig, TFDistilBertForQuestionAnswering),
        (AlbertConfig, TFAlbertForQuestionAnswering),
        (CamembertConfig, TFCamembertForQuestionAnswering),
        (XLMRobertaConfig, TFXLMRobertaForQuestionAnswering),
        (LongformerConfig, TFLongformerForQuestionAnswering),
        (RobertaConfig, TFRobertaForQuestionAnswering),
        (BertConfig, TFBertForQuestionAnswering),
        (XLNetConfig, TFXLNetForQuestionAnsweringSimple),
        (MobileBertConfig, TFMobileBertForQuestionAnswering),
        (FlaubertConfig, TFFlaubertForQuestionAnsweringSimple),
        (XLMConfig, TFXLMForQuestionAnsweringSimple),
        (ElectraConfig, TFElectraForQuestionAnswering),
        (FunnelConfig, TFFunnelForQuestionAnswering),
        (MPNetConfig, TFMPNetForQuestionAnswering),
    ]
)

TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
    [
        # Model for Token Classification mapping
        (RoFormerConfig, TFRoFormerForTokenClassification),
        (ConvBertConfig, TFConvBertForTokenClassification),
        (DistilBertConfig, TFDistilBertForTokenClassification),
        (AlbertConfig, TFAlbertForTokenClassification),
        (CamembertConfig, TFCamembertForTokenClassification),
        (FlaubertConfig, TFFlaubertForTokenClassification),
        (XLMConfig, TFXLMForTokenClassification),
        (XLMRobertaConfig, TFXLMRobertaForTokenClassification),
        (LongformerConfig, TFLongformerForTokenClassification),
        (RobertaConfig, TFRobertaForTokenClassification),
        (LayoutLMConfig, TFLayoutLMForTokenClassification),
        (BertConfig, TFBertForTokenClassification),
        (MobileBertConfig, TFMobileBertForTokenClassification),
        (XLNetConfig, TFXLNetForTokenClassification),
        (ElectraConfig, TFElectraForTokenClassification),
        (FunnelConfig, TFFunnelForTokenClassification),
        (MPNetConfig, TFMPNetForTokenClassification),
    ]
)

TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
    [
        # Model for Multiple Choice mapping
        (RoFormerConfig, TFRoFormerForMultipleChoice),
        (ConvBertConfig, TFConvBertForMultipleChoice),
        (CamembertConfig, TFCamembertForMultipleChoice),
        (XLMConfig, TFXLMForMultipleChoice),
        (XLMRobertaConfig, TFXLMRobertaForMultipleChoice),
        (LongformerConfig, TFLongformerForMultipleChoice),
        (RobertaConfig, TFRobertaForMultipleChoice),
        (BertConfig, TFBertForMultipleChoice),
        (DistilBertConfig, TFDistilBertForMultipleChoice),
        (MobileBertConfig, TFMobileBertForMultipleChoice),
        (XLNetConfig, TFXLNetForMultipleChoice),
        (FlaubertConfig, TFFlaubertForMultipleChoice),
        (AlbertConfig, TFAlbertForMultipleChoice),
        (ElectraConfig, TFElectraForMultipleChoice),
        (FunnelConfig, TFFunnelForMultipleChoice),
        (MPNetConfig, TFMPNetForMultipleChoice),
    ]
)

TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
    [
        (BertConfig, TFBertForNextSentencePrediction),
        (MobileBertConfig, TFMobileBertForNextSentencePrediction),
    ]
)


[docs]class TFAutoModel(_BaseAutoModelClass): _model_mapping = TF_MODEL_MAPPING
TFAutoModel = auto_class_update(TFAutoModel)
[docs]class TFAutoModelForPreTraining(_BaseAutoModelClass): _model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING
TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining") # Private on purpose, the public class will add the deprecation warnings. class _TFAutoModelWithLMHead(_BaseAutoModelClass): _model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING _TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling")
[docs]class TFAutoModelForCausalLM(_BaseAutoModelClass): _model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING
TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
[docs]class TFAutoModelForMaskedLM(_BaseAutoModelClass): _model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
TFAutoModelForMaskedLM = auto_class_update(TFAutoModelForMaskedLM, head_doc="masked language modeling")
[docs]class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass): _model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
TFAutoModelForSeq2SeqLM = auto_class_update( TFAutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" )
[docs]class TFAutoModelForSequenceClassification(_BaseAutoModelClass): _model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
TFAutoModelForSequenceClassification = auto_class_update( TFAutoModelForSequenceClassification, head_doc="sequence classification" )
[docs]class TFAutoModelForQuestionAnswering(_BaseAutoModelClass): _model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering")
[docs]class TFAutoModelForTokenClassification(_BaseAutoModelClass): _model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
TFAutoModelForTokenClassification = auto_class_update( TFAutoModelForTokenClassification, head_doc="token classification" )
[docs]class TFAutoModelForMultipleChoice(_BaseAutoModelClass): _model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
TFAutoModelForMultipleChoice = auto_class_update(TFAutoModelForMultipleChoice, head_doc="multiple choice") class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass): _model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING TFAutoModelForNextSentencePrediction = auto_class_update( TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction" ) class TFAutoModelWithLMHead(_TFAutoModelWithLMHead): @classmethod def from_config(cls, config): warnings.warn( "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " "`TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models and " "`TFAutoModelForSeq2SeqLM` for encoder-decoder models.", FutureWarning, ) return super().from_config(config) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): warnings.warn( "The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " "`TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models and " "`TFAutoModelForSeq2SeqLM` for encoder-decoder models.", FutureWarning, ) return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)