Transformers documentation

ResNet

You are viewing v4.37.2 version. A newer version v4.46.3 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

ResNet

Overview

The ResNet model was proposed in Deep Residual Learning for Image Recognition by Kaiming He, Xiangyu Zhang, Shaoqing Ren and Jian Sun. Our implementation follows the small changes made by Nvidia, we apply the stride=2 for downsampling in bottleneck’s 3x3 conv and not in the first 1x1. This is generally known as β€œResNet v1.5”.

ResNet introduced residual connections, they allow to train networks with an unseen number of layers (up to 1000). ResNet won the 2015 ILSVRC & COCO competition, one important milestone in deep computer vision.

The abstract from the paper is the following:

Deeper neural networks are more difficult to train. We present a residual learning framework to ease the training of networks that are substantially deeper than those used previously. We explicitly reformulate the layers as learning residual functions with reference to the layer inputs, instead of learning unreferenced functions. We provide comprehensive empirical evidence showing that these residual networks are easier to optimize, and can gain accuracy from considerably increased depth. On the ImageNet dataset we evaluate residual nets with a depth of up to 152 layers---8x deeper than VGG nets but still having lower complexity. An ensemble of these residual nets achieves 3.57% error on the ImageNet test set. This result won the 1st place on the ILSVRC 2015 classification task. We also present analysis on CIFAR-10 with 100 and 1000 layers. The depth of representations is of central importance for many visual recognition tasks. Solely due to our extremely deep representations, we obtain a 28% relative improvement on the COCO object detection dataset. Deep residual nets are foundations of our submissions to ILSVRC & COCO 2015 competitions, where we also won the 1st places on the tasks of ImageNet detection, ImageNet localization, COCO detection, and COCO segmentation.

The figure below illustrates the architecture of ResNet. Taken from the original paper.

This model was contributed by Francesco. The TensorFlow version of this model was added by amyeroberts. The original code can be found here.

Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with ResNet.

Image Classification

If you’re interested in submitting a resource to be included here, please feel free to open a Pull Request and we’ll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.

ResNetConfig

class transformers.ResNetConfig

< >

( num_channels = 3 embedding_size = 64 hidden_sizes = [256, 512, 1024, 2048] depths = [3, 4, 6, 3] layer_type = 'bottleneck' hidden_act = 'relu' downsample_in_first_stage = False downsample_in_bottleneck = False out_features = None out_indices = None **kwargs )

Parameters

  • num_channels (int, optional, defaults to 3) — The number of input channels.
  • embedding_size (int, optional, defaults to 64) — Dimensionality (hidden size) for the embedding layer.
  • hidden_sizes (List[int], optional, defaults to [256, 512, 1024, 2048]) — Dimensionality (hidden size) at each stage.
  • depths (List[int], optional, defaults to [3, 4, 6, 3]) — Depth (number of layers) for each stage.
  • layer_type (str, optional, defaults to "bottleneck") — The layer to use, it can be either "basic" (used for smaller models, like resnet-18 or resnet-34) or "bottleneck" (used for larger models like resnet-50 and above).
  • hidden_act (str, optional, defaults to "relu") — The non-linear activation function in each block. If string, "gelu", "relu", "selu" and "gelu_new" are supported.
  • downsample_in_first_stage (bool, optional, defaults to False) — If True, the first stage will downsample the inputs using a stride of 2.
  • downsample_in_bottleneck (bool, optional, defaults to False) — If True, the first conv 1x1 in ResNetBottleNeckLayer will downsample the inputs using a stride of 2.
  • out_features (List[str], optional) — If used as backbone, list of features to output. Can be any of "stem", "stage1", "stage2", etc. (depending on how many stages the model has). If unset and out_indices is set, will default to the corresponding stages. If unset and out_indices is unset, will default to the last stage. Must be in the same order as defined in the stage_names attribute.
  • out_indices (List[int], optional) — If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how many stages the model has). If unset and out_features is set, will default to the corresponding stages. If unset and out_features is unset, will default to the last stage. Must be in the same order as defined in the stage_names attribute.

This is the configuration class to store the configuration of a ResNetModel. It is used to instantiate an ResNet model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the ResNet microsoft/resnet-50 architecture.

Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.

Example:

>>> from transformers import ResNetConfig, ResNetModel

>>> # Initializing a ResNet resnet-50 style configuration
>>> configuration = ResNetConfig()

