Transformers documentation
TimmWrapper
TimmWrapper
Overview
Helper class to enable loading timm models to be used with the transformers library and its autoclasses.
>>> import torch
>>> from PIL import Image
>>> from urllib.request import urlopen
>>> from transformers import AutoModelForImageClassification, AutoImageProcessor
>>> # Load image
>>> image = Image.open(urlopen(
...     'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
... ))
>>> # Load model and image processor
>>> checkpoint = "timm/resnet50.a1_in1k"
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
>>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()
>>> # Preprocess image
>>> inputs = image_processor(image)
>>> # Forward pass
>>> with torch.no_grad():
...     logits = model(**inputs).logits
>>> # Get top 5 predictions
>>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5)Resources
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with TimmWrapper.
For a more detailed overview please read the official blog post on the timm integration.
TimmWrapperConfig
class transformers.TimmWrapperConfig
< source >( architecture: str = 'resnet50' initializer_range: float = 0.02 do_pooling: bool = True model_args: typing.Optional[dict[str, typing.Any]] = None **kwargs )
Parameters
-  architecture (str, optional, defaults to"resnet50") — The timm architecture to load.
-  initializer_range (float, optional, defaults to 0.02) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
-  do_pooling (bool, optional, defaults toTrue) — Whether to do pooling for the last_hidden_state inTimmWrapperModelor not.
-  model_args (dict[str, Any], optional) — Additional keyword arguments to pass to thetimm.create_modelfunction. e.g.model_args={"depth": 3}fortimm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1kto create a model with 3 blocks. Defaults toNone.
This is the configuration class to store the configuration for a timm backbone TimmWrapper.
It is used to instantiate a timm model according to the specified arguments, defining the model.
Configuration objects inherit from PreTrainedConfig and can be used to control the model outputs. Read the documentation from PreTrainedConfig for more information.
Config loads imagenet label descriptions and stores them in id2label attribute, label2id attribute for default
imagenet models is set to None due to occlusions in the label descriptions.
TimmWrapperImageProcessor
class transformers.TimmWrapperImageProcessor
< source >( pretrained_cfg: dict architecture: typing.Optional[str] = None **kwargs )
Wrapper class for timm models to be used within transformers.
preprocess
< source >( images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']] return_tensors: typing.Union[str, transformers.utils.generic.TensorType, NoneType] = 'pt' )
Preprocess an image or batch of images.
TimmWrapperModel
Wrapper class for timm models to be used in transformers.
forward
< source >( pixel_values: FloatTensor output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Union[bool, list[int], NoneType] = None return_dict: typing.Optional[bool] = None do_pooling: typing.Optional[bool] = None **kwargs  ) → transformers.models.timm_wrapper.modeling_timm_wrapper.TimmWrapperModelOutput or tuple(torch.FloatTensor)
Parameters
-  pixel_values (torch.FloatTensorof shape(batch_size, num_channels, image_size, image_size)) — The tensors corresponding to the input images. Pixel values can be obtained using TimmWrapperImageProcessor. See TimmWrapperImageProcessor.call() for details (processor_classuses TimmWrapperImageProcessor for processing images).
-  output_attentions (bool, optional) — Whether or not to return the attentions tensors of all attention layers. Not compatible with timm wrapped models.
-  output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. Not compatible with timm wrapped models.
-  return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
-  do_pooling (bool, optional) — Whether to do pooling for the last_hidden_state inTimmWrapperModelor not. IfNoneis passed, thedo_poolingvalue from the config is used.
Returns
transformers.models.timm_wrapper.modeling_timm_wrapper.TimmWrapperModelOutput or tuple(torch.FloatTensor)
A transformers.models.timm_wrapper.modeling_timm_wrapper.TimmWrapperModelOutput or a tuple of
torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various
elements depending on the configuration (TimmWrapperConfig) and inputs.
- last_hidden_state (<class 'torch.FloatTensor'>.last_hidden_state) — The last hidden state of the model, output before applying the classification head.
- pooler_output (torch.FloatTensor, optional) — The pooled output derived from the last hidden state, if applicable.
- hidden_states (tuple(torch.FloatTensor), optional, returned ifoutput_hidden_states=Trueis set or ifconfig.output_hidden_states=True) — A tuple containing the intermediate hidden states of the model at the output of each layer or specified layers.
- attentions (tuple(torch.FloatTensor), optional, returned ifoutput_attentions=Trueis set or ifconfig.output_attentions=True.) — A tuple containing the intermediate attention weights of the model at the output of each layer. Note: Currently, Timm models do not support attentions output.
The TimmWrapperModel forward method, overrides the __call__ special method.
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.
Examples:
>>> import torch
>>> from PIL import Image
>>> from urllib.request import urlopen
>>> from transformers import AutoModel, AutoImageProcessor
>>> # Load image
>>> image = Image.open(urlopen(
...     'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
... ))
>>> # Load model and image processor
>>> checkpoint = "timm/resnet50.a1_in1k"
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
>>> model = AutoModel.from_pretrained(checkpoint).eval()
>>> # Preprocess image
>>> inputs = image_processor(image)
>>> # Forward pass
>>> with torch.no_grad():
...     outputs = model(**inputs)
>>> # Get pooled output
>>> pooled_output = outputs.pooler_output
>>> # Get last hidden state
>>> last_hidden_state = outputs.last_hidden_stateTimmWrapperForImageClassification
Wrapper class for timm models to be used in transformers for image classification.
forward
< source >( pixel_values: FloatTensor labels: typing.Optional[torch.LongTensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Union[bool, list[int], NoneType] = None return_dict: typing.Optional[bool] = None **kwargs  ) → transformers.modeling_outputs.ImageClassifierOutput or tuple(torch.FloatTensor)
Parameters
-  pixel_values (torch.FloatTensorof shape(batch_size, num_channels, image_size, image_size)) — The tensors corresponding to the input images. Pixel values can be obtained using TimmWrapperImageProcessor. See TimmWrapperImageProcessor.call() for details (processor_classuses TimmWrapperImageProcessor for processing images).
-  labels (torch.LongTensorof shape(batch_size,), optional) — Labels for computing the image classification/regression loss. Indices should be in[0, ..., config.num_labels - 1]. Ifconfig.num_labels == 1a regression loss is computed (Mean-Square loss), Ifconfig.num_labels > 1a classification loss is computed (Cross-Entropy).
-  output_attentions (bool, optional) — Whether or not to return the attentions tensors of all attention layers. Not compatible with timm wrapped models.
-  output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. Not compatible with timm wrapped models.
-  return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple. **kwargs: Additional keyword arguments passed along to thetimmmodel forward.
Returns
transformers.modeling_outputs.ImageClassifierOutput or tuple(torch.FloatTensor)
A transformers.modeling_outputs.ImageClassifierOutput or a tuple of
torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various
elements depending on the configuration (TimmWrapperConfig) and inputs.
- 
loss ( torch.FloatTensorof shape(1,), optional, returned whenlabelsis provided) — Classification (or regression if config.num_labels==1) loss.
- 
logits ( torch.FloatTensorof shape(batch_size, config.num_labels)) — Classification (or regression if config.num_labels==1) scores (before SoftMax).
- 
hidden_states ( tuple(torch.FloatTensor), optional, returned whenoutput_hidden_states=Trueis passed or whenconfig.output_hidden_states=True) — Tuple oftorch.FloatTensor(one for the output of the embeddings, if the model has an embedding layer, + one for the output of each stage) of shape(batch_size, sequence_length, hidden_size). Hidden-states (also called feature maps) of the model at the output of each stage.
- 
attentions ( tuple(torch.FloatTensor), optional, returned whenoutput_attentions=Trueis passed or whenconfig.output_attentions=True) — Tuple oftorch.FloatTensor(one for each layer) of shape(batch_size, num_heads, patch_size, sequence_length).Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. 
The TimmWrapperForImageClassification forward method, overrides the __call__ special method.
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.
Examples:
>>> import torch
>>> from PIL import Image
>>> from urllib.request import urlopen
>>> from transformers import AutoModelForImageClassification, AutoImageProcessor
>>> # Load image
>>> image = Image.open(urlopen(
...     'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
... ))
>>> # Load model and image processor
>>> checkpoint = "timm/resnet50.a1_in1k"
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
>>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()
>>> # Preprocess image
>>> inputs = image_processor(image)
>>> # Forward pass
>>> with torch.no_grad():
...     logits = model(**inputs).logits
>>> # Get top 5 predictions
>>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5)