Source code for transformers.models.auto.modeling_flax_auto

# coding=utf-8
# Copyright 2018 The Google Flax Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class. """


from collections import OrderedDict

from ...utils import logging
from ..bart.modeling_flax_bart import (
    FlaxBartForConditionalGeneration,
    FlaxBartForQuestionAnswering,
    FlaxBartForSequenceClassification,
    FlaxBartModel,
)
from ..bert.modeling_flax_bert import (
    FlaxBertForMaskedLM,
    FlaxBertForMultipleChoice,
    FlaxBertForNextSentencePrediction,
    FlaxBertForPreTraining,
    FlaxBertForQuestionAnswering,
    FlaxBertForSequenceClassification,
    FlaxBertForTokenClassification,
    FlaxBertModel,
)
from ..big_bird.modeling_flax_big_bird import (
    FlaxBigBirdForMaskedLM,
    FlaxBigBirdForMultipleChoice,
    FlaxBigBirdForPreTraining,
    FlaxBigBirdForQuestionAnswering,
    FlaxBigBirdForSequenceClassification,
    FlaxBigBirdForTokenClassification,
    FlaxBigBirdModel,
)
from ..clip.modeling_flax_clip import FlaxCLIPModel
from ..electra.modeling_flax_electra import (
    FlaxElectraForMaskedLM,
    FlaxElectraForMultipleChoice,
    FlaxElectraForPreTraining,
    FlaxElectraForQuestionAnswering,
    FlaxElectraForSequenceClassification,
    FlaxElectraForTokenClassification,
    FlaxElectraModel,
)
from ..gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
from ..gpt_neo.modeling_flax_gpt_neo import FlaxGPTNeoForCausalLM, FlaxGPTNeoModel
from ..marian.modeling_flax_marian import FlaxMarianModel, FlaxMarianMTModel
from ..mbart.modeling_flax_mbart import (
    FlaxMBartForConditionalGeneration,
    FlaxMBartForQuestionAnswering,
    FlaxMBartForSequenceClassification,
    FlaxMBartModel,
)
from ..roberta.modeling_flax_roberta import (
    FlaxRobertaForMaskedLM,
    FlaxRobertaForMultipleChoice,
    FlaxRobertaForQuestionAnswering,
    FlaxRobertaForSequenceClassification,
    FlaxRobertaForTokenClassification,
    FlaxRobertaModel,
)
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
from ..vit.modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
from ..wav2vec2.modeling_flax_wav2vec2 import FlaxWav2Vec2ForPreTraining, FlaxWav2Vec2Model
from .auto_factory import _BaseAutoModelClass, auto_class_update
from .configuration_auto import (
    BartConfig,
    BertConfig,
    BigBirdConfig,
    CLIPConfig,
    ElectraConfig,
    GPT2Config,
    GPTNeoConfig,
    MarianConfig,
    MBartConfig,
    MT5Config,
    RobertaConfig,
    T5Config,
    ViTConfig,
    Wav2Vec2Config,
)


logger = logging.get_logger(__name__)


FLAX_MODEL_MAPPING = OrderedDict(
    [
        # Base model mapping
        (RobertaConfig, FlaxRobertaModel),
        (BertConfig, FlaxBertModel),
        (BigBirdConfig, FlaxBigBirdModel),
        (BartConfig, FlaxBartModel),
        (GPT2Config, FlaxGPT2Model),
        (GPTNeoConfig, FlaxGPTNeoModel),
        (ElectraConfig, FlaxElectraModel),
        (CLIPConfig, FlaxCLIPModel),
        (ViTConfig, FlaxViTModel),
        (MBartConfig, FlaxMBartModel),
        (T5Config, FlaxT5Model),
        (MT5Config, FlaxT5Model),
        (Wav2Vec2Config, FlaxWav2Vec2Model),
        (MarianConfig, FlaxMarianModel),
    ]
)

FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
    [
        # Model for pre-training mapping
        (RobertaConfig, FlaxRobertaForMaskedLM),
        (BertConfig, FlaxBertForPreTraining),
        (BigBirdConfig, FlaxBigBirdForPreTraining),
        (BartConfig, FlaxBartForConditionalGeneration),
        (ElectraConfig, FlaxElectraForPreTraining),
        (MBartConfig, FlaxMBartForConditionalGeneration),
        (T5Config, FlaxT5ForConditionalGeneration),
        (MT5Config, FlaxT5ForConditionalGeneration),
        (Wav2Vec2Config, FlaxWav2Vec2ForPreTraining),
    ]
)

FLAX_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
    [
        # Model for Masked LM mapping
        (RobertaConfig, FlaxRobertaForMaskedLM),
        (BertConfig, FlaxBertForMaskedLM),
        (BigBirdConfig, FlaxBigBirdForMaskedLM),
        (BartConfig, FlaxBartForConditionalGeneration),
        (ElectraConfig, FlaxElectraForMaskedLM),
        (MBartConfig, FlaxMBartForConditionalGeneration),
    ]
)

FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
    [
        # Model for Seq2Seq Causal LM mapping
        (BartConfig, FlaxBartForConditionalGeneration),
        (T5Config, FlaxT5ForConditionalGeneration),
        (MT5Config, FlaxT5ForConditionalGeneration),
        (MarianConfig, FlaxMarianMTModel),
    ]
)

FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = OrderedDict(
    [
        # Model for Image-classsification
        (ViTConfig, FlaxViTForImageClassification),
    ]
)

FLAX_MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
    [
        # Model for Causal LM mapping
        (GPT2Config, FlaxGPT2LMHeadModel),
        (GPTNeoConfig, FlaxGPTNeoForCausalLM),
    ]
)

FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
    [
        # Model for Sequence Classification mapping
        (RobertaConfig, FlaxRobertaForSequenceClassification),
        (BertConfig, FlaxBertForSequenceClassification),
        (BigBirdConfig, FlaxBigBirdForSequenceClassification),
        (BartConfig, FlaxBartForSequenceClassification),
        (ElectraConfig, FlaxElectraForSequenceClassification),
        (MBartConfig, FlaxMBartForSequenceClassification),
    ]
)

FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
    [
        # Model for Question Answering mapping
        (RobertaConfig, FlaxRobertaForQuestionAnswering),
        (BertConfig, FlaxBertForQuestionAnswering),
        (BigBirdConfig, FlaxBigBirdForQuestionAnswering),
        (BartConfig, FlaxBartForQuestionAnswering),
        (ElectraConfig, FlaxElectraForQuestionAnswering),
        (MBartConfig, FlaxMBartForQuestionAnswering),
    ]
)

FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
    [
        # Model for Token Classification mapping
        (RobertaConfig, FlaxRobertaForTokenClassification),
        (BertConfig, FlaxBertForTokenClassification),
        (BigBirdConfig, FlaxBigBirdForTokenClassification),
        (ElectraConfig, FlaxElectraForTokenClassification),
    ]
)

FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
    [
        # Model for Multiple Choice mapping
        (RobertaConfig, FlaxRobertaForMultipleChoice),
        (BertConfig, FlaxBertForMultipleChoice),
        (BigBirdConfig, FlaxBigBirdForMultipleChoice),
        (ElectraConfig, FlaxElectraForMultipleChoice),
    ]
)

FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING = OrderedDict(
    [
        (BertConfig, FlaxBertForNextSentencePrediction),
    ]
)


[docs]class FlaxAutoModel(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_MAPPING
FlaxAutoModel = auto_class_update(FlaxAutoModel)
[docs]class FlaxAutoModelForPreTraining(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_PRETRAINING_MAPPING
FlaxAutoModelForPreTraining = auto_class_update(FlaxAutoModelForPreTraining, head_doc="pretraining")
[docs]class FlaxAutoModelForCausalLM(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_CAUSAL_LM_MAPPING
FlaxAutoModelForCausalLM = auto_class_update(FlaxAutoModelForCausalLM, head_doc="causal language modeling")
[docs]class FlaxAutoModelForMaskedLM(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_MASKED_LM_MAPPING
FlaxAutoModelForMaskedLM = auto_class_update(FlaxAutoModelForMaskedLM, head_doc="masked language modeling")
[docs]class FlaxAutoModelForSeq2SeqLM(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
FlaxAutoModelForSeq2SeqLM = auto_class_update( FlaxAutoModelForSeq2SeqLM, head_doc="sequence-to-sequence language modeling", checkpoint_for_example="t5-base" )
[docs]class FlaxAutoModelForSequenceClassification(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
FlaxAutoModelForSequenceClassification = auto_class_update( FlaxAutoModelForSequenceClassification, head_doc="sequence classification" )
[docs]class FlaxAutoModelForQuestionAnswering(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING
FlaxAutoModelForQuestionAnswering = auto_class_update(FlaxAutoModelForQuestionAnswering, head_doc="question answering")
[docs]class FlaxAutoModelForTokenClassification(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
FlaxAutoModelForTokenClassification = auto_class_update( FlaxAutoModelForTokenClassification, head_doc="token classification" )
[docs]class FlaxAutoModelForMultipleChoice(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING
FlaxAutoModelForMultipleChoice = auto_class_update(FlaxAutoModelForMultipleChoice, head_doc="multiple choice")
[docs]class FlaxAutoModelForNextSentencePrediction(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING
FlaxAutoModelForNextSentencePrediction = auto_class_update( FlaxAutoModelForNextSentencePrediction, head_doc="next sentence prediction" )
[docs]class FlaxAutoModelForImageClassification(_BaseAutoModelClass): _model_mapping = FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
FlaxAutoModelForImageClassification = auto_class_update( FlaxAutoModelForImageClassification, head_doc="image classification" )