>>> # Initializing a model (with random weights) from the resnet-50 style configuration
>>> model = ResNetModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
Pytorch
Hide Pytorch content

ResNetModel

class transformers.ResNetModel

< >

( config )

Parameters

  • config (ResNetConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

The bare ResNet model outputting raw features without any specific head on top. This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< >

( pixel_values: Tensor output_hidden_states: Optional = None return_dict: Optional = None ) β†’ transformers.modeling_outputs.BaseModelOutputWithPoolingAndNoAttention or tuple(torch.FloatTensor)

Parameters

  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See ConvNextImageProcessor.call() for details.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.

Returns

transformers.modeling_outputs.BaseModelOutputWithPoolingAndNoAttention or tuple(torch.FloatTensor)

A transformers.modeling_outputs.BaseModelOutputWithPoolingAndNoAttention or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (ResNetConfig) and inputs.

  • last_hidden_state (torch.FloatTensor of shape (batch_size, num_channels, height, width)) β€” Sequence of hidden-states at the output of the last layer of the model.

  • pooler_output (torch.FloatTensor of shape (batch_size, hidden_size)) β€” Last layer hidden-state after a pooling operation on the spatial dimensions.

  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, num_channels, height, width).

    Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.

The ResNetModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

>>> from transformers import AutoImageProcessor, ResNetModel
>>> import torch
>>> from datasets import load_dataset

>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]

>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
>>> model = ResNetModel.from_pretrained("microsoft/resnet-50")

>>> inputs = image_processor(image, return_tensors="pt")

>>> with torch.no_grad():
...     outputs = model(**inputs)

>>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
[1, 2048, 7, 7]

ResNetForImageClassification

class transformers.ResNetForImageClassification

< >

( config )

Parameters

  • config (ResNetConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for ImageNet.

This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.

forward

< >

( pixel_values: Optional = None labels: Optional = None output_hidden_states: Optional = None return_dict: Optional = None ) β†’ transformers.modeling_outputs.ImageClassifierOutputWithNoAttention or tuple(torch.FloatTensor)

Parameters

  • pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See ConvNextImageProcessor.call() for details.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
  • labels (torch.LongTensor of shape (batch_size,), optional) — Labels for computing the image classification/regression loss. Indices should be in [0, ..., config.num_labels - 1]. If config.num_labels > 1 a classification loss is computed (Cross-Entropy).

Returns

transformers.modeling_outputs.ImageClassifierOutputWithNoAttention or tuple(torch.FloatTensor)

A transformers.modeling_outputs.ImageClassifierOutputWithNoAttention or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (ResNetConfig) and inputs.

  • loss (torch.FloatTensor of shape (1,), optional, returned when labels is provided) β€” Classification (or regression if config.num_labels==1) loss.
  • logits (torch.FloatTensor of shape (batch_size, config.num_labels)) β€” Classification (or regression if config.num_labels==1) scores (before SoftMax).
  • hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of torch.FloatTensor (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each stage) of shape (batch_size, num_channels, height, width). Hidden-states (also called feature maps) of the model at the output of each stage.

The ResNetForImageClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

>>> from transformers import AutoImageProcessor, ResNetForImageClassification
>>> import torch
>>> from datasets import load_dataset

>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]

>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
>>> model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")

>>> inputs = image_processor(image, return_tensors="pt")

>>> with torch.no_grad():
...     logits = model(**inputs).logits

>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_label = logits.argmax(-1).item()
>>> print(model.config.id2label[predicted_label])
tiger cat
TensorFlow
Hide TensorFlow content

TFResNetModel

class transformers.TFResNetModel

< >

( config: ResNetConfig **kwargs )

