|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" Auto Model class.""" |
|
|
|
import warnings |
|
from collections import OrderedDict |
|
|
|
from transformers.utils import logging |
|
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update |
|
from .configuration_auto import CONFIG_MAPPING_NAMES |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
MODEL_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("roformer", "RoFormerModel"), |
|
("longformer", "LongformerModel"), |
|
] |
|
) |
|
|
|
MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("longformer", "LongformerForMaskedLM"), |
|
] |
|
) |
|
|
|
MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("roformer", "RoFormerForMaskedLM"), |
|
("longformer", "LongformerForMaskedLM"), |
|
] |
|
) |
|
|
|
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("roformer", "RoFormerForCausalLM"), |
|
] |
|
) |
|
|
|
|
|
MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("roformer", "RoFormerForMaskedLM"), |
|
("longformer", "LongformerForMaskedLM"), |
|
] |
|
) |
|
|
|
|
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("t5", "T5ForConditionalGeneration"), |
|
|
|
] |
|
) |
|
|
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict( |
|
[ |
|
("speech-encoder-decoder", "SpeechEncoderDecoderModel"), |
|
("speech_to_text", "Speech2TextForConditionalGeneration"), |
|
] |
|
) |
|
|
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("roformer", "RoFormerForSequenceClassification"), |
|
("longformer", "LongformerForSequenceClassification"), |
|
] |
|
) |
|
|
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("roformer", "RoFormerForQuestionAnswering"), |
|
("longformer", "LongformerForQuestionAnswering"), |
|
] |
|
) |
|
|
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("tapas", "TapasForQuestionAnswering"), |
|
] |
|
) |
|
|
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("roformer", "RoFormerForTokenClassification"), |
|
("longformer", "LongformerForTokenClassification"), |
|
] |
|
) |
|
|
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict( |
|
[ |
|
|
|
("roformer", "RoFormerForMultipleChoice"), |
|
("longformer", "LongformerForMultipleChoice"), |
|
] |
|
) |
|
|
|
|
|
|
|
|
|
MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES) |
|
|
|
MODEL_FOR_PRETRAINING_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_PRETRAINING_MAPPING_NAMES) |
|
|
|
MODEL_WITH_LM_HEAD_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_WITH_LM_HEAD_MAPPING_NAMES) |
|
|
|
MODEL_FOR_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) |
|
|
|
MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES) |
|
|
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES |
|
) |
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES |
|
) |
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES |
|
) |
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES |
|
) |
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = _LazyAutoMapping( |
|
CONFIG_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES |
|
) |
|
MODEL_FOR_MULTIPLE_CHOICE_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES) |
|
|
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES) |
|
|
|
|
|
|
|
class AutoModel(_BaseAutoModelClass): |
|
_model_mapping = MODEL_MAPPING |
|
|
|
|
|
AutoModel = auto_class_update(AutoModel) |
|
|
|
|
|
class AutoModelForPreTraining(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_PRETRAINING_MAPPING |
|
|
|
|
|
AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining") |
|
|
|
|
|
|
|
class _AutoModelWithLMHead(_BaseAutoModelClass): |
|
_model_mapping = MODEL_WITH_LM_HEAD_MAPPING |
|
|
|
|
|
_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling") |
|
|
|
|
|
class AutoModelForCausalLM(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING |
|
|
|
|
|
AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") |
|
|
|
|
|
class AutoModelForMaskedLM(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING |
|
|
|
|
|
AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling") |
|
|
|
|
|
class AutoModelForSeq2SeqLM(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING |
|
|
|
|
|
AutoModelForSeq2SeqLM = auto_class_update( |
|
AutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" |
|
) |
|
|
|
|
|
class AutoModelForSequenceClassification(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING |
|
|
|
|
|
AutoModelForSequenceClassification = auto_class_update( |
|
AutoModelForSequenceClassification, head_doc="sequence classification" |
|
) |
|
|
|
|
|
class AutoModelForQuestionAnswering(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING |
|
|
|
|
|
AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering") |
|
|
|
|
|
class AutoModelForTableQuestionAnswering(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING |
|
|
|
|
|
AutoModelForTableQuestionAnswering = auto_class_update( |
|
AutoModelForTableQuestionAnswering, |
|
head_doc="table question answering", |
|
checkpoint_for_example="google/tapas-base-finetuned-wtq", |
|
) |
|
|
|
|
|
class AutoModelForTokenClassification(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING |
|
|
|
|
|
AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification") |
|
|
|
|
|
class AutoModelForMultipleChoice(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING |
|
|
|
|
|
AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice") |
|
|
|
|
|
|
|
class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): |
|
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING |
|
|
|
|
|
AutoModelForSpeechSeq2Seq = auto_class_update( |
|
AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeing" |
|
) |
|
|
|
|
|
|
|
class AutoModelWithLMHead(_AutoModelWithLMHead): |
|
@classmethod |
|
def from_config(cls, config): |
|
warnings.warn( |
|
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " |
|
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " |
|
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", |
|
FutureWarning, |
|
) |
|
return super().from_config(config) |
|
|
|
@classmethod |
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
|
warnings.warn( |
|
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " |
|
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " |
|
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", |
|
FutureWarning, |
|
) |
|
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) |
|
|