mart9992's picture
m
06ba6ce
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class."""
import warnings
from collections import OrderedDict
from ...utils import logging
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
from .configuration_auto import CONFIG_MAPPING_NAMES
logger = logging.get_logger(__name__)
TF_MODEL_MAPPING_NAMES = OrderedDict(
[
# Base model mapping
("albert", "TFAlbertModel"),
("bart", "TFBartModel"),
("bert", "TFBertModel"),
("blenderbot", "TFBlenderbotModel"),
("blenderbot-small", "TFBlenderbotSmallModel"),
("blip", "TFBlipModel"),
("camembert", "TFCamembertModel"),
("clip", "TFCLIPModel"),
("convbert", "TFConvBertModel"),
("convnext", "TFConvNextModel"),
("ctrl", "TFCTRLModel"),
("cvt", "TFCvtModel"),
("data2vec-vision", "TFData2VecVisionModel"),
("deberta", "TFDebertaModel"),
("deberta-v2", "TFDebertaV2Model"),
("deit", "TFDeiTModel"),
("distilbert", "TFDistilBertModel"),
("dpr", "TFDPRQuestionEncoder"),
("efficientformer", "TFEfficientFormerModel"),
("electra", "TFElectraModel"),
("esm", "TFEsmModel"),
("flaubert", "TFFlaubertModel"),
("funnel", ("TFFunnelModel", "TFFunnelBaseModel")),
("gpt-sw3", "TFGPT2Model"),
("gpt2", "TFGPT2Model"),
("gptj", "TFGPTJModel"),
("groupvit", "TFGroupViTModel"),
("hubert", "TFHubertModel"),
("layoutlm", "TFLayoutLMModel"),
("layoutlmv3", "TFLayoutLMv3Model"),
("led", "TFLEDModel"),
("longformer", "TFLongformerModel"),
("lxmert", "TFLxmertModel"),
("marian", "TFMarianModel"),
("mbart", "TFMBartModel"),
("mobilebert", "TFMobileBertModel"),
("mobilevit", "TFMobileViTModel"),
("mpnet", "TFMPNetModel"),
("mt5", "TFMT5Model"),
("openai-gpt", "TFOpenAIGPTModel"),
("opt", "TFOPTModel"),
("pegasus", "TFPegasusModel"),
("regnet", "TFRegNetModel"),
("rembert", "TFRemBertModel"),
("resnet", "TFResNetModel"),
("roberta", "TFRobertaModel"),
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
("roformer", "TFRoFormerModel"),
("sam", "TFSamModel"),
("segformer", "TFSegformerModel"),
("speech_to_text", "TFSpeech2TextModel"),
("swin", "TFSwinModel"),
("t5", "TFT5Model"),
("tapas", "TFTapasModel"),
("transfo-xl", "TFTransfoXLModel"),
("vision-text-dual-encoder", "TFVisionTextDualEncoderModel"),
("vit", "TFViTModel"),
("vit_mae", "TFViTMAEModel"),
("wav2vec2", "TFWav2Vec2Model"),
("whisper", "TFWhisperModel"),
("xglm", "TFXGLMModel"),
("xlm", "TFXLMModel"),
("xlm-roberta", "TFXLMRobertaModel"),
("xlnet", "TFXLNetModel"),
]
)
TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
[
# Model for pre-training mapping
("albert", "TFAlbertForPreTraining"),
("bart", "TFBartForConditionalGeneration"),
("bert", "TFBertForPreTraining"),
("camembert", "TFCamembertForMaskedLM"),
("ctrl", "TFCTRLLMHeadModel"),
("distilbert", "TFDistilBertForMaskedLM"),
("electra", "TFElectraForPreTraining"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("funnel", "TFFunnelForPreTraining"),
("gpt-sw3", "TFGPT2LMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("lxmert", "TFLxmertForPreTraining"),
("mobilebert", "TFMobileBertForPreTraining"),
("mpnet", "TFMPNetForMaskedLM"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("roberta", "TFRobertaForMaskedLM"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
("t5", "TFT5ForConditionalGeneration"),
("tapas", "TFTapasForMaskedLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("vit_mae", "TFViTMAEForPreTraining"),
("xlm", "TFXLMWithLMHeadModel"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("xlnet", "TFXLNetLMHeadModel"),
]
)
TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
[
# Model with LM heads mapping
("albert", "TFAlbertForMaskedLM"),
("bart", "TFBartForConditionalGeneration"),
("bert", "TFBertForMaskedLM"),
("camembert", "TFCamembertForMaskedLM"),
("convbert", "TFConvBertForMaskedLM"),
("ctrl", "TFCTRLLMHeadModel"),
("distilbert", "TFDistilBertForMaskedLM"),
("electra", "TFElectraForMaskedLM"),
("esm", "TFEsmForMaskedLM"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("funnel", "TFFunnelForMaskedLM"),
("gpt-sw3", "TFGPT2LMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("gptj", "TFGPTJForCausalLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("led", "TFLEDForConditionalGeneration"),
("longformer", "TFLongformerForMaskedLM"),
("marian", "TFMarianMTModel"),
("mobilebert", "TFMobileBertForMaskedLM"),
("mpnet", "TFMPNetForMaskedLM"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("rembert", "TFRemBertForMaskedLM"),
("roberta", "TFRobertaForMaskedLM"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
("roformer", "TFRoFormerForMaskedLM"),
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
("t5", "TFT5ForConditionalGeneration"),
("tapas", "TFTapasForMaskedLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("whisper", "TFWhisperForConditionalGeneration"),
("xlm", "TFXLMWithLMHeadModel"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
("xlnet", "TFXLNetLMHeadModel"),
]
)
TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Causal LM mapping
("bert", "TFBertLMHeadModel"),
("camembert", "TFCamembertForCausalLM"),
("ctrl", "TFCTRLLMHeadModel"),
("gpt-sw3", "TFGPT2LMHeadModel"),
("gpt2", "TFGPT2LMHeadModel"),
("gptj", "TFGPTJForCausalLM"),
("openai-gpt", "TFOpenAIGPTLMHeadModel"),
("opt", "TFOPTForCausalLM"),
("rembert", "TFRemBertForCausalLM"),
("roberta", "TFRobertaForCausalLM"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForCausalLM"),
("roformer", "TFRoFormerForCausalLM"),
("transfo-xl", "TFTransfoXLLMHeadModel"),
("xglm", "TFXGLMForCausalLM"),
("xlm", "TFXLMWithLMHeadModel"),
("xlm-roberta", "TFXLMRobertaForCausalLM"),
("xlnet", "TFXLNetLMHeadModel"),
]
)
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES = OrderedDict(
[
("deit", "TFDeiTForMaskedImageModeling"),
("swin", "TFSwinForMaskedImageModeling"),
]
)
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Image-classsification
("convnext", "TFConvNextForImageClassification"),
("cvt", "TFCvtForImageClassification"),
("data2vec-vision", "TFData2VecVisionForImageClassification"),
("deit", ("TFDeiTForImageClassification", "TFDeiTForImageClassificationWithTeacher")),
(
"efficientformer",
("TFEfficientFormerForImageClassification", "TFEfficientFormerForImageClassificationWithTeacher"),
),
("mobilevit", "TFMobileViTForImageClassification"),
("regnet", "TFRegNetForImageClassification"),
("resnet", "TFResNetForImageClassification"),
("segformer", "TFSegformerForImageClassification"),
("swin", "TFSwinForImageClassification"),
("vit", "TFViTForImageClassification"),
]
)
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Zero Shot Image Classification mapping
("blip", "TFBlipModel"),
("clip", "TFCLIPModel"),
]
)
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES = OrderedDict(
[
# Model for Semantic Segmentation mapping
("data2vec-vision", "TFData2VecVisionForSemanticSegmentation"),
("mobilevit", "TFMobileViTForSemanticSegmentation"),
("segformer", "TFSegformerForSemanticSegmentation"),
]
)
TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("blip", "TFBlipForConditionalGeneration"),
("vision-encoder-decoder", "TFVisionEncoderDecoderModel"),
]
)
TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Masked LM mapping
("albert", "TFAlbertForMaskedLM"),
("bert", "TFBertForMaskedLM"),
("camembert", "TFCamembertForMaskedLM"),
("convbert", "TFConvBertForMaskedLM"),
("deberta", "TFDebertaForMaskedLM"),
("deberta-v2", "TFDebertaV2ForMaskedLM"),
("distilbert", "TFDistilBertForMaskedLM"),
("electra", "TFElectraForMaskedLM"),
("esm", "TFEsmForMaskedLM"),
("flaubert", "TFFlaubertWithLMHeadModel"),
("funnel", "TFFunnelForMaskedLM"),
("layoutlm", "TFLayoutLMForMaskedLM"),
("longformer", "TFLongformerForMaskedLM"),
("mobilebert", "TFMobileBertForMaskedLM"),
("mpnet", "TFMPNetForMaskedLM"),
("rembert", "TFRemBertForMaskedLM"),
("roberta", "TFRobertaForMaskedLM"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForMaskedLM"),
("roformer", "TFRoFormerForMaskedLM"),
("tapas", "TFTapasForMaskedLM"),
("xlm", "TFXLMWithLMHeadModel"),
("xlm-roberta", "TFXLMRobertaForMaskedLM"),
]
)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
("bart", "TFBartForConditionalGeneration"),
("blenderbot", "TFBlenderbotForConditionalGeneration"),
("blenderbot-small", "TFBlenderbotSmallForConditionalGeneration"),
("encoder-decoder", "TFEncoderDecoderModel"),
("led", "TFLEDForConditionalGeneration"),
("marian", "TFMarianMTModel"),
("mbart", "TFMBartForConditionalGeneration"),
("mt5", "TFMT5ForConditionalGeneration"),
("pegasus", "TFPegasusForConditionalGeneration"),
("t5", "TFT5ForConditionalGeneration"),
]
)
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("speech_to_text", "TFSpeech2TextForConditionalGeneration"),
("whisper", "TFWhisperForConditionalGeneration"),
]
)
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Sequence Classification mapping
("albert", "TFAlbertForSequenceClassification"),
("bart", "TFBartForSequenceClassification"),
("bert", "TFBertForSequenceClassification"),
("camembert", "TFCamembertForSequenceClassification"),
("convbert", "TFConvBertForSequenceClassification"),
("ctrl", "TFCTRLForSequenceClassification"),
("deberta", "TFDebertaForSequenceClassification"),
("deberta-v2", "TFDebertaV2ForSequenceClassification"),
("distilbert", "TFDistilBertForSequenceClassification"),
("electra", "TFElectraForSequenceClassification"),
("esm", "TFEsmForSequenceClassification"),
("flaubert", "TFFlaubertForSequenceClassification"),
("funnel", "TFFunnelForSequenceClassification"),
("gpt-sw3", "TFGPT2ForSequenceClassification"),
("gpt2", "TFGPT2ForSequenceClassification"),
("gptj", "TFGPTJForSequenceClassification"),
("layoutlm", "TFLayoutLMForSequenceClassification"),
("layoutlmv3", "TFLayoutLMv3ForSequenceClassification"),
("longformer", "TFLongformerForSequenceClassification"),
("mobilebert", "TFMobileBertForSequenceClassification"),
("mpnet", "TFMPNetForSequenceClassification"),
("openai-gpt", "TFOpenAIGPTForSequenceClassification"),
("rembert", "TFRemBertForSequenceClassification"),
("roberta", "TFRobertaForSequenceClassification"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForSequenceClassification"),
("roformer", "TFRoFormerForSequenceClassification"),
("tapas", "TFTapasForSequenceClassification"),
("transfo-xl", "TFTransfoXLForSequenceClassification"),
("xlm", "TFXLMForSequenceClassification"),
("xlm-roberta", "TFXLMRobertaForSequenceClassification"),
("xlnet", "TFXLNetForSequenceClassification"),
]
)
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Question Answering mapping
("albert", "TFAlbertForQuestionAnswering"),
("bert", "TFBertForQuestionAnswering"),
("camembert", "TFCamembertForQuestionAnswering"),
("convbert", "TFConvBertForQuestionAnswering"),
("deberta", "TFDebertaForQuestionAnswering"),
("deberta-v2", "TFDebertaV2ForQuestionAnswering"),
("distilbert", "TFDistilBertForQuestionAnswering"),
("electra", "TFElectraForQuestionAnswering"),
("flaubert", "TFFlaubertForQuestionAnsweringSimple"),
("funnel", "TFFunnelForQuestionAnswering"),
("gptj", "TFGPTJForQuestionAnswering"),
("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
("longformer", "TFLongformerForQuestionAnswering"),
("mobilebert", "TFMobileBertForQuestionAnswering"),
("mpnet", "TFMPNetForQuestionAnswering"),
("rembert", "TFRemBertForQuestionAnswering"),
("roberta", "TFRobertaForQuestionAnswering"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForQuestionAnswering"),
("roformer", "TFRoFormerForQuestionAnswering"),
("xlm", "TFXLMForQuestionAnsweringSimple"),
("xlm-roberta", "TFXLMRobertaForQuestionAnswering"),
("xlnet", "TFXLNetForQuestionAnsweringSimple"),
]
)
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = OrderedDict([("wav2vec2", "TFWav2Vec2ForSequenceClassification")])
TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
("layoutlm", "TFLayoutLMForQuestionAnswering"),
("layoutlmv3", "TFLayoutLMv3ForQuestionAnswering"),
]
)
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
[
# Model for Table Question Answering mapping
("tapas", "TFTapasForQuestionAnswering"),
]
)
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
[
# Model for Token Classification mapping
("albert", "TFAlbertForTokenClassification"),
("bert", "TFBertForTokenClassification"),
("camembert", "TFCamembertForTokenClassification"),
("convbert", "TFConvBertForTokenClassification"),
("deberta", "TFDebertaForTokenClassification"),
("deberta-v2", "TFDebertaV2ForTokenClassification"),
("distilbert", "TFDistilBertForTokenClassification"),
("electra", "TFElectraForTokenClassification"),
("esm", "TFEsmForTokenClassification"),
("flaubert", "TFFlaubertForTokenClassification"),
("funnel", "TFFunnelForTokenClassification"),
("layoutlm", "TFLayoutLMForTokenClassification"),
("layoutlmv3", "TFLayoutLMv3ForTokenClassification"),
("longformer", "TFLongformerForTokenClassification"),
("mobilebert", "TFMobileBertForTokenClassification"),
("mpnet", "TFMPNetForTokenClassification"),
("rembert", "TFRemBertForTokenClassification"),
("roberta", "TFRobertaForTokenClassification"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForTokenClassification"),
("roformer", "TFRoFormerForTokenClassification"),
("xlm", "TFXLMForTokenClassification"),
("xlm-roberta", "TFXLMRobertaForTokenClassification"),
("xlnet", "TFXLNetForTokenClassification"),
]
)
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
[
# Model for Multiple Choice mapping
("albert", "TFAlbertForMultipleChoice"),
("bert", "TFBertForMultipleChoice"),
("camembert", "TFCamembertForMultipleChoice"),
("convbert", "TFConvBertForMultipleChoice"),
("deberta-v2", "TFDebertaV2ForMultipleChoice"),
("distilbert", "TFDistilBertForMultipleChoice"),
("electra", "TFElectraForMultipleChoice"),
("flaubert", "TFFlaubertForMultipleChoice"),
("funnel", "TFFunnelForMultipleChoice"),
("longformer", "TFLongformerForMultipleChoice"),
("mobilebert", "TFMobileBertForMultipleChoice"),
("mpnet", "TFMPNetForMultipleChoice"),
("rembert", "TFRemBertForMultipleChoice"),
("roberta", "TFRobertaForMultipleChoice"),
("roberta-prelayernorm", "TFRobertaPreLayerNormForMultipleChoice"),
("roformer", "TFRoFormerForMultipleChoice"),
("xlm", "TFXLMForMultipleChoice"),
("xlm-roberta", "TFXLMRobertaForMultipleChoice"),
("xlnet", "TFXLNetForMultipleChoice"),
]
)
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES = OrderedDict(
[
("bert", "TFBertForNextSentencePrediction"),
("mobilebert", "TFMobileBertForNextSentencePrediction"),
]
)
TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = OrderedDict(
[
("sam", "TFSamModel"),
]
)
TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES = OrderedDict(
[
("albert", "TFAlbertModel"),
("bert", "TFBertModel"),
("convbert", "TFConvBertModel"),
("deberta", "TFDebertaModel"),
("deberta-v2", "TFDebertaV2Model"),
("distilbert", "TFDistilBertModel"),
("electra", "TFElectraModel"),
("flaubert", "TFFlaubertModel"),
("longformer", "TFLongformerModel"),
("mobilebert", "TFMobileBertModel"),
("mt5", "TFMT5EncoderModel"),
("rembert", "TFRemBertModel"),
("roberta", "TFRobertaModel"),
("roberta-prelayernorm", "TFRobertaPreLayerNormModel"),
("roformer", "TFRoFormerModel"),
("t5", "TFT5EncoderModel"),
("xlm", "TFXLMModel"),
("xlm-roberta", "TFXLMRobertaModel"),
]
)
TF_MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_MAPPING_NAMES)
TF_MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_PRETRAINING_MAPPING_NAMES)
TF_MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_WITH_LM_HEAD_MAPPING_NAMES)
TF_MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES)
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES
)
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING_NAMES
)
TF_MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES)
TF_MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASKED_LM_MAPPING_NAMES)
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
)
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
)
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
)
TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES
)
TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES
)
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES
)
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES
)
TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
)
TF_MODEL_FOR_MASK_GENERATION_MAPPING = _LazyAutoMapping(
CONFIG_MAPPING_NAMES, TF_MODEL_FOR_MASK_GENERATION_MAPPING_NAMES
)
TF_MODEL_FOR_TEXT_ENCODING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES)
class TFAutoModelForMaskGeneration(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASK_GENERATION_MAPPING
class TFAutoModelForTextEncoding(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_TEXT_ENCODING_MAPPING
class TFAutoModel(_BaseAutoModelClass):
_model_mapping = TF_MODEL_MAPPING
TFAutoModel = auto_class_update(TFAutoModel)
class TFAutoModelForAudioClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
TFAutoModelForAudioClassification = auto_class_update(
TFAutoModelForAudioClassification, head_doc="audio classification"
)
class TFAutoModelForPreTraining(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_PRETRAINING_MAPPING
TFAutoModelForPreTraining = auto_class_update(TFAutoModelForPreTraining, head_doc="pretraining")
# Private on purpose, the public class will add the deprecation warnings.
class _TFAutoModelWithLMHead(_BaseAutoModelClass):
_model_mapping = TF_MODEL_WITH_LM_HEAD_MAPPING
_TFAutoModelWithLMHead = auto_class_update(_TFAutoModelWithLMHead, head_doc="language modeling")
class TFAutoModelForCausalLM(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_CAUSAL_LM_MAPPING
TFAutoModelForCausalLM = auto_class_update(TFAutoModelForCausalLM, head_doc="causal language modeling")
class TFAutoModelForMaskedImageModeling(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING
TFAutoModelForMaskedImageModeling = auto_class_update(
TFAutoModelForMaskedImageModeling, head_doc="masked image modeling"
)
class TFAutoModelForImageClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
TFAutoModelForImageClassification = auto_class_update(
TFAutoModelForImageClassification, head_doc="image classification"
)
class TFAutoModelForZeroShotImageClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING
TFAutoModelForZeroShotImageClassification = auto_class_update(
TFAutoModelForZeroShotImageClassification, head_doc="zero-shot image classification"
)
class TFAutoModelForSemanticSegmentation(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
TFAutoModelForSemanticSegmentation = auto_class_update(
TFAutoModelForSemanticSegmentation, head_doc="semantic segmentation"
)
class TFAutoModelForVision2Seq(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_VISION_2_SEQ_MAPPING
TFAutoModelForVision2Seq = auto_class_update(TFAutoModelForVision2Seq, head_doc="vision-to-text modeling")
class TFAutoModelForMaskedLM(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MASKED_LM_MAPPING
TFAutoModelForMaskedLM = auto_class_update(TFAutoModelForMaskedLM, head_doc="masked language modeling")
class TFAutoModelForSeq2SeqLM(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
TFAutoModelForSeq2SeqLM = auto_class_update(
TFAutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base"
)
class TFAutoModelForSequenceClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
TFAutoModelForSequenceClassification = auto_class_update(
TFAutoModelForSequenceClassification, head_doc="sequence classification"
)
class TFAutoModelForQuestionAnswering(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering")
class TFAutoModelForDocumentQuestionAnswering(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING
TFAutoModelForDocumentQuestionAnswering = auto_class_update(
TFAutoModelForDocumentQuestionAnswering,
head_doc="document question answering",
checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3',
)
class TFAutoModelForTableQuestionAnswering(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
TFAutoModelForTableQuestionAnswering = auto_class_update(
TFAutoModelForTableQuestionAnswering,
head_doc="table question answering",
checkpoint_for_example="google/tapas-base-finetuned-wtq",
)
class TFAutoModelForTokenClassification(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
TFAutoModelForTokenClassification = auto_class_update(
TFAutoModelForTokenClassification, head_doc="token classification"
)
class TFAutoModelForMultipleChoice(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
TFAutoModelForMultipleChoice = auto_class_update(TFAutoModelForMultipleChoice, head_doc="multiple choice")
class TFAutoModelForNextSentencePrediction(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
TFAutoModelForNextSentencePrediction = auto_class_update(
TFAutoModelForNextSentencePrediction, head_doc="next sentence prediction"
)
class TFAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
_model_mapping = TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
TFAutoModelForSpeechSeq2Seq = auto_class_update(
TFAutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling"
)
class TFAutoModelWithLMHead(_TFAutoModelWithLMHead):
@classmethod
def from_config(cls, config):
warnings.warn(
"The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
" `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
" and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning,
)
return super().from_config(config)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
warnings.warn(
"The class `TFAutoModelWithLMHead` is deprecated and will be removed in a future version. Please use"
" `TFAutoModelForCausalLM` for causal language models, `TFAutoModelForMaskedLM` for masked language models"
" and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning,
)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)