Parameters

  • config (ResNetConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

The bare ResNet model outputting raw features without any specific head on top. This model is a TensorFlow tf.keras.layers.Layer sub-class. Use it as a regular TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and behavior.

call

< >

( pixel_values: Tensor output_hidden_states: Optional = None return_dict: Optional = None training: bool = False ) β†’ transformers.modeling_tf_outputs.TFBaseModelOutputWithPoolingAndNoAttention or tuple(tf.Tensor)

Parameters

  • pixel_values (tf.Tensor of shape (batch_size, num_channels, height, width)) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See ConvNextImageProcessor.call() for details.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.

Returns

transformers.modeling_tf_outputs.TFBaseModelOutputWithPoolingAndNoAttention or tuple(tf.Tensor)

A transformers.modeling_tf_outputs.TFBaseModelOutputWithPoolingAndNoAttention or a tuple of tf.Tensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (ResNetConfig) and inputs.

  • last_hidden_state (tf.Tensor of shape (batch_size, num_channels, height, width)) β€” Sequence of hidden-states at the output of the last layer of the model.

  • pooler_output (tf.Tensor of shape (batch_size, hidden_size)) β€” Last layer hidden-state after a pooling operation on the spatial dimensions.

  • hidden_states (tuple(tf.Tensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of tf.Tensor (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, num_channels, height, width).

    Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.

The TFResNetModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

>>> from transformers import AutoImageProcessor, TFResNetModel
>>> from datasets import load_dataset

>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]

>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
>>> model = TFResNetModel.from_pretrained("microsoft/resnet-50")

>>> inputs = image_processor(image, return_tensors="tf")
>>> outputs = model(**inputs)

>>> last_hidden_states = outputs.last_hidden_state
>>> list(last_hidden_states.shape)
[1, 2048, 7, 7]

TFResNetForImageClassification

class transformers.TFResNetForImageClassification

< >

( config: ResNetConfig **kwargs )

Parameters

  • config (ResNetConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.

ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for ImageNet.

This model is a TensorFlow tf.keras.layers.Layer sub-class. Use it as a regular TensorFlow Module and refer to the TensorFlow documentation for all matter related to general usage and behavior.

call

< >

( pixel_values: Tensor = None labels: Tensor = None output_hidden_states: bool = None return_dict: bool = None training: bool = False ) β†’ transformers.modeling_tf_outputs.TFImageClassifierOutputWithNoAttention or tuple(tf.Tensor)

Parameters

  • pixel_values (tf.Tensor of shape (batch_size, num_channels, height, width)) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See ConvNextImageProcessor.call() for details.
  • output_hidden_states (bool, optional) — Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
  • return_dict (bool, optional) — Whether or not to return a ModelOutput instead of a plain tuple.
  • labels (tf.Tensor of shape (batch_size,), optional) — Labels for computing the image classification/regression loss. Indices should be in [0, ..., config.num_labels - 1]. If config.num_labels > 1 a classification loss is computed (Cross-Entropy).

Returns

transformers.modeling_tf_outputs.TFImageClassifierOutputWithNoAttention or tuple(tf.Tensor)

A transformers.modeling_tf_outputs.TFImageClassifierOutputWithNoAttention or a tuple of tf.Tensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (ResNetConfig) and inputs.

  • loss (tf.Tensor of shape (1,), optional, returned when labels is provided) β€” Classification (or regression if config.num_labels==1) loss.
  • logits (tf.Tensor of shape (batch_size, config.num_labels)) β€” Classification (or regression if config.num_labels==1) scores (before SoftMax).
  • hidden_states (tuple(tf.Tensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of tf.Tensor (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each stage) of shape (batch_size, num_channels, height, width). Hidden-states (also called feature maps) of the model at the output of each stage.

The TFResNetForImageClassification forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

>>> from transformers import AutoImageProcessor, TFResNetForImageClassification
>>> import tensorflow as tf
>>> from datasets import load_dataset

>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]

>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
>>> model = TFResNetForImageClassification.from_pretrained("microsoft/resnet-50")

>>> inputs = image_processor(image, return_tensors="tf")
>>> logits = model(**inputs).logits

>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_label = int(tf.math.argmax(logits, axis=-1))
>>> print(model.config.id2label[predicted_label])
tiger cat
JAX
Hide JAX content

FlaxResNetModel

class transformers.FlaxResNetModel

< >

( config: ResNetConfig input_shape = (1, 224, 224, 3) seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True **kwargs )

Parameters

  • config (ResNetConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
  • dtype (jax.numpy.dtype, optional, defaults to jax.numpy.float32) — The data type of the computation. Can be one of jax.numpy.float32, jax.numpy.float16 (on GPUs) and jax.numpy.bfloat16 (on TPUs).

    This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If specified all the computation will be performed with the given dtype.

    Note that this only specifies the dtype of the computation and does not influence the dtype of model parameters.

    If you wish to change the dtype of the model parameters, see to_fp16() and to_bf16().

The bare ResNet model outputting raw features without any specific head on top.

This model inherits from FlaxPreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading, saving and converting weights from PyTorch models)

This model is also a flax.linen.Module subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and behavior.

Finally, this model supports inherent JAX features such as:

__call__

< >

( pixel_values params: dict = None train: bool = False output_hidden_states: Optional = None return_dict: Optional = None ) β†’ transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPoolingAndNoAttention or tuple(torch.FloatTensor)

Returns

transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPoolingAndNoAttention or tuple(torch.FloatTensor)

A transformers.modeling_flax_outputs.FlaxBaseModelOutputWithPoolingAndNoAttention or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.resnet.configuration_resnet.ResNetConfig'>) and inputs.

  • last_hidden_state (jnp.ndarray of shape (batch_size, num_channels, height, width)) β€” Sequence of hidden-states at the output of the last layer of the model.
  • pooler_output (jnp.ndarray of shape (batch_size, hidden_size)) β€” Last layer hidden-state after a pooling operation on the spatial dimensions.
  • hidden_states (tuple(jnp.ndarray), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) β€” Tuple of jnp.ndarray (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape (batch_size, num_channels, height, width). Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.

The FlaxResNetPreTrainedModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Examples:

>>> from transformers import AutoImageProcessor, FlaxResNetModel
>>> from PIL import Image
>>> import requests

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
>>> model = FlaxResNetModel.from_pretrained("microsoft/resnet-50")
>>> inputs = image_processor(images=image, return_tensors="np")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state

FlaxResNetForImageClassification

class transformers.FlaxResNetForImageClassification

< >

( config: ResNetConfig input_shape = (1, 224, 224, 3) seed: int = 0 dtype: dtype = <class 'jax.numpy.float32'> _do_init: bool = True **kwargs )

Parameters

  • config (ResNetConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
  • dtype (jax.numpy.dtype, optional, defaults to jax.numpy.float32) — The data type of the computation. Can be one of jax.numpy.float32, jax.numpy.float16 (on GPUs) and jax.numpy.bfloat16 (on TPUs).

    This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If specified all the computation will be performed with the given dtype.

    Note that this only specifies the dtype of the computation and does not influence the dtype of model parameters.

    If you wish to change the dtype of the model parameters, see to_fp16() and to_bf16().

ResNet Model with an image classification head on top (a linear layer on top of the pooled features), e.g. for ImageNet.

This model inherits from FlaxPreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading, saving and converting weights from PyTorch models)

This model is also a flax.linen.Module subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to general usage and behavior.

Finally, this model supports inherent JAX features such as:

__call__

< >

( pixel_values params: dict = None train: bool = False output_hidden_states: Optional = None return_dict: Optional = None ) β†’ transformers.modeling_flax_outputs.FlaxImageClassifierOutputWithNoAttention or tuple(torch.FloatTensor)

Returns

transformers.modeling_flax_outputs.FlaxImageClassifierOutputWithNoAttention or tuple(torch.FloatTensor)

A transformers.modeling_flax_outputs.FlaxImageClassifierOutputWithNoAttention or a tuple of torch.FloatTensor (if return_dict=False is passed or when config.return_dict=False) comprising various elements depending on the configuration (<class 'transformers.models.resnet.configuration_resnet.ResNetConfig'>) and inputs.

  • logits (jnp.ndarray of shape (batch_size, config.num_labels)) β€” Classification (or regression if config.num_labels==1) scores (before SoftMax).
  • hidden_states (tuple(jnp.ndarray), optional, returned when output_hidden_states=True is passed or when
  • config.output_hidden_states=True): Tuple of jnp.ndarray (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each stage) of shape (batch_size, num_channels, height, width). Hidden-states (also called feature maps) of the model at the output of each stage.

The FlaxResNetPreTrainedModel forward method, overrides the __call__ special method.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the pre and post processing steps while the latter silently ignores them.

Example:

>>> from transformers import AutoImageProcessor, FlaxResNetForImageClassification
>>> from PIL import Image
>>> import jax
>>> import requests

>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
>>> model = FlaxResNetForImageClassification.from_pretrained("microsoft/resnet-50")

>>> inputs = image_processor(images=image, return_tensors="np")
>>> outputs = model(**inputs)
>>> logits = outputs.logits

>>> # model predicts one of the 1000 ImageNet classes
>>> predicted_class_idx = jax.numpy.argmax(logits, axis=-1)
>>> print("Predicted class:", model.config.id2label[predicted_class_idx.item()])