Auto Classes
In many cases, the architecture you want to use can be guessed from the name or the path of the pretrained model you
are supplying to the from_pretrained()
method. AutoClasses are here to do this job for you so that you
automatically retrieve the relevant model given the name/path to the pretrained weights/config/vocabulary.
Instantiating one of AutoConfig, AutoModel, and AutoTokenizer will directly create a class of the relevant architecture. For instance
model = AutoModel.from_pretrained("bert-base-cased")
will create a model that is an instance of BertModel.
There is one class of AutoModel
for each task, and for each backend (PyTorch, TensorFlow, or Flax).
Extending the Auto Classes
Each of the auto classes has a method to be extended with your custom classes. For instance, if you have defined a
custom class of model NewModel
, make sure you have a NewModelConfig
then you can add those to the auto
classes like this:
from transformers import AutoConfig, AutoModel
AutoConfig.register("new-model", NewModelConfig)
AutoModel.register(NewModelConfig, NewModel)
You will then be able to use the auto classes like you would usually do!
If your NewModelConfig
is a subclass of ~transformer.PretrainedConfig
, make sure its
model_type
attribute is set to the same key you use when registering the config (here "new-model"
).
Likewise, if your NewModel
is a subclass of PreTrainedModel, make sure its
config_class
attribute is set to the same class you use when registering the model (here
NewModelConfig
).
AutoConfig
This is a generic configuration class that will be instantiated as one of the configuration classes of the library when created with the from_pretrained() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_pretrained
< source >( pretrained_model_name_or_path **kwargs )
Parameters
-
pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like
bert-base-uncased
, or namespaced under a user or organization name, likedbmdz/bert-base-german-cased
. - A path to a directory containing a configuration file saved using the
save_pretrained() method, or the save_pretrained() method,
e.g.,
./my_model_directory/
. - A path or url to a saved configuration JSON file, e.g.,
./my_model_directory/configuration.json
.
- A string, the model id of a pretrained model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like
-
cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. -
force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download the model weights and configuration files and override the cached versions if they exist. -
resume_download (
bool
, optional, defaults toFalse
) — Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. -
proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. -
revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
return_unused_kwargs (
bool
, optional, defaults toFalse
) — IfFalse
, then this function returns just the final configuration object.If
True
, then this functions returns aTuple(config, unused_kwargs)
where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the part ofkwargs
which has not been used to updateconfig
and is otherwise ignored. -
trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. -
kwargs(additional keyword arguments, optional) —
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
values. Behavior concerning key/value pairs whose keys are not configuration attributes is controlled
by the
return_unused_kwargs
keyword parameter.
Instantiate one of the configuration classes of the library from a pretrained model configuration.
The configuration class to instantiate is selected based on the model_type
property of the config object that
is loaded, or when itβs missing, by falling back to using pattern matching on pretrained_model_name_or_path
:
- albert β AlbertConfig (ALBERT model)
- align β AlignConfig (ALIGN model)
- altclip β AltCLIPConfig (AltCLIP model)
- audio-spectrogram-transformer β ASTConfig (Audio Spectrogram Transformer model)
- autoformer β AutoformerConfig (Autoformer model)
- bark β BarkConfig (Bark model)
- bart β BartConfig (BART model)
- beit β BeitConfig (BEiT model)
- bert β BertConfig (BERT model)
- bert-generation β BertGenerationConfig (Bert Generation model)
- big_bird β BigBirdConfig (BigBird model)
- bigbird_pegasus β BigBirdPegasusConfig (BigBird-Pegasus model)
- biogpt β BioGptConfig (BioGpt model)
- bit β BitConfig (BiT model)
- blenderbot β BlenderbotConfig (Blenderbot model)
- blenderbot-small β BlenderbotSmallConfig (BlenderbotSmall model)
- blip β BlipConfig (BLIP model)
- blip-2 β Blip2Config (BLIP-2 model)
- bloom β BloomConfig (BLOOM model)
- bridgetower β BridgeTowerConfig (BridgeTower model)
- camembert β CamembertConfig (CamemBERT model)
- canine β CanineConfig (CANINE model)
- chinese_clip β ChineseCLIPConfig (Chinese-CLIP model)
- clap β ClapConfig (CLAP model)
- clip β CLIPConfig (CLIP model)
- clipseg β CLIPSegConfig (CLIPSeg model)
- code_llama β LlamaConfig (CodeLlama model)
- codegen β CodeGenConfig (CodeGen model)
- conditional_detr β ConditionalDetrConfig (Conditional DETR model)
- convbert β ConvBertConfig (ConvBERT model)
- convnext β ConvNextConfig (ConvNeXT model)
- convnextv2 β ConvNextV2Config (ConvNeXTV2 model)
- cpmant β CpmAntConfig (CPM-Ant model)
- ctrl β CTRLConfig (CTRL model)
- cvt β CvtConfig (CvT model)
- data2vec-audio β Data2VecAudioConfig (Data2VecAudio model)
- data2vec-text β Data2VecTextConfig (Data2VecText model)
- data2vec-vision β Data2VecVisionConfig (Data2VecVision model)
- deberta β DebertaConfig (DeBERTa model)
- deberta-v2 β DebertaV2Config (DeBERTa-v2 model)
- decision_transformer β DecisionTransformerConfig (Decision Transformer model)
- deformable_detr β DeformableDetrConfig (Deformable DETR model)
- deit β DeiTConfig (DeiT model)
- deta β DetaConfig (DETA model)
- detr β DetrConfig (DETR model)
- dinat β DinatConfig (DiNAT model)
- dinov2 β Dinov2Config (DINOv2 model)
- distilbert β DistilBertConfig (DistilBERT model)
- donut-swin β DonutSwinConfig (DonutSwin model)
- dpr β DPRConfig (DPR model)
- dpt β DPTConfig (DPT model)
- efficientformer β EfficientFormerConfig (EfficientFormer model)
- efficientnet β EfficientNetConfig (EfficientNet model)
- electra β ElectraConfig (ELECTRA model)
- encodec β EncodecConfig (EnCodec model)
- encoder-decoder β EncoderDecoderConfig (Encoder decoder model)
- ernie β ErnieConfig (ERNIE model)
- ernie_m β ErnieMConfig (ErnieM model)
- esm β EsmConfig (ESM model)
- falcon β FalconConfig (Falcon model)
- flaubert β FlaubertConfig (FlauBERT model)
- flava β FlavaConfig (FLAVA model)
- fnet β FNetConfig (FNet model)
- focalnet β FocalNetConfig (FocalNet model)
- fsmt β FSMTConfig (FairSeq Machine-Translation model)
- funnel β FunnelConfig (Funnel Transformer model)
- git β GitConfig (GIT model)
- glpn β GLPNConfig (GLPN model)
- gpt-sw3 β GPT2Config (GPT-Sw3 model)
- gpt2 β GPT2Config (OpenAI GPT-2 model)
- gpt_bigcode β GPTBigCodeConfig (GPTBigCode model)
- gpt_neo β GPTNeoConfig (GPT Neo model)
- gpt_neox β GPTNeoXConfig (GPT NeoX model)
- gpt_neox_japanese β GPTNeoXJapaneseConfig (GPT NeoX Japanese model)
- gptj β GPTJConfig (GPT-J model)
- gptsan-japanese β GPTSanJapaneseConfig (GPTSAN-japanese model)
- graphormer β GraphormerConfig (Graphormer model)
- groupvit β GroupViTConfig (GroupViT model)
- hubert β HubertConfig (Hubert model)
- ibert β IBertConfig (I-BERT model)
- idefics β IdeficsConfig (IDEFICS model)
- imagegpt β ImageGPTConfig (ImageGPT model)
- informer β InformerConfig (Informer model)
- instructblip β InstructBlipConfig (InstructBLIP model)
- jukebox β JukeboxConfig (Jukebox model)
- layoutlm β LayoutLMConfig (LayoutLM model)
- layoutlmv2 β LayoutLMv2Config (LayoutLMv2 model)
- layoutlmv3 β LayoutLMv3Config (LayoutLMv3 model)
- led β LEDConfig (LED model)
- levit β LevitConfig (LeViT model)
- lilt β LiltConfig (LiLT model)
- llama β LlamaConfig (LLaMA model)
- longformer β LongformerConfig (Longformer model)
- longt5 β LongT5Config (LongT5 model)
- luke β LukeConfig (LUKE model)
- lxmert β LxmertConfig (LXMERT model)
- m2m_100 β M2M100Config (M2M100 model)
- marian β MarianConfig (Marian model)
- markuplm β MarkupLMConfig (MarkupLM model)
- mask2former β Mask2FormerConfig (Mask2Former model)
- maskformer β MaskFormerConfig (MaskFormer model)
- maskformer-swin β
MaskFormerSwinConfig
(MaskFormerSwin model) - mbart β MBartConfig (mBART model)
- mctct β MCTCTConfig (M-CTC-T model)
- mega β MegaConfig (MEGA model)
- megatron-bert β MegatronBertConfig (Megatron-BERT model)
- mgp-str β MgpstrConfig (MGP-STR model)
- mobilebert β MobileBertConfig (MobileBERT model)
- mobilenet_v1 β MobileNetV1Config (MobileNetV1 model)
- mobilenet_v2 β MobileNetV2Config (MobileNetV2 model)
- mobilevit β MobileViTConfig (MobileViT model)
- mobilevitv2 β MobileViTV2Config (MobileViTV2 model)
- mpnet β MPNetConfig (MPNet model)
- mpt β MptConfig (MPT model)
- mra β MraConfig (MRA model)
- mt5 β MT5Config (MT5 model)
- musicgen β MusicgenConfig (MusicGen model)
- mvp β MvpConfig (MVP model)
- nat β NatConfig (NAT model)
- nezha β NezhaConfig (Nezha model)
- nllb-moe β NllbMoeConfig (NLLB-MOE model)
- nystromformer β NystromformerConfig (NystrΓΆmformer model)
- oneformer β OneFormerConfig (OneFormer model)
- open-llama β OpenLlamaConfig (OpenLlama model)
- openai-gpt β OpenAIGPTConfig (OpenAI GPT model)
- opt β OPTConfig (OPT model)
- owlvit β OwlViTConfig (OWL-ViT model)
- pegasus β PegasusConfig (Pegasus model)
- pegasus_x β PegasusXConfig (PEGASUS-X model)
- perceiver β PerceiverConfig (Perceiver model)
- pix2struct β Pix2StructConfig (Pix2Struct model)
- plbart β PLBartConfig (PLBart model)
- poolformer β PoolFormerConfig (PoolFormer model)
- pop2piano β Pop2PianoConfig (Pop2Piano model)
- prophetnet β ProphetNetConfig (ProphetNet model)
- pvt β PvtConfig (PVT model)
- qdqbert β QDQBertConfig (QDQBert model)
- rag β RagConfig (RAG model)
- realm β RealmConfig (REALM model)
- reformer β ReformerConfig (Reformer model)
- regnet β RegNetConfig (RegNet model)
- rembert β RemBertConfig (RemBERT model)
- resnet β ResNetConfig (ResNet model)
- retribert β RetriBertConfig (RetriBERT model)
- roberta β RobertaConfig (RoBERTa model)
- roberta-prelayernorm β RobertaPreLayerNormConfig (RoBERTa-PreLayerNorm model)
- roc_bert β RoCBertConfig (RoCBert model)
- roformer β RoFormerConfig (RoFormer model)
- rwkv β RwkvConfig (RWKV model)
- sam β SamConfig (SAM model)
- segformer β SegformerConfig (SegFormer model)
- sew β SEWConfig (SEW model)
- sew-d β SEWDConfig (SEW-D model)
- speech-encoder-decoder β SpeechEncoderDecoderConfig (Speech Encoder decoder model)
- speech_to_text β Speech2TextConfig (Speech2Text model)
- speech_to_text_2 β Speech2Text2Config (Speech2Text2 model)
- speecht5 β SpeechT5Config (SpeechT5 model)
- splinter β SplinterConfig (Splinter model)
- squeezebert β SqueezeBertConfig (SqueezeBERT model)
- swiftformer β SwiftFormerConfig (SwiftFormer model)
- swin β SwinConfig (Swin Transformer model)
- swin2sr β Swin2SRConfig (Swin2SR model)
- swinv2 β Swinv2Config (Swin Transformer V2 model)
- switch_transformers β SwitchTransformersConfig (SwitchTransformers model)
- t5 β T5Config (T5 model)
- table-transformer β TableTransformerConfig (Table Transformer model)
- tapas β TapasConfig (TAPAS model)
- time_series_transformer β TimeSeriesTransformerConfig (Time Series Transformer model)
- timesformer β TimesformerConfig (TimeSformer model)
- timm_backbone β
TimmBackboneConfig
(TimmBackbone model) - trajectory_transformer β TrajectoryTransformerConfig (Trajectory Transformer model)
- transfo-xl β TransfoXLConfig (Transformer-XL model)
- trocr β TrOCRConfig (TrOCR model)
- tvlt β TvltConfig (TVLT model)
- umt5 β UMT5Config (UMT5 model)
- unispeech β UniSpeechConfig (UniSpeech model)
- unispeech-sat β UniSpeechSatConfig (UniSpeechSat model)
- upernet β UperNetConfig (UPerNet model)
- van β VanConfig (VAN model)
- videomae β VideoMAEConfig (VideoMAE model)
- vilt β ViltConfig (ViLT model)
- vision-encoder-decoder β VisionEncoderDecoderConfig (Vision Encoder decoder model)
- vision-text-dual-encoder β VisionTextDualEncoderConfig (VisionTextDualEncoder model)
- visual_bert β VisualBertConfig (VisualBERT model)
- vit β ViTConfig (ViT model)
- vit_hybrid β ViTHybridConfig (ViT Hybrid model)
- vit_mae β ViTMAEConfig (ViTMAE model)
- vit_msn β ViTMSNConfig (ViTMSN model)
- vitdet β VitDetConfig (VitDet model)
- vits β VitsConfig (VITS model)
- vivit β VivitConfig (ViViT model)
- wav2vec2 β Wav2Vec2Config (Wav2Vec2 model)
- wav2vec2-conformer β Wav2Vec2ConformerConfig (Wav2Vec2-Conformer model)
- wavlm β WavLMConfig (WavLM model)
- whisper β WhisperConfig (Whisper model)
- xclip β XCLIPConfig (X-CLIP model)
- xglm β XGLMConfig (XGLM model)
- xlm β XLMConfig (XLM model)
- xlm-prophetnet β XLMProphetNetConfig (XLM-ProphetNet model)
- xlm-roberta β XLMRobertaConfig (XLM-RoBERTa model)
- xlm-roberta-xl β XLMRobertaXLConfig (XLM-RoBERTa-XL model)
- xlnet β XLNetConfig (XLNet model)
- xmod β XmodConfig (X-MOD model)
- yolos β YolosConfig (YOLOS model)
- yoso β YosoConfig (YOSO model)
Examples:
>>> from transformers import AutoConfig
>>> # Download configuration from huggingface.co and cache.
>>> config = AutoConfig.from_pretrained("bert-base-uncased")
>>> # Download configuration from huggingface.co (user-uploaded) and cache.
>>> config = AutoConfig.from_pretrained("dbmdz/bert-base-german-cased")
>>> # If configuration file is in a directory (e.g., was saved using *save_pretrained('./test/saved_model/')*).
>>> config = AutoConfig.from_pretrained("./test/bert_saved_model/")
>>> # Load a specific configuration file.
>>> config = AutoConfig.from_pretrained("./test/bert_saved_model/my_configuration.json")
>>> # Change some config attributes when loading a pretrained config.
>>> config = AutoConfig.from_pretrained("bert-base-uncased", output_attentions=True, foo=False)
>>> config.output_attentions
True
>>> config, unused_kwargs = AutoConfig.from_pretrained(
... "bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
... )
>>> config.output_attentions
True
>>> unused_kwargs
{'foo': False}
register
< source >( model_type config exist_ok = False )
Parameters
-
model_type (
str
) — The model type like “bert” or “gpt”. - config (PretrainedConfig) — The config to register.
Register a new configuration for this class.
AutoTokenizer
This is a generic tokenizer class that will be instantiated as one of the tokenizer classes of the library when created with the AutoTokenizer.from_pretrained() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_pretrained
< source >( pretrained_model_name_or_path *inputs **kwargs )
Parameters
-
pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a predefined tokenizer hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
bert-base-uncased
, or namespaced under a user or organization name, likedbmdz/bert-base-german-cased
. - A path to a directory containing vocabulary files required by the tokenizer, for instance saved
using the save_pretrained() method, e.g.,
./my_model_directory/
. - A path or url to a single saved vocabulary file if and only if the tokenizer only requires a
single vocabulary file (like Bert or XLNet), e.g.:
./my_model_directory/vocab.txt
. (Not applicable to all derived classes)
- A string, the model id of a predefined tokenizer hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
-
inputs (additional positional arguments, optional) —
Will be passed along to the Tokenizer
__init__()
method. - config (PretrainedConfig, optional) — The configuration object used to determine the tokenizer class to instantiate.
-
cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. -
force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download the model weights and configuration files and override the cached versions if they exist. -
resume_download (
bool
, optional, defaults toFalse
) — Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. -
proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. -
revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
subfolder (
str
, optional) — In case the relevant files are located inside a subfolder of the model repo on huggingface.co (e.g. for facebook/rag-token-base), specify it here. -
use_fast (
bool
, optional, defaults toTrue
) — Use a fast Rust-based tokenizer if it is supported for a given model. If a fast tokenizer is not available for a given model, a normal Python-based tokenizer is returned instead. -
tokenizer_type (
str
, optional) — Tokenizer type to be loaded. -
trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. -
kwargs (additional keyword arguments, optional) —
Will be passed to the Tokenizer
__init__()
method. Can be used to set special tokens likebos_token
,eos_token
,unk_token
,sep_token
,pad_token
,cls_token
,mask_token
,additional_special_tokens
. See parameters in the__init__()
for more details.
Instantiate one of the tokenizer classes of the library from a pretrained model vocabulary.
The tokenizer class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when itβs missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert β AlbertTokenizer or AlbertTokenizerFast (ALBERT model)
- align β BertTokenizer or BertTokenizerFast (ALIGN model)
- bark β BertTokenizer or BertTokenizerFast (Bark model)
- bart β BartTokenizer or BartTokenizerFast (BART model)
- barthez β BarthezTokenizer or BarthezTokenizerFast (BARThez model)
- bartpho β BartphoTokenizer (BARTpho model)
- bert β BertTokenizer or BertTokenizerFast (BERT model)
- bert-generation β BertGenerationTokenizer (Bert Generation model)
- bert-japanese β BertJapaneseTokenizer (BertJapanese model)
- bertweet β BertweetTokenizer (BERTweet model)
- big_bird β BigBirdTokenizer or BigBirdTokenizerFast (BigBird model)
- bigbird_pegasus β PegasusTokenizer or PegasusTokenizerFast (BigBird-Pegasus model)
- biogpt β BioGptTokenizer (BioGpt model)
- blenderbot β BlenderbotTokenizer or BlenderbotTokenizerFast (Blenderbot model)
- blenderbot-small β BlenderbotSmallTokenizer (BlenderbotSmall model)
- blip β BertTokenizer or BertTokenizerFast (BLIP model)
- blip-2 β GPT2Tokenizer or GPT2TokenizerFast (BLIP-2 model)
- bloom β BloomTokenizerFast (BLOOM model)
- bridgetower β RobertaTokenizer or RobertaTokenizerFast (BridgeTower model)
- byt5 β ByT5Tokenizer (ByT5 model)
- camembert β CamembertTokenizer or CamembertTokenizerFast (CamemBERT model)
- canine β CanineTokenizer (CANINE model)
- chinese_clip β BertTokenizer or BertTokenizerFast (Chinese-CLIP model)
- clap β RobertaTokenizer or RobertaTokenizerFast (CLAP model)
- clip β CLIPTokenizer or CLIPTokenizerFast (CLIP model)
- clipseg β CLIPTokenizer or CLIPTokenizerFast (CLIPSeg model)
- code_llama β CodeLlamaTokenizer or CodeLlamaTokenizerFast (CodeLlama model)
- codegen β CodeGenTokenizer or CodeGenTokenizerFast (CodeGen model)
- convbert β ConvBertTokenizer or ConvBertTokenizerFast (ConvBERT model)
- cpm β CpmTokenizer or CpmTokenizerFast (CPM model)
- cpmant β CpmAntTokenizer (CPM-Ant model)
- ctrl β CTRLTokenizer (CTRL model)
- data2vec-audio β Wav2Vec2CTCTokenizer (Data2VecAudio model)
- data2vec-text β RobertaTokenizer or RobertaTokenizerFast (Data2VecText model)
- deberta β DebertaTokenizer or DebertaTokenizerFast (DeBERTa model)
- deberta-v2 β DebertaV2Tokenizer or DebertaV2TokenizerFast (DeBERTa-v2 model)
- distilbert β DistilBertTokenizer or DistilBertTokenizerFast (DistilBERT model)
- dpr β DPRQuestionEncoderTokenizer or DPRQuestionEncoderTokenizerFast (DPR model)
- electra β ElectraTokenizer or ElectraTokenizerFast (ELECTRA model)
- ernie β BertTokenizer or BertTokenizerFast (ERNIE model)
- ernie_m β ErnieMTokenizer (ErnieM model)
- esm β EsmTokenizer (ESM model)
- flaubert β FlaubertTokenizer (FlauBERT model)
- fnet β FNetTokenizer or FNetTokenizerFast (FNet model)
- fsmt β FSMTTokenizer (FairSeq Machine-Translation model)
- funnel β FunnelTokenizer or FunnelTokenizerFast (Funnel Transformer model)
- git β BertTokenizer or BertTokenizerFast (GIT model)
- gpt-sw3 β GPTSw3Tokenizer (GPT-Sw3 model)
- gpt2 β GPT2Tokenizer or GPT2TokenizerFast (OpenAI GPT-2 model)
- gpt_bigcode β GPT2Tokenizer or GPT2TokenizerFast (GPTBigCode model)
- gpt_neo β GPT2Tokenizer or GPT2TokenizerFast (GPT Neo model)
- gpt_neox β GPTNeoXTokenizerFast (GPT NeoX model)
- gpt_neox_japanese β GPTNeoXJapaneseTokenizer (GPT NeoX Japanese model)
- gptj β GPT2Tokenizer or GPT2TokenizerFast (GPT-J model)
- gptsan-japanese β GPTSanJapaneseTokenizer (GPTSAN-japanese model)
- groupvit β CLIPTokenizer or CLIPTokenizerFast (GroupViT model)
- herbert β HerbertTokenizer or HerbertTokenizerFast (HerBERT model)
- hubert β Wav2Vec2CTCTokenizer (Hubert model)
- ibert β RobertaTokenizer or RobertaTokenizerFast (I-BERT model)
- idefics β LlamaTokenizerFast (IDEFICS model)
- instructblip β GPT2Tokenizer or GPT2TokenizerFast (InstructBLIP model)
- jukebox β JukeboxTokenizer (Jukebox model)
- layoutlm β LayoutLMTokenizer or LayoutLMTokenizerFast (LayoutLM model)
- layoutlmv2 β LayoutLMv2Tokenizer or LayoutLMv2TokenizerFast (LayoutLMv2 model)
- layoutlmv3 β LayoutLMv3Tokenizer or LayoutLMv3TokenizerFast (LayoutLMv3 model)
- layoutxlm β LayoutXLMTokenizer or LayoutXLMTokenizerFast (LayoutXLM model)
- led β LEDTokenizer or LEDTokenizerFast (LED model)
- lilt β LayoutLMv3Tokenizer or LayoutLMv3TokenizerFast (LiLT model)
- llama β LlamaTokenizer or LlamaTokenizerFast (LLaMA model)
- longformer β LongformerTokenizer or LongformerTokenizerFast (Longformer model)
- longt5 β T5Tokenizer or T5TokenizerFast (LongT5 model)
- luke β LukeTokenizer (LUKE model)
- lxmert β LxmertTokenizer or LxmertTokenizerFast (LXMERT model)
- m2m_100 β M2M100Tokenizer (M2M100 model)
- marian β MarianTokenizer (Marian model)
- mbart β MBartTokenizer or MBartTokenizerFast (mBART model)
- mbart50 β MBart50Tokenizer or MBart50TokenizerFast (mBART-50 model)
- mega β RobertaTokenizer or RobertaTokenizerFast (MEGA model)
- megatron-bert β BertTokenizer or BertTokenizerFast (Megatron-BERT model)
- mgp-str β MgpstrTokenizer (MGP-STR model)
- mluke β MLukeTokenizer (mLUKE model)
- mobilebert β MobileBertTokenizer or MobileBertTokenizerFast (MobileBERT model)
- mpnet β MPNetTokenizer or MPNetTokenizerFast (MPNet model)
- mpt β GPTNeoXTokenizerFast (MPT model)
- mra β RobertaTokenizer or RobertaTokenizerFast (MRA model)
- mt5 β MT5Tokenizer or MT5TokenizerFast (MT5 model)
- musicgen β T5Tokenizer or T5TokenizerFast (MusicGen model)
- mvp β MvpTokenizer or MvpTokenizerFast (MVP model)
- nezha β BertTokenizer or BertTokenizerFast (Nezha model)
- nllb β NllbTokenizer or NllbTokenizerFast (NLLB model)
- nllb-moe β NllbTokenizer or NllbTokenizerFast (NLLB-MOE model)
- nystromformer β AlbertTokenizer or AlbertTokenizerFast (NystrΓΆmformer model)
- oneformer β CLIPTokenizer or CLIPTokenizerFast (OneFormer model)
- openai-gpt β OpenAIGPTTokenizer or OpenAIGPTTokenizerFast (OpenAI GPT model)
- opt β GPT2Tokenizer or GPT2TokenizerFast (OPT model)
- owlvit β CLIPTokenizer or CLIPTokenizerFast (OWL-ViT model)
- pegasus β PegasusTokenizer or PegasusTokenizerFast (Pegasus model)
- pegasus_x β PegasusTokenizer or PegasusTokenizerFast (PEGASUS-X model)
- perceiver β PerceiverTokenizer (Perceiver model)
- phobert β PhobertTokenizer (PhoBERT model)
- pix2struct β T5Tokenizer or T5TokenizerFast (Pix2Struct model)
- plbart β PLBartTokenizer (PLBart model)
- prophetnet β ProphetNetTokenizer (ProphetNet model)
- qdqbert β BertTokenizer or BertTokenizerFast (QDQBert model)
- rag β RagTokenizer (RAG model)
- realm β RealmTokenizer or RealmTokenizerFast (REALM model)
- reformer β ReformerTokenizer or ReformerTokenizerFast (Reformer model)
- rembert β RemBertTokenizer or RemBertTokenizerFast (RemBERT model)
- retribert β RetriBertTokenizer or RetriBertTokenizerFast (RetriBERT model)
- roberta β RobertaTokenizer or RobertaTokenizerFast (RoBERTa model)
- roberta-prelayernorm β RobertaTokenizer or RobertaTokenizerFast (RoBERTa-PreLayerNorm model)
- roc_bert β RoCBertTokenizer (RoCBert model)
- roformer β RoFormerTokenizer or RoFormerTokenizerFast (RoFormer model)
- rwkv β GPTNeoXTokenizerFast (RWKV model)
- speech_to_text β Speech2TextTokenizer (Speech2Text model)
- speech_to_text_2 β Speech2Text2Tokenizer (Speech2Text2 model)
- speecht5 β SpeechT5Tokenizer (SpeechT5 model)
- splinter β SplinterTokenizer or SplinterTokenizerFast (Splinter model)
- squeezebert β SqueezeBertTokenizer or SqueezeBertTokenizerFast (SqueezeBERT model)
- switch_transformers β T5Tokenizer or T5TokenizerFast (SwitchTransformers model)
- t5 β T5Tokenizer or T5TokenizerFast (T5 model)
- tapas β TapasTokenizer (TAPAS model)
- tapex β TapexTokenizer (TAPEX model)
- transfo-xl β TransfoXLTokenizer (Transformer-XL model)
- umt5 β T5Tokenizer or T5TokenizerFast (UMT5 model)
- vilt β BertTokenizer or BertTokenizerFast (ViLT model)
- visual_bert β BertTokenizer or BertTokenizerFast (VisualBERT model)
- vits β VitsTokenizer (VITS model)
- wav2vec2 β Wav2Vec2CTCTokenizer (Wav2Vec2 model)
- wav2vec2-conformer β Wav2Vec2CTCTokenizer (Wav2Vec2-Conformer model)
- wav2vec2_phoneme β Wav2Vec2PhonemeCTCTokenizer (Wav2Vec2Phoneme model)
- whisper β WhisperTokenizer or WhisperTokenizerFast (Whisper model)
- xclip β CLIPTokenizer or CLIPTokenizerFast (X-CLIP model)
- xglm β XGLMTokenizer or XGLMTokenizerFast (XGLM model)
- xlm β XLMTokenizer (XLM model)
- xlm-prophetnet β XLMProphetNetTokenizer (XLM-ProphetNet model)
- xlm-roberta β XLMRobertaTokenizer or XLMRobertaTokenizerFast (XLM-RoBERTa model)
- xlm-roberta-xl β XLMRobertaTokenizer or XLMRobertaTokenizerFast (XLM-RoBERTa-XL model)
- xlnet β XLNetTokenizer or XLNetTokenizerFast (XLNet model)
- xmod β XLMRobertaTokenizer or XLMRobertaTokenizerFast (X-MOD model)
- yoso β AlbertTokenizer or AlbertTokenizerFast (YOSO model)
Examples:
>>> from transformers import AutoTokenizer
>>> # Download vocabulary from huggingface.co and cache.
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
>>> # Download vocabulary from huggingface.co (user-uploaded) and cache.
>>> tokenizer = AutoTokenizer.from_pretrained("dbmdz/bert-base-german-cased")
>>> # If vocabulary files are in a directory (e.g. tokenizer was saved using *save_pretrained('./test/saved_model/')*)
>>> # tokenizer = AutoTokenizer.from_pretrained("./test/bert_saved_model/")
>>> # Download vocabulary from huggingface.co and define model-specific arguments
>>> tokenizer = AutoTokenizer.from_pretrained("roberta-base", add_prefix_space=True)
register
< source >( config_class slow_tokenizer_class = None fast_tokenizer_class = None exist_ok = False )
Parameters
- config_class (PretrainedConfig) — The configuration corresponding to the model to register.
-
slow_tokenizer_class (
PretrainedTokenizer
, optional) — The slow tokenizer to register. -
slow_tokenizer_class (
PretrainedTokenizerFast
, optional) — The fast tokenizer to register.
Register a new tokenizer in this mapping.
AutoFeatureExtractor
This is a generic feature extractor class that will be instantiated as one of the feature extractor classes of the library when created with the AutoFeatureExtractor.from_pretrained() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_pretrained
< source >( pretrained_model_name_or_path **kwargs )
Parameters
-
pretrained_model_name_or_path (
str
oros.PathLike
) — This can be either:- a string, the model id of a pretrained feature_extractor hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like
bert-base-uncased
, or namespaced under a user or organization name, likedbmdz/bert-base-german-cased
. - a path to a directory containing a feature extractor file saved using the
save_pretrained() method, e.g.,
./my_model_directory/
. - a path or url to a saved feature extractor JSON file, e.g.,
./my_model_directory/preprocessor_config.json
.
- a string, the model id of a pretrained feature_extractor hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like
-
cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model feature extractor should be cached if the standard cache should not be used. -
force_download (
bool
, optional, defaults toFalse
) — Whether or not to force to (re-)download the feature extractor files and override the cached versions if they exist. -
resume_download (
bool
, optional, defaults toFalse
) — Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. -
proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. -
token (
str
or bool, optional) — The token to use as HTTP bearer authorization for remote files. IfTrue
, will use the token generated when runninghuggingface-cli login
(stored in~/.huggingface
). -
revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
return_unused_kwargs (
bool
, optional, defaults toFalse
) — IfFalse
, then this function returns just the final feature extractor object. IfTrue
, then this functions returns aTuple(feature_extractor, unused_kwargs)
where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part ofkwargs
which has not been used to updatefeature_extractor
and is otherwise ignored. -
trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. -
kwargs (
Dict[str, Any]
, optional) — The values in kwargs of any keys which are feature extractor attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not feature extractor attributes is controlled by thereturn_unused_kwargs
keyword parameter.
Instantiate one of the feature extractor classes of the library from a pretrained model vocabulary.
The feature extractor class to instantiate is selected based on the model_type
property of the config object
(either passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when itβs
missing, by falling back to using pattern matching on pretrained_model_name_or_path
:
- audio-spectrogram-transformer β ASTFeatureExtractor (Audio Spectrogram Transformer model)
- beit β BeitFeatureExtractor (BEiT model)
- chinese_clip β ChineseCLIPFeatureExtractor (Chinese-CLIP model)
- clap β ClapFeatureExtractor (CLAP model)
- clip β CLIPFeatureExtractor (CLIP model)
- clipseg β ViTFeatureExtractor (CLIPSeg model)
- conditional_detr β ConditionalDetrFeatureExtractor (Conditional DETR model)
- convnext β ConvNextFeatureExtractor (ConvNeXT model)
- cvt β ConvNextFeatureExtractor (CvT model)
- data2vec-audio β Wav2Vec2FeatureExtractor (Data2VecAudio model)
- data2vec-vision β BeitFeatureExtractor (Data2VecVision model)
- deformable_detr β DeformableDetrFeatureExtractor (Deformable DETR model)
- deit β DeiTFeatureExtractor (DeiT model)
- detr β DetrFeatureExtractor (DETR model)
- dinat β ViTFeatureExtractor (DiNAT model)
- donut-swin β DonutFeatureExtractor (DonutSwin model)
- dpt β DPTFeatureExtractor (DPT model)
- encodec β EncodecFeatureExtractor (EnCodec model)
- flava β FlavaFeatureExtractor (FLAVA model)
- glpn β GLPNFeatureExtractor (GLPN model)
- groupvit β CLIPFeatureExtractor (GroupViT model)
- hubert β Wav2Vec2FeatureExtractor (Hubert model)
- imagegpt β ImageGPTFeatureExtractor (ImageGPT model)
- layoutlmv2 β LayoutLMv2FeatureExtractor (LayoutLMv2 model)
- layoutlmv3 β LayoutLMv3FeatureExtractor (LayoutLMv3 model)
- levit β LevitFeatureExtractor (LeViT model)
- maskformer β MaskFormerFeatureExtractor (MaskFormer model)
- mctct β MCTCTFeatureExtractor (M-CTC-T model)
- mobilenet_v1 β MobileNetV1FeatureExtractor (MobileNetV1 model)
- mobilenet_v2 β MobileNetV2FeatureExtractor (MobileNetV2 model)
- mobilevit β MobileViTFeatureExtractor (MobileViT model)
- nat β ViTFeatureExtractor (NAT model)
- owlvit β OwlViTFeatureExtractor (OWL-ViT model)
- perceiver β PerceiverFeatureExtractor (Perceiver model)
- poolformer β PoolFormerFeatureExtractor (PoolFormer model)
- pop2piano β Pop2PianoFeatureExtractor (Pop2Piano model)
- regnet β ConvNextFeatureExtractor (RegNet model)
- resnet β ConvNextFeatureExtractor (ResNet model)
- segformer β SegformerFeatureExtractor (SegFormer model)
- sew β Wav2Vec2FeatureExtractor (SEW model)
- sew-d β Wav2Vec2FeatureExtractor (SEW-D model)
- speech_to_text β Speech2TextFeatureExtractor (Speech2Text model)
- speecht5 β SpeechT5FeatureExtractor (SpeechT5 model)
- swiftformer β ViTFeatureExtractor (SwiftFormer model)
- swin β ViTFeatureExtractor (Swin Transformer model)
- swinv2 β ViTFeatureExtractor (Swin Transformer V2 model)
- table-transformer β DetrFeatureExtractor (Table Transformer model)
- timesformer β VideoMAEFeatureExtractor (TimeSformer model)
- tvlt β TvltFeatureExtractor (TVLT model)
- unispeech β Wav2Vec2FeatureExtractor (UniSpeech model)
- unispeech-sat β Wav2Vec2FeatureExtractor (UniSpeechSat model)
- van β ConvNextFeatureExtractor (VAN model)
- videomae β VideoMAEFeatureExtractor (VideoMAE model)
- vilt β ViltFeatureExtractor (ViLT model)
- vit β ViTFeatureExtractor (ViT model)
- vit_mae β ViTFeatureExtractor (ViTMAE model)
- vit_msn β ViTFeatureExtractor (ViTMSN model)
- wav2vec2 β Wav2Vec2FeatureExtractor (Wav2Vec2 model)
- wav2vec2-conformer β Wav2Vec2FeatureExtractor (Wav2Vec2-Conformer model)
- wavlm β Wav2Vec2FeatureExtractor (WavLM model)
- whisper β WhisperFeatureExtractor (Whisper model)
- xclip β CLIPFeatureExtractor (X-CLIP model)
- yolos β YolosFeatureExtractor (YOLOS model)
Passing token=True
is required when you want to use a private model.
Examples:
>>> from transformers import AutoFeatureExtractor
>>> # Download feature extractor from huggingface.co and cache.
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")
>>> # If feature extractor files are in a directory (e.g. feature extractor was saved using *save_pretrained('./test/saved_model/')*)
>>> # feature_extractor = AutoFeatureExtractor.from_pretrained("./test/saved_model/")
register
< source >( config_class feature_extractor_class exist_ok = False )
Parameters
- config_class (PretrainedConfig) — The configuration corresponding to the model to register.
-
feature_extractor_class (
FeatureExtractorMixin
) — The feature extractor to register.
Register a new feature extractor for this class.
AutoImageProcessor
This is a generic image processor class that will be instantiated as one of the image processor classes of the library when created with the AutoImageProcessor.from_pretrained() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_pretrained
< source >( pretrained_model_name_or_path **kwargs )
Parameters
-
pretrained_model_name_or_path (
str
oros.PathLike
) — This can be either:- a string, the model id of a pretrained image_processor hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like
bert-base-uncased
, or namespaced under a user or organization name, likedbmdz/bert-base-german-cased
. - a path to a directory containing a image processor file saved using the
save_pretrained() method, e.g.,
./my_model_directory/
. - a path or url to a saved image processor JSON file, e.g.,
./my_model_directory/preprocessor_config.json
.
- a string, the model id of a pretrained image_processor hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like
-
cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model image processor should be cached if the standard cache should not be used. -
force_download (
bool
, optional, defaults toFalse
) — Whether or not to force to (re-)download the image processor files and override the cached versions if they exist. -
resume_download (
bool
, optional, defaults toFalse
) — Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. -
proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. -
token (
str
or bool, optional) — The token to use as HTTP bearer authorization for remote files. IfTrue
, will use the token generated when runninghuggingface-cli login
(stored in~/.huggingface
). -
revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
return_unused_kwargs (
bool
, optional, defaults toFalse
) — IfFalse
, then this function returns just the final image processor object. IfTrue
, then this functions returns aTuple(image_processor, unused_kwargs)
where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not image processor attributes: i.e., the part ofkwargs
which has not been used to updateimage_processor
and is otherwise ignored. -
trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. -
kwargs (
Dict[str, Any]
, optional) — The values in kwargs of any keys which are image processor attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not image processor attributes is controlled by thereturn_unused_kwargs
keyword parameter.
Instantiate one of the image processor classes of the library from a pretrained model vocabulary.
The image processor class to instantiate is selected based on the model_type
property of the config object
(either passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when itβs
missing, by falling back to using pattern matching on pretrained_model_name_or_path
:
- align β EfficientNetImageProcessor (ALIGN model)
- beit β BeitImageProcessor (BEiT model)
- bit β BitImageProcessor (BiT model)
- blip β BlipImageProcessor (BLIP model)
- blip-2 β BlipImageProcessor (BLIP-2 model)
- bridgetower β BridgeTowerImageProcessor (BridgeTower model)
- chinese_clip β ChineseCLIPImageProcessor (Chinese-CLIP model)
- clip β CLIPImageProcessor (CLIP model)
- clipseg β ViTImageProcessor (CLIPSeg model)
- conditional_detr β ConditionalDetrImageProcessor (Conditional DETR model)
- convnext β ConvNextImageProcessor (ConvNeXT model)
- convnextv2 β ConvNextImageProcessor (ConvNeXTV2 model)
- cvt β ConvNextImageProcessor (CvT model)
- data2vec-vision β BeitImageProcessor (Data2VecVision model)
- deformable_detr β DeformableDetrImageProcessor (Deformable DETR model)
- deit β DeiTImageProcessor (DeiT model)
- deta β DetaImageProcessor (DETA model)
- detr β DetrImageProcessor (DETR model)
- dinat β ViTImageProcessor (DiNAT model)
- dinov2 β BitImageProcessor (DINOv2 model)
- donut-swin β DonutImageProcessor (DonutSwin model)
- dpt β DPTImageProcessor (DPT model)
- efficientformer β EfficientFormerImageProcessor (EfficientFormer model)
- efficientnet β EfficientNetImageProcessor (EfficientNet model)
- flava β FlavaImageProcessor (FLAVA model)
- focalnet β BitImageProcessor (FocalNet model)
- git β CLIPImageProcessor (GIT model)
- glpn β GLPNImageProcessor (GLPN model)
- groupvit β CLIPImageProcessor (GroupViT model)
- idefics β IdeficsImageProcessor (IDEFICS model)
- imagegpt β ImageGPTImageProcessor (ImageGPT model)
- instructblip β BlipImageProcessor (InstructBLIP model)
- layoutlmv2 β LayoutLMv2ImageProcessor (LayoutLMv2 model)
- layoutlmv3 β LayoutLMv3ImageProcessor (LayoutLMv3 model)
- levit β LevitImageProcessor (LeViT model)
- mask2former β Mask2FormerImageProcessor (Mask2Former model)
- maskformer β MaskFormerImageProcessor (MaskFormer model)
- mgp-str β ViTImageProcessor (MGP-STR model)
- mobilenet_v1 β MobileNetV1ImageProcessor (MobileNetV1 model)
- mobilenet_v2 β MobileNetV2ImageProcessor (MobileNetV2 model)
- mobilevit β MobileViTImageProcessor (MobileViT model)
- mobilevitv2 β MobileViTImageProcessor (MobileViTV2 model)
- nat β ViTImageProcessor (NAT model)
- oneformer β OneFormerImageProcessor (OneFormer model)
- owlvit β OwlViTImageProcessor (OWL-ViT model)
- perceiver β PerceiverImageProcessor (Perceiver model)
- pix2struct β Pix2StructImageProcessor (Pix2Struct model)
- poolformer β PoolFormerImageProcessor (PoolFormer model)
- pvt β PvtImageProcessor (PVT model)
- regnet β ConvNextImageProcessor (RegNet model)
- resnet β ConvNextImageProcessor (ResNet model)
- sam β SamImageProcessor (SAM model)
- segformer β SegformerImageProcessor (SegFormer model)
- swiftformer β ViTImageProcessor (SwiftFormer model)
- swin β ViTImageProcessor (Swin Transformer model)
- swin2sr β Swin2SRImageProcessor (Swin2SR model)
- swinv2 β ViTImageProcessor (Swin Transformer V2 model)
- table-transformer β DetrImageProcessor (Table Transformer model)
- timesformer β VideoMAEImageProcessor (TimeSformer model)
- tvlt β TvltImageProcessor (TVLT model)
- upernet β SegformerImageProcessor (UPerNet model)
- van β ConvNextImageProcessor (VAN model)
- videomae β VideoMAEImageProcessor (VideoMAE model)
- vilt β ViltImageProcessor (ViLT model)
- vit β ViTImageProcessor (ViT model)
- vit_hybrid β ViTHybridImageProcessor (ViT Hybrid model)
- vit_mae β ViTImageProcessor (ViTMAE model)
- vit_msn β ViTImageProcessor (ViTMSN model)
- xclip β CLIPImageProcessor (X-CLIP model)
- yolos β YolosImageProcessor (YOLOS model)
Passing token=True
is required when you want to use a private model.
Examples:
>>> from transformers import AutoImageProcessor
>>> # Download image processor from huggingface.co and cache.
>>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> # If image processor files are in a directory (e.g. image processor was saved using *save_pretrained('./test/saved_model/')*)
>>> # image_processor = AutoImageProcessor.from_pretrained("./test/saved_model/")
register
< source >( config_class image_processor_class exist_ok = False )
Parameters
- config_class (PretrainedConfig) — The configuration corresponding to the model to register.
- image_processor_class (ImageProcessingMixin) — The image processor to register.
Register a new image processor for this class.
AutoProcessor
This is a generic processor class that will be instantiated as one of the processor classes of the library when created with the AutoProcessor.from_pretrained() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_pretrained
< source >( pretrained_model_name_or_path **kwargs )
Parameters
-
pretrained_model_name_or_path (
str
oros.PathLike
) — This can be either:- a string, the model id of a pretrained feature_extractor hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like
bert-base-uncased
, or namespaced under a user or organization name, likedbmdz/bert-base-german-cased
. - a path to a directory containing a processor files saved using the
save_pretrained()
method, e.g.,./my_model_directory/
.
- a string, the model id of a pretrained feature_extractor hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like
-
cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model feature extractor should be cached if the standard cache should not be used. -
force_download (
bool
, optional, defaults toFalse
) — Whether or not to force to (re-)download the feature extractor files and override the cached versions if they exist. -
resume_download (
bool
, optional, defaults toFalse
) — Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists. -
proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request. -
token (
str
or bool, optional) — The token to use as HTTP bearer authorization for remote files. IfTrue
, will use the token generated when runninghuggingface-cli login
(stored in~/.huggingface
). -
revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
return_unused_kwargs (
bool
, optional, defaults toFalse
) — IfFalse
, then this function returns just the final feature extractor object. IfTrue
, then this functions returns aTuple(feature_extractor, unused_kwargs)
where unused_kwargs is a dictionary consisting of the key/value pairs whose keys are not feature extractor attributes: i.e., the part ofkwargs
which has not been used to updatefeature_extractor
and is otherwise ignored. -
trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. -
kwargs (
Dict[str, Any]
, optional) — The values in kwargs of any keys which are feature extractor attributes will be used to override the loaded values. Behavior concerning key/value pairs whose keys are not feature extractor attributes is controlled by thereturn_unused_kwargs
keyword parameter.
Instantiate one of the processor classes of the library from a pretrained model vocabulary.
The processor class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible):
- align β AlignProcessor (ALIGN model)
- altclip β AltCLIPProcessor (AltCLIP model)
- bark β BarkProcessor (Bark model)
- blip β BlipProcessor (BLIP model)
- blip-2 β Blip2Processor (BLIP-2 model)
- bridgetower β BridgeTowerProcessor (BridgeTower model)
- chinese_clip β ChineseCLIPProcessor (Chinese-CLIP model)
- clap β ClapProcessor (CLAP model)
- clip β CLIPProcessor (CLIP model)
- clipseg β CLIPSegProcessor (CLIPSeg model)
- flava β FlavaProcessor (FLAVA model)
- git β GitProcessor (GIT model)
- groupvit β CLIPProcessor (GroupViT model)
- hubert β Wav2Vec2Processor (Hubert model)
- idefics β IdeficsProcessor (IDEFICS model)
- instructblip β InstructBlipProcessor (InstructBLIP model)
- layoutlmv2 β LayoutLMv2Processor (LayoutLMv2 model)
- layoutlmv3 β LayoutLMv3Processor (LayoutLMv3 model)
- markuplm β MarkupLMProcessor (MarkupLM model)
- mctct β MCTCTProcessor (M-CTC-T model)
- mgp-str β MgpstrProcessor (MGP-STR model)
- oneformer β OneFormerProcessor (OneFormer model)
- owlvit β OwlViTProcessor (OWL-ViT model)
- pix2struct β Pix2StructProcessor (Pix2Struct model)
- pop2piano β Pop2PianoProcessor (Pop2Piano model)
- sam β SamProcessor (SAM model)
- sew β Wav2Vec2Processor (SEW model)
- sew-d β Wav2Vec2Processor (SEW-D model)
- speech_to_text β Speech2TextProcessor (Speech2Text model)
- speech_to_text_2 β Speech2Text2Processor (Speech2Text2 model)
- speecht5 β SpeechT5Processor (SpeechT5 model)
- trocr β TrOCRProcessor (TrOCR model)
- tvlt β TvltProcessor (TVLT model)
- unispeech β Wav2Vec2Processor (UniSpeech model)
- unispeech-sat β Wav2Vec2Processor (UniSpeechSat model)
- vilt β ViltProcessor (ViLT model)
- vision-text-dual-encoder β VisionTextDualEncoderProcessor (VisionTextDualEncoder model)
- wav2vec2 β Wav2Vec2Processor (Wav2Vec2 model)
- wav2vec2-conformer β Wav2Vec2Processor (Wav2Vec2-Conformer model)
- wavlm β Wav2Vec2Processor (WavLM model)
- whisper β WhisperProcessor (Whisper model)
- xclip β XCLIPProcessor (X-CLIP model)
Passing token=True
is required when you want to use a private model.
Examples:
>>> from transformers import AutoProcessor
>>> # Download processor from huggingface.co and cache.
>>> processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base-960h")
>>> # If processor files are in a directory (e.g. processor was saved using *save_pretrained('./test/saved_model/')*)
>>> # processor = AutoProcessor.from_pretrained("./test/saved_model/")
register
< source >( config_class processor_class exist_ok = False )
Parameters
- config_class (PretrainedConfig) — The configuration corresponding to the model to register.
-
processor_class (
FeatureExtractorMixin
) — The processor to register.
Register a new processor for this class.
Generic model classes
The following auto classes are available for instantiating a base model class without a specific head.
AutoModel
This is a generic model class that will be instantiated as one of the base model classes of the library when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
-
config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- ASTConfig configuration class: ASTModel (Audio Spectrogram Transformer model)
- AlbertConfig configuration class: AlbertModel (ALBERT model)
- AlignConfig configuration class: AlignModel (ALIGN model)
- AltCLIPConfig configuration class: AltCLIPModel (AltCLIP model)
- AutoformerConfig configuration class: AutoformerModel (Autoformer model)
- BarkConfig configuration class: BarkModel (Bark model)
- BartConfig configuration class: BartModel (BART model)
- BeitConfig configuration class: BeitModel (BEiT model)
- BertConfig configuration class: BertModel (BERT model)
- BertGenerationConfig configuration class: BertGenerationEncoder (Bert Generation model)
- BigBirdConfig configuration class: BigBirdModel (BigBird model)
- BigBirdPegasusConfig configuration class: BigBirdPegasusModel (BigBird-Pegasus model)
- BioGptConfig configuration class: BioGptModel (BioGpt model)
- BitConfig configuration class: BitModel (BiT model)
- BlenderbotConfig configuration class: BlenderbotModel (Blenderbot model)
- BlenderbotSmallConfig configuration class: BlenderbotSmallModel (BlenderbotSmall model)
- Blip2Config configuration class: Blip2Model (BLIP-2 model)
- BlipConfig configuration class: BlipModel (BLIP model)
- BloomConfig configuration class: BloomModel (BLOOM model)
- BridgeTowerConfig configuration class: BridgeTowerModel (BridgeTower model)
- CLIPConfig configuration class: CLIPModel (CLIP model)
- CLIPSegConfig configuration class: CLIPSegModel (CLIPSeg model)
- CTRLConfig configuration class: CTRLModel (CTRL model)
- CamembertConfig configuration class: CamembertModel (CamemBERT model)
- CanineConfig configuration class: CanineModel (CANINE model)
- ChineseCLIPConfig configuration class: ChineseCLIPModel (Chinese-CLIP model)
- ClapConfig configuration class: ClapModel (CLAP model)
- CodeGenConfig configuration class: CodeGenModel (CodeGen model)
- ConditionalDetrConfig configuration class: ConditionalDetrModel (Conditional DETR model)
- ConvBertConfig configuration class: ConvBertModel (ConvBERT model)
- ConvNextConfig configuration class: ConvNextModel (ConvNeXT model)
- ConvNextV2Config configuration class: ConvNextV2Model (ConvNeXTV2 model)
- CpmAntConfig configuration class: CpmAntModel (CPM-Ant model)
- CvtConfig configuration class: CvtModel (CvT model)
- DPRConfig configuration class: DPRQuestionEncoder (DPR model)
- DPTConfig configuration class: DPTModel (DPT model)
- Data2VecAudioConfig configuration class: Data2VecAudioModel (Data2VecAudio model)
- Data2VecTextConfig configuration class: Data2VecTextModel (Data2VecText model)
- Data2VecVisionConfig configuration class: Data2VecVisionModel (Data2VecVision model)
- DebertaConfig configuration class: DebertaModel (DeBERTa model)
- DebertaV2Config configuration class: DebertaV2Model (DeBERTa-v2 model)
- DecisionTransformerConfig configuration class: DecisionTransformerModel (Decision Transformer model)
- DeformableDetrConfig configuration class: DeformableDetrModel (Deformable DETR model)
- DeiTConfig configuration class: DeiTModel (DeiT model)
- DetaConfig configuration class: DetaModel (DETA model)
- DetrConfig configuration class: DetrModel (DETR model)
- DinatConfig configuration class: DinatModel (DiNAT model)
- Dinov2Config configuration class: Dinov2Model (DINOv2 model)
- DistilBertConfig configuration class: DistilBertModel (DistilBERT model)
- DonutSwinConfig configuration class: DonutSwinModel (DonutSwin model)
- EfficientFormerConfig configuration class: EfficientFormerModel (EfficientFormer model)
- EfficientNetConfig configuration class: EfficientNetModel (EfficientNet model)
- ElectraConfig configuration class: ElectraModel (ELECTRA model)
- EncodecConfig configuration class: EncodecModel (EnCodec model)
- ErnieConfig configuration class: ErnieModel (ERNIE model)
- ErnieMConfig configuration class: ErnieMModel (ErnieM model)
- EsmConfig configuration class: EsmModel (ESM model)
- FNetConfig configuration class: FNetModel (FNet model)
- FSMTConfig configuration class: FSMTModel (FairSeq Machine-Translation model)
- FalconConfig configuration class: FalconModel (Falcon model)
- FlaubertConfig configuration class: FlaubertModel (FlauBERT model)
- FlavaConfig configuration class: FlavaModel (FLAVA model)
- FocalNetConfig configuration class: FocalNetModel (FocalNet model)
- FunnelConfig configuration class: FunnelModel or FunnelBaseModel (Funnel Transformer model)
- GLPNConfig configuration class: GLPNModel (GLPN model)
- GPT2Config configuration class: GPT2Model (OpenAI GPT-2 model)
- GPTBigCodeConfig configuration class: GPTBigCodeModel (GPTBigCode model)
- GPTJConfig configuration class: GPTJModel (GPT-J model)
- GPTNeoConfig configuration class: GPTNeoModel (GPT Neo model)
- GPTNeoXConfig configuration class: GPTNeoXModel (GPT NeoX model)
- GPTNeoXJapaneseConfig configuration class: GPTNeoXJapaneseModel (GPT NeoX Japanese model)
- GPTSanJapaneseConfig configuration class: GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese model)
- GitConfig configuration class: GitModel (GIT model)
- GraphormerConfig configuration class: GraphormerModel (Graphormer model)
- GroupViTConfig configuration class: GroupViTModel (GroupViT model)
- HubertConfig configuration class: HubertModel (Hubert model)
- IBertConfig configuration class: IBertModel (I-BERT model)
- IdeficsConfig configuration class: IdeficsModel (IDEFICS model)
- ImageGPTConfig configuration class: ImageGPTModel (ImageGPT model)
- InformerConfig configuration class: InformerModel (Informer model)
- JukeboxConfig configuration class: JukeboxModel (Jukebox model)
- LEDConfig configuration class: LEDModel (LED model)
- LayoutLMConfig configuration class: LayoutLMModel (LayoutLM model)
- LayoutLMv2Config configuration class: LayoutLMv2Model (LayoutLMv2 model)
- LayoutLMv3Config configuration class: LayoutLMv3Model (LayoutLMv3 model)
- LevitConfig configuration class: LevitModel (LeViT model)
- LiltConfig configuration class: LiltModel (LiLT model)
- LlamaConfig configuration class: LlamaModel (LLaMA model)
- LongT5Config configuration class: LongT5Model (LongT5 model)
- LongformerConfig configuration class: LongformerModel (Longformer model)
- LukeConfig configuration class: LukeModel (LUKE model)
- LxmertConfig configuration class: LxmertModel (LXMERT model)
- M2M100Config configuration class: M2M100Model (M2M100 model)
- MBartConfig configuration class: MBartModel (mBART model)
- MCTCTConfig configuration class: MCTCTModel (M-CTC-T model)
- MPNetConfig configuration class: MPNetModel (MPNet model)
- MT5Config configuration class: MT5Model (MT5 model)
- MarianConfig configuration class: MarianModel (Marian model)
- MarkupLMConfig configuration class: MarkupLMModel (MarkupLM model)
- Mask2FormerConfig configuration class: Mask2FormerModel (Mask2Former model)
- MaskFormerConfig configuration class: MaskFormerModel (MaskFormer model)
MaskFormerSwinConfig
configuration class:MaskFormerSwinModel
(MaskFormerSwin model)- MegaConfig configuration class: MegaModel (MEGA model)
- MegatronBertConfig configuration class: MegatronBertModel (Megatron-BERT model)
- MgpstrConfig configuration class: MgpstrForSceneTextRecognition (MGP-STR model)
- MobileBertConfig configuration class: MobileBertModel (MobileBERT model)
- MobileNetV1Config configuration class: MobileNetV1Model (MobileNetV1 model)
- MobileNetV2Config configuration class: MobileNetV2Model (MobileNetV2 model)
- MobileViTConfig configuration class: MobileViTModel (MobileViT model)
- MobileViTV2Config configuration class: MobileViTV2Model (MobileViTV2 model)
- MptConfig configuration class: MptModel (MPT model)
- MraConfig configuration class: MraModel (MRA model)
- MvpConfig configuration class: MvpModel (MVP model)
- NatConfig configuration class: NatModel (NAT model)
- NezhaConfig configuration class: NezhaModel (Nezha model)
- NllbMoeConfig configuration class: NllbMoeModel (NLLB-MOE model)
- NystromformerConfig configuration class: NystromformerModel (Nyströmformer model)
- OPTConfig configuration class: OPTModel (OPT model)
- OneFormerConfig configuration class: OneFormerModel (OneFormer model)
- OpenAIGPTConfig configuration class: OpenAIGPTModel (OpenAI GPT model)
- OpenLlamaConfig configuration class: OpenLlamaModel (OpenLlama model)
- OwlViTConfig configuration class: OwlViTModel (OWL-ViT model)
- PLBartConfig configuration class: PLBartModel (PLBart model)
- PegasusConfig configuration class: PegasusModel (Pegasus model)
- PegasusXConfig configuration class: PegasusXModel (PEGASUS-X model)
- PerceiverConfig configuration class: PerceiverModel (Perceiver model)
- PoolFormerConfig configuration class: PoolFormerModel (PoolFormer model)
- ProphetNetConfig configuration class: ProphetNetModel (ProphetNet model)
- PvtConfig configuration class: PvtModel (PVT model)
- QDQBertConfig configuration class: QDQBertModel (QDQBert model)
- ReformerConfig configuration class: ReformerModel (Reformer model)
- RegNetConfig configuration class: RegNetModel (RegNet model)
- RemBertConfig configuration class: RemBertModel (RemBERT model)
- ResNetConfig configuration class: ResNetModel (ResNet model)
- RetriBertConfig configuration class: RetriBertModel (RetriBERT model)
- RoCBertConfig configuration class: RoCBertModel (RoCBert model)
- RoFormerConfig configuration class: RoFormerModel (RoFormer model)
- RobertaConfig configuration class: RobertaModel (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: RobertaPreLayerNormModel (RoBERTa-PreLayerNorm model)
- RwkvConfig configuration class: RwkvModel (RWKV model)
- SEWConfig configuration class: SEWModel (SEW model)
- SEWDConfig configuration class: SEWDModel (SEW-D model)
- SamConfig configuration class: SamModel (SAM model)
- SegformerConfig configuration class: SegformerModel (SegFormer model)
- Speech2TextConfig configuration class: Speech2TextModel (Speech2Text model)
- SpeechT5Config configuration class: SpeechT5Model (SpeechT5 model)
- SplinterConfig configuration class: SplinterModel (Splinter model)
- SqueezeBertConfig configuration class: SqueezeBertModel (SqueezeBERT model)
- SwiftFormerConfig configuration class: SwiftFormerModel (SwiftFormer model)
- Swin2SRConfig configuration class: Swin2SRModel (Swin2SR model)
- SwinConfig configuration class: SwinModel (Swin Transformer model)
- Swinv2Config configuration class: Swinv2Model (Swin Transformer V2 model)
- SwitchTransformersConfig configuration class: SwitchTransformersModel (SwitchTransformers model)
- T5Config configuration class: T5Model (T5 model)
- TableTransformerConfig configuration class: TableTransformerModel (Table Transformer model)
- TapasConfig configuration class: TapasModel (TAPAS model)
- TimeSeriesTransformerConfig configuration class: TimeSeriesTransformerModel (Time Series Transformer model)
- TimesformerConfig configuration class: TimesformerModel (TimeSformer model)
TimmBackboneConfig
configuration class:TimmBackbone
(TimmBackbone model)- TrajectoryTransformerConfig configuration class: TrajectoryTransformerModel (Trajectory Transformer model)
- TransfoXLConfig configuration class: TransfoXLModel (Transformer-XL model)
- TvltConfig configuration class: TvltModel (TVLT model)
- UMT5Config configuration class: UMT5Model (UMT5 model)
- UniSpeechConfig configuration class: UniSpeechModel (UniSpeech model)
- UniSpeechSatConfig configuration class: UniSpeechSatModel (UniSpeechSat model)
- VanConfig configuration class: VanModel (VAN model)
- ViTConfig configuration class: ViTModel (ViT model)
- ViTHybridConfig configuration class: ViTHybridModel (ViT Hybrid model)
- ViTMAEConfig configuration class: ViTMAEModel (ViTMAE model)
- ViTMSNConfig configuration class: ViTMSNModel (ViTMSN model)
- VideoMAEConfig configuration class: VideoMAEModel (VideoMAE model)
- ViltConfig configuration class: ViltModel (ViLT model)
- VisionTextDualEncoderConfig configuration class: VisionTextDualEncoderModel (VisionTextDualEncoder model)
- VisualBertConfig configuration class: VisualBertModel (VisualBERT model)
- VitDetConfig configuration class: VitDetModel (VitDet model)
- VitsConfig configuration class: VitsModel (VITS model)
- VivitConfig configuration class: VivitModel (ViViT model)
- Wav2Vec2Config configuration class: Wav2Vec2Model (Wav2Vec2 model)
- Wav2Vec2ConformerConfig configuration class: Wav2Vec2ConformerModel (Wav2Vec2-Conformer model)
- WavLMConfig configuration class: WavLMModel (WavLM model)
- WhisperConfig configuration class: WhisperModel (Whisper model)
- XCLIPConfig configuration class: XCLIPModel (X-CLIP model)
- XGLMConfig configuration class: XGLMModel (XGLM model)
- XLMConfig configuration class: XLMModel (XLM model)
- XLMProphetNetConfig configuration class: XLMProphetNetModel (XLM-ProphetNet model)
- XLMRobertaConfig configuration class: XLMRobertaModel (XLM-RoBERTa model)
- XLMRobertaXLConfig configuration class: XLMRobertaXLModel (XLM-RoBERTa-XL model)
- XLNetConfig configuration class: XLNetModel (XLNet model)
- XmodConfig configuration class: XmodModel (X-MOD model)
- YolosConfig configuration class: YolosModel (YOLOS model)
- YosoConfig configuration class: YosoModel (YOSO model)
Instantiates one of the base model classes of the library from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the modelβs configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
-
pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
bert-base-uncased
, or namespaced under a user or organization name, likedbmdz/bert-base-german-cased
. - A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
-
model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. -
config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
-
state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
-
cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. -
from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). -
force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. -
resume_download (
bool
, optional, defaults toFalse
) — Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. -
proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. -
output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. -
local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). -
revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. -
code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the base model classes of the library from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when itβs missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert β AlbertModel (ALBERT model)
- align β AlignModel (ALIGN model)
- altclip β AltCLIPModel (AltCLIP model)
- audio-spectrogram-transformer β ASTModel (Audio Spectrogram Transformer model)
- autoformer β AutoformerModel (Autoformer model)
- bark β BarkModel (Bark model)
- bart β BartModel (BART model)
- beit β BeitModel (BEiT model)
- bert β BertModel (BERT model)
- bert-generation β BertGenerationEncoder (Bert Generation model)
- big_bird β BigBirdModel (BigBird model)
- bigbird_pegasus β BigBirdPegasusModel (BigBird-Pegasus model)
- biogpt β BioGptModel (BioGpt model)
- bit β BitModel (BiT model)
- blenderbot β BlenderbotModel (Blenderbot model)
- blenderbot-small β BlenderbotSmallModel (BlenderbotSmall model)
- blip β BlipModel (BLIP model)
- blip-2 β Blip2Model (BLIP-2 model)
- bloom β BloomModel (BLOOM model)
- bridgetower β BridgeTowerModel (BridgeTower model)
- camembert β CamembertModel (CamemBERT model)
- canine β CanineModel (CANINE model)
- chinese_clip β ChineseCLIPModel (Chinese-CLIP model)
- clap β ClapModel (CLAP model)
- clip β CLIPModel (CLIP model)
- clipseg β CLIPSegModel (CLIPSeg model)
- code_llama β LlamaModel (CodeLlama model)
- codegen β CodeGenModel (CodeGen model)
- conditional_detr β ConditionalDetrModel (Conditional DETR model)
- convbert β ConvBertModel (ConvBERT model)
- convnext β ConvNextModel (ConvNeXT model)
- convnextv2 β ConvNextV2Model (ConvNeXTV2 model)
- cpmant β CpmAntModel (CPM-Ant model)
- ctrl β CTRLModel (CTRL model)
- cvt β CvtModel (CvT model)
- data2vec-audio β Data2VecAudioModel (Data2VecAudio model)
- data2vec-text β Data2VecTextModel (Data2VecText model)
- data2vec-vision β Data2VecVisionModel (Data2VecVision model)
- deberta β DebertaModel (DeBERTa model)
- deberta-v2 β DebertaV2Model (DeBERTa-v2 model)
- decision_transformer β DecisionTransformerModel (Decision Transformer model)
- deformable_detr β DeformableDetrModel (Deformable DETR model)
- deit β DeiTModel (DeiT model)
- deta β DetaModel (DETA model)
- detr β DetrModel (DETR model)
- dinat β DinatModel (DiNAT model)
- dinov2 β Dinov2Model (DINOv2 model)
- distilbert β DistilBertModel (DistilBERT model)
- donut-swin β DonutSwinModel (DonutSwin model)
- dpr β DPRQuestionEncoder (DPR model)
- dpt β DPTModel (DPT model)
- efficientformer β EfficientFormerModel (EfficientFormer model)
- efficientnet β EfficientNetModel (EfficientNet model)
- electra β ElectraModel (ELECTRA model)
- encodec β EncodecModel (EnCodec model)
- ernie β ErnieModel (ERNIE model)
- ernie_m β ErnieMModel (ErnieM model)
- esm β EsmModel (ESM model)
- falcon β FalconModel (Falcon model)
- flaubert β FlaubertModel (FlauBERT model)
- flava β FlavaModel (FLAVA model)
- fnet β FNetModel (FNet model)
- focalnet β FocalNetModel (FocalNet model)
- fsmt β FSMTModel (FairSeq Machine-Translation model)
- funnel β FunnelModel or FunnelBaseModel (Funnel Transformer model)
- git β GitModel (GIT model)
- glpn β GLPNModel (GLPN model)
- gpt-sw3 β GPT2Model (GPT-Sw3 model)
- gpt2 β GPT2Model (OpenAI GPT-2 model)
- gpt_bigcode β GPTBigCodeModel (GPTBigCode model)
- gpt_neo β GPTNeoModel (GPT Neo model)
- gpt_neox β GPTNeoXModel (GPT NeoX model)
- gpt_neox_japanese β GPTNeoXJapaneseModel (GPT NeoX Japanese model)
- gptj β GPTJModel (GPT-J model)
- gptsan-japanese β GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese model)
- graphormer β GraphormerModel (Graphormer model)
- groupvit β GroupViTModel (GroupViT model)
- hubert β HubertModel (Hubert model)
- ibert β IBertModel (I-BERT model)
- idefics β IdeficsModel (IDEFICS model)
- imagegpt β ImageGPTModel (ImageGPT model)
- informer β InformerModel (Informer model)
- jukebox β JukeboxModel (Jukebox model)
- layoutlm β LayoutLMModel (LayoutLM model)
- layoutlmv2 β LayoutLMv2Model (LayoutLMv2 model)
- layoutlmv3 β LayoutLMv3Model (LayoutLMv3 model)
- led β LEDModel (LED model)
- levit β LevitModel (LeViT model)
- lilt β LiltModel (LiLT model)
- llama β LlamaModel (LLaMA model)
- longformer β LongformerModel (Longformer model)
- longt5 β LongT5Model (LongT5 model)
- luke β LukeModel (LUKE model)
- lxmert β LxmertModel (LXMERT model)
- m2m_100 β M2M100Model (M2M100 model)
- marian β MarianModel (Marian model)
- markuplm β MarkupLMModel (MarkupLM model)
- mask2former β Mask2FormerModel (Mask2Former model)
- maskformer β MaskFormerModel (MaskFormer model)
- maskformer-swin β
MaskFormerSwinModel
(MaskFormerSwin model) - mbart β MBartModel (mBART model)
- mctct β MCTCTModel (M-CTC-T model)
- mega β MegaModel (MEGA model)
- megatron-bert β MegatronBertModel (Megatron-BERT model)
- mgp-str β MgpstrForSceneTextRecognition (MGP-STR model)
- mobilebert β MobileBertModel (MobileBERT model)
- mobilenet_v1 β MobileNetV1Model (MobileNetV1 model)
- mobilenet_v2 β MobileNetV2Model (MobileNetV2 model)
- mobilevit β MobileViTModel (MobileViT model)
- mobilevitv2 β MobileViTV2Model (MobileViTV2 model)
- mpnet β MPNetModel (MPNet model)
- mpt β MptModel (MPT model)
- mra β MraModel (MRA model)
- mt5 β MT5Model (MT5 model)
- mvp β MvpModel (MVP model)
- nat β NatModel (NAT model)
- nezha β NezhaModel (Nezha model)
- nllb-moe β NllbMoeModel (NLLB-MOE model)
- nystromformer β NystromformerModel (NystrΓΆmformer model)
- oneformer β OneFormerModel (OneFormer model)
- open-llama β OpenLlamaModel (OpenLlama model)
- openai-gpt β OpenAIGPTModel (OpenAI GPT model)
- opt β OPTModel (OPT model)
- owlvit β OwlViTModel (OWL-ViT model)
- pegasus β PegasusModel (Pegasus model)
- pegasus_x β PegasusXModel (PEGASUS-X model)
- perceiver β PerceiverModel (Perceiver model)
- plbart β PLBartModel (PLBart model)
- poolformer β PoolFormerModel (PoolFormer model)
- prophetnet β ProphetNetModel (ProphetNet model)
- pvt β PvtModel (PVT model)
- qdqbert β QDQBertModel (QDQBert model)
- reformer β ReformerModel (Reformer model)
- regnet β RegNetModel (RegNet model)
- rembert β RemBertModel (RemBERT model)
- resnet β ResNetModel (ResNet model)
- retribert β RetriBertModel (RetriBERT model)
- roberta β RobertaModel (RoBERTa model)
- roberta-prelayernorm β RobertaPreLayerNormModel (RoBERTa-PreLayerNorm model)
- roc_bert β RoCBertModel (RoCBert model)
- roformer β RoFormerModel (RoFormer model)
- rwkv β RwkvModel (RWKV model)
- sam β SamModel (SAM model)
- segformer β SegformerModel (SegFormer model)
- sew β SEWModel (SEW model)
- sew-d β SEWDModel (SEW-D model)
- speech_to_text β Speech2TextModel (Speech2Text model)
- speecht5 β SpeechT5Model (SpeechT5 model)
- splinter β SplinterModel (Splinter model)
- squeezebert β SqueezeBertModel (SqueezeBERT model)
- swiftformer β SwiftFormerModel (SwiftFormer model)
- swin β SwinModel (Swin Transformer model)
- swin2sr β Swin2SRModel (Swin2SR model)
- swinv2 β Swinv2Model (Swin Transformer V2 model)
- switch_transformers β SwitchTransformersModel (SwitchTransformers model)
- t5 β T5Model (T5 model)
- table-transformer β TableTransformerModel (Table Transformer model)
- tapas β TapasModel (TAPAS model)
- time_series_transformer β TimeSeriesTransformerModel (Time Series Transformer model)
- timesformer β TimesformerModel (TimeSformer model)
- timm_backbone β
TimmBackbone
(TimmBackbone model) - trajectory_transformer β TrajectoryTransformerModel (Trajectory Transformer model)
- transfo-xl β TransfoXLModel (Transformer-XL model)
- tvlt β TvltModel (TVLT model)
- umt5 β UMT5Model (UMT5 model)
- unispeech β UniSpeechModel (UniSpeech model)
- unispeech-sat β UniSpeechSatModel (UniSpeechSat model)
- van β VanModel (VAN model)
- videomae β VideoMAEModel (VideoMAE model)
- vilt β ViltModel (ViLT model)
- vision-text-dual-encoder β VisionTextDualEncoderModel (VisionTextDualEncoder model)
- visual_bert β VisualBertModel (VisualBERT model)
- vit β ViTModel (ViT model)
- vit_hybrid β ViTHybridModel (ViT Hybrid model)
- vit_mae β ViTMAEModel (ViTMAE model)
- vit_msn β ViTMSNModel (ViTMSN model)
- vitdet β VitDetModel (VitDet model)
- vits β VitsModel (VITS model)
- vivit β VivitModel (ViViT model)
- wav2vec2 β Wav2Vec2Model (Wav2Vec2 model)
- wav2vec2-conformer β Wav2Vec2ConformerModel (Wav2Vec2-Conformer model)
- wavlm β WavLMModel (WavLM model)
- whisper β WhisperModel (Whisper model)
- xclip β XCLIPModel (X-CLIP model)
- xglm β XGLMModel (XGLM model)
- xlm β XLMModel (XLM model)
- xlm-prophetnet β XLMProphetNetModel (XLM-ProphetNet model)
- xlm-roberta β XLMRobertaModel (XLM-RoBERTa model)
- xlm-roberta-xl β XLMRobertaXLModel (XLM-RoBERTa-XL model)
- xlnet β XLNetModel (XLNet model)
- xmod β XmodModel (X-MOD model)
- yolos β YolosModel (YOLOS model)
- yoso β YosoModel (YOSO model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModel.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModel.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModel.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModel
This is a generic model class that will be instantiated as one of the base model classes of the library when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
-
config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: TFAlbertModel (ALBERT model)
- BartConfig configuration class: TFBartModel (BART model)
- BertConfig configuration class: TFBertModel (BERT model)
- BlenderbotConfig configuration class: TFBlenderbotModel (Blenderbot model)
- BlenderbotSmallConfig configuration class: TFBlenderbotSmallModel (BlenderbotSmall model)
- BlipConfig configuration class: TFBlipModel (BLIP model)
- CLIPConfig configuration class: TFCLIPModel (CLIP model)
- CTRLConfig configuration class: TFCTRLModel (CTRL model)
- CamembertConfig configuration class: TFCamembertModel (CamemBERT model)
- ConvBertConfig configuration class: TFConvBertModel (ConvBERT model)
- ConvNextConfig configuration class: TFConvNextModel (ConvNeXT model)
- CvtConfig configuration class: TFCvtModel (CvT model)
- DPRConfig configuration class: TFDPRQuestionEncoder (DPR model)
- Data2VecVisionConfig configuration class: TFData2VecVisionModel (Data2VecVision model)
- DebertaConfig configuration class: TFDebertaModel (DeBERTa model)
- DebertaV2Config configuration class: TFDebertaV2Model (DeBERTa-v2 model)
- DeiTConfig configuration class: TFDeiTModel (DeiT model)
- DistilBertConfig configuration class: TFDistilBertModel (DistilBERT model)
- EfficientFormerConfig configuration class: TFEfficientFormerModel (EfficientFormer model)
- ElectraConfig configuration class: TFElectraModel (ELECTRA model)
- EsmConfig configuration class: TFEsmModel (ESM model)
- FlaubertConfig configuration class: TFFlaubertModel (FlauBERT model)
- FunnelConfig configuration class: TFFunnelModel or TFFunnelBaseModel (Funnel Transformer model)
- GPT2Config configuration class: TFGPT2Model (OpenAI GPT-2 model)
- GPTJConfig configuration class: TFGPTJModel (GPT-J model)
- GroupViTConfig configuration class: TFGroupViTModel (GroupViT model)
- HubertConfig configuration class: TFHubertModel (Hubert model)
- LEDConfig configuration class: TFLEDModel (LED model)
- LayoutLMConfig configuration class: TFLayoutLMModel (LayoutLM model)
- LayoutLMv3Config configuration class: TFLayoutLMv3Model (LayoutLMv3 model)
- LongformerConfig configuration class: TFLongformerModel (Longformer model)
- LxmertConfig configuration class: TFLxmertModel (LXMERT model)
- MBartConfig configuration class: TFMBartModel (mBART model)
- MPNetConfig configuration class: TFMPNetModel (MPNet model)
- MT5Config configuration class: TFMT5Model (MT5 model)
- MarianConfig configuration class: TFMarianModel (Marian model)
- MobileBertConfig configuration class: TFMobileBertModel (MobileBERT model)
- MobileViTConfig configuration class: TFMobileViTModel (MobileViT model)
- OPTConfig configuration class: TFOPTModel (OPT model)
- OpenAIGPTConfig configuration class: TFOpenAIGPTModel (OpenAI GPT model)
- PegasusConfig configuration class: TFPegasusModel (Pegasus model)
- RegNetConfig configuration class: TFRegNetModel (RegNet model)
- RemBertConfig configuration class: TFRemBertModel (RemBERT model)
- ResNetConfig configuration class: TFResNetModel (ResNet model)
- RoFormerConfig configuration class: TFRoFormerModel (RoFormer model)
- RobertaConfig configuration class: TFRobertaModel (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: TFRobertaPreLayerNormModel (RoBERTa-PreLayerNorm model)
- SamConfig configuration class: TFSamModel (SAM model)
- SegformerConfig configuration class: TFSegformerModel (SegFormer model)
- Speech2TextConfig configuration class: TFSpeech2TextModel (Speech2Text model)
- SwinConfig configuration class: TFSwinModel (Swin Transformer model)
- T5Config configuration class: TFT5Model (T5 model)
- TapasConfig configuration class: TFTapasModel (TAPAS model)
- TransfoXLConfig configuration class: TFTransfoXLModel (Transformer-XL model)
- ViTConfig configuration class: TFViTModel (ViT model)
- ViTMAEConfig configuration class: TFViTMAEModel (ViTMAE model)
- VisionTextDualEncoderConfig configuration class: TFVisionTextDualEncoderModel (VisionTextDualEncoder model)
- Wav2Vec2Config configuration class: TFWav2Vec2Model (Wav2Vec2 model)
- WhisperConfig configuration class: TFWhisperModel (Whisper model)
- XGLMConfig configuration class: TFXGLMModel (XGLM model)
- XLMConfig configuration class: TFXLMModel (XLM model)
- XLMRobertaConfig configuration class: TFXLMRobertaModel (XLM-RoBERTa model)
- XLNetConfig configuration class: TFXLNetModel (XLNet model)
Instantiates one of the base model classes of the library from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the modelβs configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
-
pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
bert-base-uncased
, or namespaced under a user or organization name, likedbmdz/bert-base-german-cased
. - A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
-
model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. -
config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
-
cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. -
from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). -
force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. -
resume_download (
bool
, optional, defaults toFalse
) — Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. -
proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. -
output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. -
local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). -
revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. -
code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the base model classes of the library from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when itβs missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert β TFAlbertModel (ALBERT model)
- bart β TFBartModel (BART model)
- bert β TFBertModel (BERT model)
- blenderbot β TFBlenderbotModel (Blenderbot model)
- blenderbot-small β TFBlenderbotSmallModel (BlenderbotSmall model)
- blip β TFBlipModel (BLIP model)
- camembert β TFCamembertModel (CamemBERT model)
- clip β TFCLIPModel (CLIP model)
- convbert β TFConvBertModel (ConvBERT model)
- convnext β TFConvNextModel (ConvNeXT model)
- ctrl β TFCTRLModel (CTRL model)
- cvt β TFCvtModel (CvT model)
- data2vec-vision β TFData2VecVisionModel (Data2VecVision model)
- deberta β TFDebertaModel (DeBERTa model)
- deberta-v2 β TFDebertaV2Model (DeBERTa-v2 model)
- deit β TFDeiTModel (DeiT model)
- distilbert β TFDistilBertModel (DistilBERT model)
- dpr β TFDPRQuestionEncoder (DPR model)
- efficientformer β TFEfficientFormerModel (EfficientFormer model)
- electra β TFElectraModel (ELECTRA model)
- esm β TFEsmModel (ESM model)
- flaubert β TFFlaubertModel (FlauBERT model)
- funnel β TFFunnelModel or TFFunnelBaseModel (Funnel Transformer model)
- gpt-sw3 β TFGPT2Model (GPT-Sw3 model)
- gpt2 β TFGPT2Model (OpenAI GPT-2 model)
- gptj β TFGPTJModel (GPT-J model)
- groupvit β TFGroupViTModel (GroupViT model)
- hubert β TFHubertModel (Hubert model)
- layoutlm β TFLayoutLMModel (LayoutLM model)
- layoutlmv3 β TFLayoutLMv3Model (LayoutLMv3 model)
- led β TFLEDModel (LED model)
- longformer β TFLongformerModel (Longformer model)
- lxmert β TFLxmertModel (LXMERT model)
- marian β TFMarianModel (Marian model)
- mbart β TFMBartModel (mBART model)
- mobilebert β TFMobileBertModel (MobileBERT model)
- mobilevit β TFMobileViTModel (MobileViT model)
- mpnet β TFMPNetModel (MPNet model)
- mt5 β TFMT5Model (MT5 model)
- openai-gpt β TFOpenAIGPTModel (OpenAI GPT model)
- opt β TFOPTModel (OPT model)
- pegasus β TFPegasusModel (Pegasus model)
- regnet β TFRegNetModel (RegNet model)
- rembert β TFRemBertModel (RemBERT model)
- resnet β TFResNetModel (ResNet model)
- roberta β TFRobertaModel (RoBERTa model)
- roberta-prelayernorm β TFRobertaPreLayerNormModel (RoBERTa-PreLayerNorm model)
- roformer β TFRoFormerModel (RoFormer model)
- sam β TFSamModel (SAM model)
- segformer β TFSegformerModel (SegFormer model)
- speech_to_text β TFSpeech2TextModel (Speech2Text model)
- swin β TFSwinModel (Swin Transformer model)
- t5 β TFT5Model (T5 model)
- tapas β TFTapasModel (TAPAS model)
- transfo-xl β TFTransfoXLModel (Transformer-XL model)
- vision-text-dual-encoder β TFVisionTextDualEncoderModel (VisionTextDualEncoder model)
- vit β TFViTModel (ViT model)
- vit_mae β TFViTMAEModel (ViTMAE model)
- wav2vec2 β TFWav2Vec2Model (Wav2Vec2 model)
- whisper β TFWhisperModel (Whisper model)
- xglm β TFXGLMModel (XGLM model)
- xlm β TFXLMModel (XLM model)
- xlm-roberta β TFXLMRobertaModel (XLM-RoBERTa model)
- xlnet β TFXLNetModel (XLNet model)
Examples:
>>> from transformers import AutoConfig, TFAutoModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModel.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModel.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModel.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModel
This is a generic model class that will be instantiated as one of the base model classes of the library when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
-
config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: FlaxAlbertModel (ALBERT model)
- BartConfig configuration class: FlaxBartModel (BART model)
- BeitConfig configuration class: FlaxBeitModel (BEiT model)
- BertConfig configuration class: FlaxBertModel (BERT model)
- BigBirdConfig configuration class: FlaxBigBirdModel (BigBird model)
- BlenderbotConfig configuration class: FlaxBlenderbotModel (Blenderbot model)
- BlenderbotSmallConfig configuration class: FlaxBlenderbotSmallModel (BlenderbotSmall model)
- BloomConfig configuration class: FlaxBloomModel (BLOOM model)
- CLIPConfig configuration class: FlaxCLIPModel (CLIP model)
- DistilBertConfig configuration class: FlaxDistilBertModel (DistilBERT model)
- ElectraConfig configuration class: FlaxElectraModel (ELECTRA model)
- GPT2Config configuration class: FlaxGPT2Model (OpenAI GPT-2 model)
- GPTJConfig configuration class: FlaxGPTJModel (GPT-J model)
- GPTNeoConfig configuration class: FlaxGPTNeoModel (GPT Neo model)
- LongT5Config configuration class: FlaxLongT5Model (LongT5 model)
- MBartConfig configuration class: FlaxMBartModel (mBART model)
- MT5Config configuration class: FlaxMT5Model (MT5 model)
- MarianConfig configuration class: FlaxMarianModel (Marian model)
- OPTConfig configuration class: FlaxOPTModel (OPT model)
- PegasusConfig configuration class: FlaxPegasusModel (Pegasus model)
- RegNetConfig configuration class: FlaxRegNetModel (RegNet model)
- ResNetConfig configuration class: FlaxResNetModel (ResNet model)
- RoFormerConfig configuration class: FlaxRoFormerModel (RoFormer model)
- RobertaConfig configuration class: FlaxRobertaModel (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: FlaxRobertaPreLayerNormModel (RoBERTa-PreLayerNorm model)
- T5Config configuration class: FlaxT5Model (T5 model)
- ViTConfig configuration class: FlaxViTModel (ViT model)
- VisionTextDualEncoderConfig configuration class: FlaxVisionTextDualEncoderModel (VisionTextDualEncoder model)
- Wav2Vec2Config configuration class: FlaxWav2Vec2Model (Wav2Vec2 model)
- WhisperConfig configuration class: FlaxWhisperModel (Whisper model)
- XGLMConfig configuration class: FlaxXGLMModel (XGLM model)
- XLMRobertaConfig configuration class: FlaxXLMRobertaModel (XLM-RoBERTa model)
Instantiates one of the base model classes of the library from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the modelβs configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
-
pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
bert-base-uncased
, or namespaced under a user or organization name, likedbmdz/bert-base-german-cased
. - A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
-
model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. -
config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
-
cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. -
from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). -
force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. -
resume_download (
bool
, optional, defaults toFalse
) — Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. -
proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. -
output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. -
local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). -
revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. -
code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the base model classes of the library from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when itβs missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert β FlaxAlbertModel (ALBERT model)
- bart β FlaxBartModel (BART model)
- beit β FlaxBeitModel (BEiT model)
- bert β FlaxBertModel (BERT model)
- big_bird β FlaxBigBirdModel (BigBird model)
- blenderbot β FlaxBlenderbotModel (Blenderbot model)
- blenderbot-small β FlaxBlenderbotSmallModel (BlenderbotSmall model)
- bloom β FlaxBloomModel (BLOOM model)
- clip β FlaxCLIPModel (CLIP model)
- distilbert β FlaxDistilBertModel (DistilBERT model)
- electra β FlaxElectraModel (ELECTRA model)
- gpt-sw3 β FlaxGPT2Model (GPT-Sw3 model)
- gpt2 β FlaxGPT2Model (OpenAI GPT-2 model)
- gpt_neo β FlaxGPTNeoModel (GPT Neo model)
- gptj β FlaxGPTJModel (GPT-J model)
- longt5 β FlaxLongT5Model (LongT5 model)
- marian β FlaxMarianModel (Marian model)
- mbart β FlaxMBartModel (mBART model)
- mt5 β FlaxMT5Model (MT5 model)
- opt β FlaxOPTModel (OPT model)
- pegasus β FlaxPegasusModel (Pegasus model)
- regnet β FlaxRegNetModel (RegNet model)
- resnet β FlaxResNetModel (ResNet model)
- roberta β FlaxRobertaModel (RoBERTa model)
- roberta-prelayernorm β FlaxRobertaPreLayerNormModel (RoBERTa-PreLayerNorm model)
- roformer β FlaxRoFormerModel (RoFormer model)
- t5 β FlaxT5Model (T5 model)
- vision-text-dual-encoder β FlaxVisionTextDualEncoderModel (VisionTextDualEncoder model)
- vit β FlaxViTModel (ViT model)
- wav2vec2 β FlaxWav2Vec2Model (Wav2Vec2 model)
- whisper β FlaxWhisperModel (Whisper model)
- xglm β FlaxXGLMModel (XGLM model)
- xlm-roberta β FlaxXLMRobertaModel (XLM-RoBERTa model)
Examples:
>>> from transformers import AutoConfig, FlaxAutoModel
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModel.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModel.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModel.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
Generic pretraining classes
The following auto classes are available for instantiating a model with a pretraining head.
AutoModelForPreTraining
This is a generic model class that will be instantiated as one of the model classes of the library (with a pretraining head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
-
config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: AlbertForPreTraining (ALBERT model)
- BartConfig configuration class: BartForConditionalGeneration (BART model)
- BertConfig configuration class: BertForPreTraining (BERT model)
- BigBirdConfig configuration class: BigBirdForPreTraining (BigBird model)
- BloomConfig configuration class: BloomForCausalLM (BLOOM model)
- CTRLConfig configuration class: CTRLLMHeadModel (CTRL model)
- CamembertConfig configuration class: CamembertForMaskedLM (CamemBERT model)
- Data2VecTextConfig configuration class: Data2VecTextForMaskedLM (Data2VecText model)
- DebertaConfig configuration class: DebertaForMaskedLM (DeBERTa model)
- DebertaV2Config configuration class: DebertaV2ForMaskedLM (DeBERTa-v2 model)
- DistilBertConfig configuration class: DistilBertForMaskedLM (DistilBERT model)
- ElectraConfig configuration class: ElectraForPreTraining (ELECTRA model)
- ErnieConfig configuration class: ErnieForPreTraining (ERNIE model)
- FNetConfig configuration class: FNetForPreTraining (FNet model)
- FSMTConfig configuration class: FSMTForConditionalGeneration (FairSeq Machine-Translation model)
- FlaubertConfig configuration class: FlaubertWithLMHeadModel (FlauBERT model)
- FlavaConfig configuration class: FlavaForPreTraining (FLAVA model)
- FunnelConfig configuration class: FunnelForPreTraining (Funnel Transformer model)
- GPT2Config configuration class: GPT2LMHeadModel (OpenAI GPT-2 model)
- GPTBigCodeConfig configuration class: GPTBigCodeForCausalLM (GPTBigCode model)
- GPTSanJapaneseConfig configuration class: GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese model)
- IBertConfig configuration class: IBertForMaskedLM (I-BERT model)
- IdeficsConfig configuration class: IdeficsForVisionText2Text (IDEFICS model)
- LayoutLMConfig configuration class: LayoutLMForMaskedLM (LayoutLM model)
- LongformerConfig configuration class: LongformerForMaskedLM (Longformer model)
- LukeConfig configuration class: LukeForMaskedLM (LUKE model)
- LxmertConfig configuration class: LxmertForPreTraining (LXMERT model)
- MPNetConfig configuration class: MPNetForMaskedLM (MPNet model)
- MegaConfig configuration class: MegaForMaskedLM (MEGA model)
- MegatronBertConfig configuration class: MegatronBertForPreTraining (Megatron-BERT model)
- MobileBertConfig configuration class: MobileBertForPreTraining (MobileBERT model)
- MptConfig configuration class: MptForCausalLM (MPT model)
- MraConfig configuration class: MraForMaskedLM (MRA model)
- MvpConfig configuration class: MvpForConditionalGeneration (MVP model)
- NezhaConfig configuration class: NezhaForPreTraining (Nezha model)
- NllbMoeConfig configuration class: NllbMoeForConditionalGeneration (NLLB-MOE model)
- OpenAIGPTConfig configuration class: OpenAIGPTLMHeadModel (OpenAI GPT model)
- RetriBertConfig configuration class: RetriBertModel (RetriBERT model)
- RoCBertConfig configuration class: RoCBertForPreTraining (RoCBert model)
- RobertaConfig configuration class: RobertaForMaskedLM (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: RobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm model)
- RwkvConfig configuration class: RwkvForCausalLM (RWKV model)
- SplinterConfig configuration class: SplinterForPreTraining (Splinter model)
- SqueezeBertConfig configuration class: SqueezeBertForMaskedLM (SqueezeBERT model)
- SwitchTransformersConfig configuration class: SwitchTransformersForConditionalGeneration (SwitchTransformers model)
- T5Config configuration class: T5ForConditionalGeneration (T5 model)
- TapasConfig configuration class: TapasForMaskedLM (TAPAS model)
- TransfoXLConfig configuration class: TransfoXLLMHeadModel (Transformer-XL model)
- TvltConfig configuration class: TvltForPreTraining (TVLT model)
- UniSpeechConfig configuration class: UniSpeechForPreTraining (UniSpeech model)
- UniSpeechSatConfig configuration class: UniSpeechSatForPreTraining (UniSpeechSat model)
- ViTMAEConfig configuration class: ViTMAEForPreTraining (ViTMAE model)
- VideoMAEConfig configuration class: VideoMAEForPreTraining (VideoMAE model)
- VisualBertConfig configuration class: VisualBertForPreTraining (VisualBERT model)
- Wav2Vec2Config configuration class: Wav2Vec2ForPreTraining (Wav2Vec2 model)
- Wav2Vec2ConformerConfig configuration class: Wav2Vec2ConformerForPreTraining (Wav2Vec2-Conformer model)
- XLMConfig configuration class: XLMWithLMHeadModel (XLM model)
- XLMRobertaConfig configuration class: XLMRobertaForMaskedLM (XLM-RoBERTa model)
- XLMRobertaXLConfig configuration class: XLMRobertaXLForMaskedLM (XLM-RoBERTa-XL model)
- XLNetConfig configuration class: XLNetLMHeadModel (XLNet model)
- XmodConfig configuration class: XmodForMaskedLM (X-MOD model)
Instantiates one of the model classes of the library (with a pretraining head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the modelβs configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
-
pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
bert-base-uncased
, or namespaced under a user or organization name, likedbmdz/bert-base-german-cased
. - A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
-
model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. -
config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
-
state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
-
cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. -
from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). -
force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. -
resume_download (
bool
, optional, defaults toFalse
) — Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. -
proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. -
output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. -
local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). -
revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. -
code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a pretraining head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when itβs missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert β AlbertForPreTraining (ALBERT model)
- bart β BartForConditionalGeneration (BART model)
- bert β BertForPreTraining (BERT model)
- big_bird β BigBirdForPreTraining (BigBird model)
- bloom β BloomForCausalLM (BLOOM model)
- camembert β CamembertForMaskedLM (CamemBERT model)
- ctrl β CTRLLMHeadModel (CTRL model)
- data2vec-text β Data2VecTextForMaskedLM (Data2VecText model)
- deberta β DebertaForMaskedLM (DeBERTa model)
- deberta-v2 β DebertaV2ForMaskedLM (DeBERTa-v2 model)
- distilbert β DistilBertForMaskedLM (DistilBERT model)
- electra β ElectraForPreTraining (ELECTRA model)
- ernie β ErnieForPreTraining (ERNIE model)
- flaubert β FlaubertWithLMHeadModel (FlauBERT model)
- flava β FlavaForPreTraining (FLAVA model)
- fnet β FNetForPreTraining (FNet model)
- fsmt β FSMTForConditionalGeneration (FairSeq Machine-Translation model)
- funnel β FunnelForPreTraining (Funnel Transformer model)
- gpt-sw3 β GPT2LMHeadModel (GPT-Sw3 model)
- gpt2 β GPT2LMHeadModel (OpenAI GPT-2 model)
- gpt_bigcode β GPTBigCodeForCausalLM (GPTBigCode model)
- gptsan-japanese β GPTSanJapaneseForConditionalGeneration (GPTSAN-japanese model)
- ibert β IBertForMaskedLM (I-BERT model)
- idefics β IdeficsForVisionText2Text (IDEFICS model)
- layoutlm β LayoutLMForMaskedLM (LayoutLM model)
- longformer β LongformerForMaskedLM (Longformer model)
- luke β LukeForMaskedLM (LUKE model)
- lxmert β LxmertForPreTraining (LXMERT model)
- mega β MegaForMaskedLM (MEGA model)
- megatron-bert β MegatronBertForPreTraining (Megatron-BERT model)
- mobilebert β MobileBertForPreTraining (MobileBERT model)
- mpnet β MPNetForMaskedLM (MPNet model)
- mpt β MptForCausalLM (MPT model)
- mra β MraForMaskedLM (MRA model)
- mvp β MvpForConditionalGeneration (MVP model)
- nezha β NezhaForPreTraining (Nezha model)
- nllb-moe β NllbMoeForConditionalGeneration (NLLB-MOE model)
- openai-gpt β OpenAIGPTLMHeadModel (OpenAI GPT model)
- retribert β RetriBertModel (RetriBERT model)
- roberta β RobertaForMaskedLM (RoBERTa model)
- roberta-prelayernorm β RobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm model)
- roc_bert β RoCBertForPreTraining (RoCBert model)
- rwkv β RwkvForCausalLM (RWKV model)
- splinter β SplinterForPreTraining (Splinter model)
- squeezebert β SqueezeBertForMaskedLM (SqueezeBERT model)
- switch_transformers β SwitchTransformersForConditionalGeneration (SwitchTransformers model)
- t5 β T5ForConditionalGeneration (T5 model)
- tapas β TapasForMaskedLM (TAPAS model)
- transfo-xl β TransfoXLLMHeadModel (Transformer-XL model)
- tvlt β TvltForPreTraining (TVLT model)
- unispeech β UniSpeechForPreTraining (UniSpeech model)
- unispeech-sat β UniSpeechSatForPreTraining (UniSpeechSat model)
- videomae β VideoMAEForPreTraining (VideoMAE model)
- visual_bert β VisualBertForPreTraining (VisualBERT model)
- vit_mae β ViTMAEForPreTraining (ViTMAE model)
- wav2vec2 β Wav2Vec2ForPreTraining (Wav2Vec2 model)
- wav2vec2-conformer β Wav2Vec2ConformerForPreTraining (Wav2Vec2-Conformer model)
- xlm β XLMWithLMHeadModel (XLM model)
- xlm-roberta β XLMRobertaForMaskedLM (XLM-RoBERTa model)
- xlm-roberta-xl β XLMRobertaXLForMaskedLM (XLM-RoBERTa-XL model)
- xlnet β XLNetLMHeadModel (XLNet model)
- xmod β XmodForMaskedLM (X-MOD model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForPreTraining
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForPreTraining.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForPreTraining.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForPreTraining.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForPreTraining
This is a generic model class that will be instantiated as one of the model classes of the library (with a pretraining head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
-
config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: TFAlbertForPreTraining (ALBERT model)
- BartConfig configuration class: TFBartForConditionalGeneration (BART model)
- BertConfig configuration class: TFBertForPreTraining (BERT model)
- CTRLConfig configuration class: TFCTRLLMHeadModel (CTRL model)
- CamembertConfig configuration class: TFCamembertForMaskedLM (CamemBERT model)
- DistilBertConfig configuration class: TFDistilBertForMaskedLM (DistilBERT model)
- ElectraConfig configuration class: TFElectraForPreTraining (ELECTRA model)
- FlaubertConfig configuration class: TFFlaubertWithLMHeadModel (FlauBERT model)
- FunnelConfig configuration class: TFFunnelForPreTraining (Funnel Transformer model)
- GPT2Config configuration class: TFGPT2LMHeadModel (OpenAI GPT-2 model)
- LayoutLMConfig configuration class: TFLayoutLMForMaskedLM (LayoutLM model)
- LxmertConfig configuration class: TFLxmertForPreTraining (LXMERT model)
- MPNetConfig configuration class: TFMPNetForMaskedLM (MPNet model)
- MobileBertConfig configuration class: TFMobileBertForPreTraining (MobileBERT model)
- OpenAIGPTConfig configuration class: TFOpenAIGPTLMHeadModel (OpenAI GPT model)
- RobertaConfig configuration class: TFRobertaForMaskedLM (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: TFRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm model)
- T5Config configuration class: TFT5ForConditionalGeneration (T5 model)
- TapasConfig configuration class: TFTapasForMaskedLM (TAPAS model)
- TransfoXLConfig configuration class: TFTransfoXLLMHeadModel (Transformer-XL model)
- ViTMAEConfig configuration class: TFViTMAEForPreTraining (ViTMAE model)
- XLMConfig configuration class: TFXLMWithLMHeadModel (XLM model)
- XLMRobertaConfig configuration class: TFXLMRobertaForMaskedLM (XLM-RoBERTa model)
- XLNetConfig configuration class: TFXLNetLMHeadModel (XLNet model)
Instantiates one of the model classes of the library (with a pretraining head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the modelβs configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
-
pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
bert-base-uncased
, or namespaced under a user or organization name, likedbmdz/bert-base-german-cased
. - A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
-
model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. -
config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
-
cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. -
from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). -
force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. -
resume_download (
bool
, optional, defaults toFalse
) — Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. -
proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. -
output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. -
local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). -
revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. -
code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a pretraining head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when itβs missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert β TFAlbertForPreTraining (ALBERT model)
- bart β TFBartForConditionalGeneration (BART model)
- bert β TFBertForPreTraining (BERT model)
- camembert β TFCamembertForMaskedLM (CamemBERT model)
- ctrl β TFCTRLLMHeadModel (CTRL model)
- distilbert β TFDistilBertForMaskedLM (DistilBERT model)
- electra β TFElectraForPreTraining (ELECTRA model)
- flaubert β TFFlaubertWithLMHeadModel (FlauBERT model)
- funnel β TFFunnelForPreTraining (Funnel Transformer model)
- gpt-sw3 β TFGPT2LMHeadModel (GPT-Sw3 model)
- gpt2 β TFGPT2LMHeadModel (OpenAI GPT-2 model)
- layoutlm β TFLayoutLMForMaskedLM (LayoutLM model)
- lxmert β TFLxmertForPreTraining (LXMERT model)
- mobilebert β TFMobileBertForPreTraining (MobileBERT model)
- mpnet β TFMPNetForMaskedLM (MPNet model)
- openai-gpt β TFOpenAIGPTLMHeadModel (OpenAI GPT model)
- roberta β TFRobertaForMaskedLM (RoBERTa model)
- roberta-prelayernorm β TFRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm model)
- t5 β TFT5ForConditionalGeneration (T5 model)
- tapas β TFTapasForMaskedLM (TAPAS model)
- transfo-xl β TFTransfoXLLMHeadModel (Transformer-XL model)
- vit_mae β TFViTMAEForPreTraining (ViTMAE model)
- xlm β TFXLMWithLMHeadModel (XLM model)
- xlm-roberta β TFXLMRobertaForMaskedLM (XLM-RoBERTa model)
- xlnet β TFXLNetLMHeadModel (XLNet model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForPreTraining
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForPreTraining.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForPreTraining.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForPreTraining.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForPreTraining
This is a generic model class that will be instantiated as one of the model classes of the library (with a pretraining head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
-
config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- AlbertConfig configuration class: FlaxAlbertForPreTraining (ALBERT model)
- BartConfig configuration class: FlaxBartForConditionalGeneration (BART model)
- BertConfig configuration class: FlaxBertForPreTraining (BERT model)
- BigBirdConfig configuration class: FlaxBigBirdForPreTraining (BigBird model)
- ElectraConfig configuration class: FlaxElectraForPreTraining (ELECTRA model)
- LongT5Config configuration class: FlaxLongT5ForConditionalGeneration (LongT5 model)
- MBartConfig configuration class: FlaxMBartForConditionalGeneration (mBART model)
- MT5Config configuration class: FlaxMT5ForConditionalGeneration (MT5 model)
- RoFormerConfig configuration class: FlaxRoFormerForMaskedLM (RoFormer model)
- RobertaConfig configuration class: FlaxRobertaForMaskedLM (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: FlaxRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm model)
- T5Config configuration class: FlaxT5ForConditionalGeneration (T5 model)
- Wav2Vec2Config configuration class: FlaxWav2Vec2ForPreTraining (Wav2Vec2 model)
- WhisperConfig configuration class: FlaxWhisperForConditionalGeneration (Whisper model)
- XLMRobertaConfig configuration class: FlaxXLMRobertaForMaskedLM (XLM-RoBERTa model)
Instantiates one of the model classes of the library (with a pretraining head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the modelβs configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
-
pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
bert-base-uncased
, or namespaced under a user or organization name, likedbmdz/bert-base-german-cased
. - A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
-
model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. -
config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
-
cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. -
from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). -
force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. -
resume_download (
bool
, optional, defaults toFalse
) — Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. -
proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. -
output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. -
local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). -
revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. -
code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a pretraining head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when itβs missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- albert β FlaxAlbertForPreTraining (ALBERT model)
- bart β FlaxBartForConditionalGeneration (BART model)
- bert β FlaxBertForPreTraining (BERT model)
- big_bird β FlaxBigBirdForPreTraining (BigBird model)
- electra β FlaxElectraForPreTraining (ELECTRA model)
- longt5 β FlaxLongT5ForConditionalGeneration (LongT5 model)
- mbart β FlaxMBartForConditionalGeneration (mBART model)
- mt5 β FlaxMT5ForConditionalGeneration (MT5 model)
- roberta β FlaxRobertaForMaskedLM (RoBERTa model)
- roberta-prelayernorm β FlaxRobertaPreLayerNormForMaskedLM (RoBERTa-PreLayerNorm model)
- roformer β FlaxRoFormerForMaskedLM (RoFormer model)
- t5 β FlaxT5ForConditionalGeneration (T5 model)
- wav2vec2 β FlaxWav2Vec2ForPreTraining (Wav2Vec2 model)
- whisper β FlaxWhisperForConditionalGeneration (Whisper model)
- xlm-roberta β FlaxXLMRobertaForMaskedLM (XLM-RoBERTa model)
Examples:
>>> from transformers import AutoConfig, FlaxAutoModelForPreTraining
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForPreTraining.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForPreTraining.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForPreTraining.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
Natural Language Processing
The following auto classes are available for the following natural language processing tasks.
AutoModelForCausalLM
This is a generic model class that will be instantiated as one of the model classes of the library (with a causal language modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
-
config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BartConfig configuration class: BartForCausalLM (BART model)
- BertConfig configuration class: BertLMHeadModel (BERT model)
- BertGenerationConfig configuration class: BertGenerationDecoder (Bert Generation model)
- BigBirdConfig configuration class: BigBirdForCausalLM (BigBird model)
- BigBirdPegasusConfig configuration class: BigBirdPegasusForCausalLM (BigBird-Pegasus model)
- BioGptConfig configuration class: BioGptForCausalLM (BioGpt model)
- BlenderbotConfig configuration class: BlenderbotForCausalLM (Blenderbot model)
- BlenderbotSmallConfig configuration class: BlenderbotSmallForCausalLM (BlenderbotSmall model)
- BloomConfig configuration class: BloomForCausalLM (BLOOM model)
- CTRLConfig configuration class: CTRLLMHeadModel (CTRL model)
- CamembertConfig configuration class: CamembertForCausalLM (CamemBERT model)
- CodeGenConfig configuration class: CodeGenForCausalLM (CodeGen model)
- CpmAntConfig configuration class: CpmAntForCausalLM (CPM-Ant model)
- Data2VecTextConfig configuration class: Data2VecTextForCausalLM (Data2VecText model)
- ElectraConfig configuration class: ElectraForCausalLM (ELECTRA model)
- ErnieConfig configuration class: ErnieForCausalLM (ERNIE model)
- FalconConfig configuration class: FalconForCausalLM (Falcon model)
- GPT2Config configuration class: GPT2LMHeadModel (OpenAI GPT-2 model)
- GPTBigCodeConfig configuration class: GPTBigCodeForCausalLM (GPTBigCode model)
- GPTJConfig configuration class: GPTJForCausalLM (GPT-J model)
- GPTNeoConfig configuration class: GPTNeoForCausalLM (GPT Neo model)
- GPTNeoXConfig configuration class: GPTNeoXForCausalLM (GPT NeoX model)
- GPTNeoXJapaneseConfig configuration class: GPTNeoXJapaneseForCausalLM (GPT NeoX Japanese model)
- GitConfig configuration class: GitForCausalLM (GIT model)
- LlamaConfig configuration class: LlamaForCausalLM (LLaMA model)
- MBartConfig configuration class: MBartForCausalLM (mBART model)
- MarianConfig configuration class: MarianForCausalLM (Marian model)
- MegaConfig configuration class: MegaForCausalLM (MEGA model)
- MegatronBertConfig configuration class: MegatronBertForCausalLM (Megatron-BERT model)
- MptConfig configuration class: MptForCausalLM (MPT model)
- MusicgenConfig configuration class: MusicgenForCausalLM (MusicGen model)
- MvpConfig configuration class: MvpForCausalLM (MVP model)
- OPTConfig configuration class: OPTForCausalLM (OPT model)
- OpenAIGPTConfig configuration class: OpenAIGPTLMHeadModel (OpenAI GPT model)
- OpenLlamaConfig configuration class: OpenLlamaForCausalLM (OpenLlama model)
- PLBartConfig configuration class: PLBartForCausalLM (PLBart model)
- PegasusConfig configuration class: PegasusForCausalLM (Pegasus model)
- ProphetNetConfig configuration class: ProphetNetForCausalLM (ProphetNet model)
- QDQBertConfig configuration class: QDQBertLMHeadModel (QDQBert model)
- ReformerConfig configuration class: ReformerModelWithLMHead (Reformer model)
- RemBertConfig configuration class: RemBertForCausalLM (RemBERT model)
- RoCBertConfig configuration class: RoCBertForCausalLM (RoCBert model)
- RoFormerConfig configuration class: RoFormerForCausalLM (RoFormer model)
- RobertaConfig configuration class: RobertaForCausalLM (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: RobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm model)
- RwkvConfig configuration class: RwkvForCausalLM (RWKV model)
- Speech2Text2Config configuration class: Speech2Text2ForCausalLM (Speech2Text2 model)
- TrOCRConfig configuration class: TrOCRForCausalLM (TrOCR model)
- TransfoXLConfig configuration class: TransfoXLLMHeadModel (Transformer-XL model)
- XGLMConfig configuration class: XGLMForCausalLM (XGLM model)
- XLMConfig configuration class: XLMWithLMHeadModel (XLM model)
- XLMProphetNetConfig configuration class: XLMProphetNetForCausalLM (XLM-ProphetNet model)
- XLMRobertaConfig configuration class: XLMRobertaForCausalLM (XLM-RoBERTa model)
- XLMRobertaXLConfig configuration class: XLMRobertaXLForCausalLM (XLM-RoBERTa-XL model)
- XLNetConfig configuration class: XLNetLMHeadModel (XLNet model)
- XmodConfig configuration class: XmodForCausalLM (X-MOD model)
Instantiates one of the model classes of the library (with a causal language modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the modelβs configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
-
pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
bert-base-uncased
, or namespaced under a user or organization name, likedbmdz/bert-base-german-cased
. - A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a tensorflow index checkpoint file (e.g,
./tf_model/model.ckpt.index
). In this case,from_tf
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
-
model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. -
config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
-
state_dict (Dict[str, torch.Tensor], optional) —
A state dictionary to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights. In this case though, you should check if using save_pretrained() and from_pretrained() is not a simpler option.
-
cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. -
from_tf (
bool
, optional, defaults toFalse
) — Load the model weights from a TensorFlow checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). -
force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. -
resume_download (
bool
, optional, defaults toFalse
) — Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. -
proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. -
output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. -
local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). -
revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. -
code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a causal language modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when itβs missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- bart β BartForCausalLM (BART model)
- bert β BertLMHeadModel (BERT model)
- bert-generation β BertGenerationDecoder (Bert Generation model)
- big_bird β BigBirdForCausalLM (BigBird model)
- bigbird_pegasus β BigBirdPegasusForCausalLM (BigBird-Pegasus model)
- biogpt β BioGptForCausalLM (BioGpt model)
- blenderbot β BlenderbotForCausalLM (Blenderbot model)
- blenderbot-small β BlenderbotSmallForCausalLM (BlenderbotSmall model)
- bloom β BloomForCausalLM (BLOOM model)
- camembert β CamembertForCausalLM (CamemBERT model)
- code_llama β LlamaForCausalLM (CodeLlama model)
- codegen β CodeGenForCausalLM (CodeGen model)
- cpmant β CpmAntForCausalLM (CPM-Ant model)
- ctrl β CTRLLMHeadModel (CTRL model)
- data2vec-text β Data2VecTextForCausalLM (Data2VecText model)
- electra β ElectraForCausalLM (ELECTRA model)
- ernie β ErnieForCausalLM (ERNIE model)
- falcon β FalconForCausalLM (Falcon model)
- git β GitForCausalLM (GIT model)
- gpt-sw3 β GPT2LMHeadModel (GPT-Sw3 model)
- gpt2 β GPT2LMHeadModel (OpenAI GPT-2 model)
- gpt_bigcode β GPTBigCodeForCausalLM (GPTBigCode model)
- gpt_neo β GPTNeoForCausalLM (GPT Neo model)
- gpt_neox β GPTNeoXForCausalLM (GPT NeoX model)
- gpt_neox_japanese β GPTNeoXJapaneseForCausalLM (GPT NeoX Japanese model)
- gptj β GPTJForCausalLM (GPT-J model)
- llama β LlamaForCausalLM (LLaMA model)
- marian β MarianForCausalLM (Marian model)
- mbart β MBartForCausalLM (mBART model)
- mega β MegaForCausalLM (MEGA model)
- megatron-bert β MegatronBertForCausalLM (Megatron-BERT model)
- mpt β MptForCausalLM (MPT model)
- musicgen β MusicgenForCausalLM (MusicGen model)
- mvp β MvpForCausalLM (MVP model)
- open-llama β OpenLlamaForCausalLM (OpenLlama model)
- openai-gpt β OpenAIGPTLMHeadModel (OpenAI GPT model)
- opt β OPTForCausalLM (OPT model)
- pegasus β PegasusForCausalLM (Pegasus model)
- plbart β PLBartForCausalLM (PLBart model)
- prophetnet β ProphetNetForCausalLM (ProphetNet model)
- qdqbert β QDQBertLMHeadModel (QDQBert model)
- reformer β ReformerModelWithLMHead (Reformer model)
- rembert β RemBertForCausalLM (RemBERT model)
- roberta β RobertaForCausalLM (RoBERTa model)
- roberta-prelayernorm β RobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm model)
- roc_bert β RoCBertForCausalLM (RoCBert model)
- roformer β RoFormerForCausalLM (RoFormer model)
- rwkv β RwkvForCausalLM (RWKV model)
- speech_to_text_2 β Speech2Text2ForCausalLM (Speech2Text2 model)
- transfo-xl β TransfoXLLMHeadModel (Transformer-XL model)
- trocr β TrOCRForCausalLM (TrOCR model)
- xglm β XGLMForCausalLM (XGLM model)
- xlm β XLMWithLMHeadModel (XLM model)
- xlm-prophetnet β XLMProphetNetForCausalLM (XLM-ProphetNet model)
- xlm-roberta β XLMRobertaForCausalLM (XLM-RoBERTa model)
- xlm-roberta-xl β XLMRobertaXLForCausalLM (XLM-RoBERTa-XL model)
- xlnet β XLNetLMHeadModel (XLNet model)
- xmod β XmodForCausalLM (X-MOD model)
The model is set in evaluation mode by default using model.eval()
(so for instance, dropout modules are
deactivated). To train the model, you should first set it back in training mode with model.train()
Examples:
>>> from transformers import AutoConfig, AutoModelForCausalLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = AutoModelForCausalLM.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = AutoModelForCausalLM.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a TF checkpoint file instead of a PyTorch model (slower)
>>> config = AutoConfig.from_pretrained("./tf_model/bert_tf_model_config.json")
>>> model = AutoModelForCausalLM.from_pretrained(
... "./tf_model/bert_tf_checkpoint.ckpt.index", from_tf=True, config=config
... )
TFAutoModelForCausalLM
This is a generic model class that will be instantiated as one of the model classes of the library (with a causal language modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
-
config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BertConfig configuration class: TFBertLMHeadModel (BERT model)
- CTRLConfig configuration class: TFCTRLLMHeadModel (CTRL model)
- CamembertConfig configuration class: TFCamembertForCausalLM (CamemBERT model)
- GPT2Config configuration class: TFGPT2LMHeadModel (OpenAI GPT-2 model)
- GPTJConfig configuration class: TFGPTJForCausalLM (GPT-J model)
- OPTConfig configuration class: TFOPTForCausalLM (OPT model)
- OpenAIGPTConfig configuration class: TFOpenAIGPTLMHeadModel (OpenAI GPT model)
- RemBertConfig configuration class: TFRemBertForCausalLM (RemBERT model)
- RoFormerConfig configuration class: TFRoFormerForCausalLM (RoFormer model)
- RobertaConfig configuration class: TFRobertaForCausalLM (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: TFRobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm model)
- TransfoXLConfig configuration class: TFTransfoXLLMHeadModel (Transformer-XL model)
- XGLMConfig configuration class: TFXGLMForCausalLM (XGLM model)
- XLMConfig configuration class: TFXLMWithLMHeadModel (XLM model)
- XLMRobertaConfig configuration class: TFXLMRobertaForCausalLM (XLM-RoBERTa model)
- XLNetConfig configuration class: TFXLNetLMHeadModel (XLNet model)
Instantiates one of the model classes of the library (with a causal language modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the modelβs configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
-
pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
bert-base-uncased
, or namespaced under a user or organization name, likedbmdz/bert-base-german-cased
. - A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
-
model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. -
config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
-
cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. -
from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). -
force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. -
resume_download (
bool
, optional, defaults toFalse
) — Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. -
proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. -
output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. -
local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). -
revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. -
code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a causal language modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when itβs missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- bert β TFBertLMHeadModel (BERT model)
- camembert β TFCamembertForCausalLM (CamemBERT model)
- ctrl β TFCTRLLMHeadModel (CTRL model)
- gpt-sw3 β TFGPT2LMHeadModel (GPT-Sw3 model)
- gpt2 β TFGPT2LMHeadModel (OpenAI GPT-2 model)
- gptj β TFGPTJForCausalLM (GPT-J model)
- openai-gpt β TFOpenAIGPTLMHeadModel (OpenAI GPT model)
- opt β TFOPTForCausalLM (OPT model)
- rembert β TFRemBertForCausalLM (RemBERT model)
- roberta β TFRobertaForCausalLM (RoBERTa model)
- roberta-prelayernorm β TFRobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm model)
- roformer β TFRoFormerForCausalLM (RoFormer model)
- transfo-xl β TFTransfoXLLMHeadModel (Transformer-XL model)
- xglm β TFXGLMForCausalLM (XGLM model)
- xlm β TFXLMWithLMHeadModel (XLM model)
- xlm-roberta β TFXLMRobertaForCausalLM (XLM-RoBERTa model)
- xlnet β TFXLNetLMHeadModel (XLNet model)
Examples:
>>> from transformers import AutoConfig, TFAutoModelForCausalLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = TFAutoModelForCausalLM.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = TFAutoModelForCausalLM.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = TFAutoModelForCausalLM.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )
FlaxAutoModelForCausalLM
This is a generic model class that will be instantiated as one of the model classes of the library (with a causal language modeling head) when created with the from_pretrained() class method or the from_config() class method.
This class cannot be instantiated directly using __init__()
(throws an error).
from_config
< source >( **kwargs )
Parameters
-
config (PretrainedConfig) —
The model class to instantiate is selected based on the configuration class:
- BartConfig configuration class: FlaxBartForCausalLM (BART model)
- BertConfig configuration class: FlaxBertForCausalLM (BERT model)
- BigBirdConfig configuration class: FlaxBigBirdForCausalLM (BigBird model)
- BloomConfig configuration class: FlaxBloomForCausalLM (BLOOM model)
- ElectraConfig configuration class: FlaxElectraForCausalLM (ELECTRA model)
- GPT2Config configuration class: FlaxGPT2LMHeadModel (OpenAI GPT-2 model)
- GPTJConfig configuration class: FlaxGPTJForCausalLM (GPT-J model)
- GPTNeoConfig configuration class: FlaxGPTNeoForCausalLM (GPT Neo model)
- OPTConfig configuration class: FlaxOPTForCausalLM (OPT model)
- RobertaConfig configuration class: FlaxRobertaForCausalLM (RoBERTa model)
- RobertaPreLayerNormConfig configuration class: FlaxRobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm model)
- XGLMConfig configuration class: FlaxXGLMForCausalLM (XGLM model)
- XLMRobertaConfig configuration class: FlaxXLMRobertaForCausalLM (XLM-RoBERTa model)
Instantiates one of the model classes of the library (with a causal language modeling head) from a configuration.
Note: Loading a model from its configuration file does not load the model weights. It only affects the modelβs configuration. Use from_pretrained() to load the model weights.
from_pretrained
< source >( *model_args **kwargs )
Parameters
-
pretrained_model_name_or_path (
str
oros.PathLike
) — Can be either:- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
bert-base-uncased
, or namespaced under a user or organization name, likedbmdz/bert-base-german-cased
. - A path to a directory containing model weights saved using
save_pretrained(), e.g.,
./my_model_directory/
. - A path or url to a PyTorch state_dict save file (e.g,
./pt_model/pytorch_model.bin
). In this case,from_pt
should be set toTrue
and a configuration object should be provided asconfig
argument. This loading path is slower than converting the PyTorch model in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
- A string, the model id of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids can be located at the root-level, like
-
model_args (additional positional arguments, optional) —
Will be passed along to the underlying model
__init__()
method. -
config (PretrainedConfig, optional) —
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
- The model is a model provided by the library (loaded with the model id string of a pretrained model).
- The model was saved using save_pretrained() and is reloaded by supplying the save directory.
- The model is loaded by supplying a local directory as
pretrained_model_name_or_path
and a configuration JSON file named config.json is found in the directory.
-
cache_dir (
str
oros.PathLike
, optional) — Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. -
from_pt (
bool
, optional, defaults toFalse
) — Load the model weights from a PyTorch checkpoint save file (see docstring ofpretrained_model_name_or_path
argument). -
force_download (
bool
, optional, defaults toFalse
) — Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. -
resume_download (
bool
, optional, defaults toFalse
) — Whether or not to delete incompletely received files. Will attempt to resume the download if such a file exists. -
proxies (
Dict[str, str]
, optional) — A dictionary of proxy servers to use by protocol or endpoint, e.g.,{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}
. The proxies are used on each request. -
output_loading_info(
bool
, optional, defaults toFalse
) — Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. -
local_files_only(
bool
, optional, defaults toFalse
) — Whether or not to only look at local files (e.g., not try downloading the model). -
revision (
str
, optional, defaults to"main"
) — The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
trust_remote_code (
bool
, optional, defaults toFalse
) — Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set toTrue
for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine. -
code_revision (
str
, optional, defaults to"main"
) — The specific revision to use for the code on the Hub, if the code leaves in a different repository than the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, sorevision
can be any identifier allowed by git. -
kwargs (additional keyword arguments, optional) —
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
output_attentions=True
). Behaves differently depending on whether aconfig
is provided or automatically loaded:- If a configuration is provided with
config
,**kwargs
will be directly passed to the underlying model’s__init__
method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided,
kwargs
will be first passed to the configuration class initialization function (from_pretrained()). Each key ofkwargs
that corresponds to a configuration attribute will be used to override said attribute with the suppliedkwargs
value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model’s__init__
function.
- If a configuration is provided with
Instantiate one of the model classes of the library (with a causal language modeling head) from a pretrained model.
The model class to instantiate is selected based on the model_type
property of the config object (either
passed as an argument or loaded from pretrained_model_name_or_path
if possible), or when itβs missing, by
falling back to using pattern matching on pretrained_model_name_or_path
:
- bart β FlaxBartForCausalLM (BART model)
- bert β FlaxBertForCausalLM (BERT model)
- big_bird β FlaxBigBirdForCausalLM (BigBird model)
- bloom β FlaxBloomForCausalLM (BLOOM model)
- electra β FlaxElectraForCausalLM (ELECTRA model)
- gpt-sw3 β FlaxGPT2LMHeadModel (GPT-Sw3 model)
- gpt2 β FlaxGPT2LMHeadModel (OpenAI GPT-2 model)
- gpt_neo β FlaxGPTNeoForCausalLM (GPT Neo model)
- gptj β FlaxGPTJForCausalLM (GPT-J model)
- opt β FlaxOPTForCausalLM (OPT model)
- roberta β FlaxRobertaForCausalLM (RoBERTa model)
- roberta-prelayernorm β FlaxRobertaPreLayerNormForCausalLM (RoBERTa-PreLayerNorm model)
- xglm β FlaxXGLMForCausalLM (XGLM model)
- xlm-roberta β FlaxXLMRobertaForCausalLM (XLM-RoBERTa model)
Examples:
>>> from transformers import AutoConfig, FlaxAutoModelForCausalLM
>>> # Download model and configuration from huggingface.co and cache.
>>> model = FlaxAutoModelForCausalLM.from_pretrained("bert-base-cased")
>>> # Update configuration during loading
>>> model = FlaxAutoModelForCausalLM.from_pretrained("bert-base-cased", output_attentions=True)
>>> model.config.output_attentions
True
>>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower)
>>> config = AutoConfig.from_pretrained("./pt_model/bert_pt_model_config.json")
>>> model = FlaxAutoModelForCausalLM.from_pretrained(
... "./pt_model/bert_pytorch_model.bin", from_pt=True, config=config
... )