Community Computer Vision Course documentation

ResNet (Residual Network)

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

ResNet (Residual Network)

Neural networks with more layers were assumed to be more effective because adding more layers improves the model performance.

As the networks became deeper, the extracted features could be further enriched, such as seen with VGG16 and VGG19.

A question arose: “Is learning networks as easy as stacking more layers”? An obstacle to answering this question, the gradient vanishing problem, was addressed by normalized initializations and intermediate normalization layers.

However, a new issue emerged: the degradation problem. As the neural networks became deeper, accuracy saturated and degraded rapidly. An experiment comparing shallow and deep plain networks revealed that deeper models exhibited higher training and test errors, suggesting a fundamental challenge in training deeper architectures effectively. This degradation was not because of overfitting but because the training error increased when the network became deeper. The added layers did not approximate the identity function.

ResNet’s residual connections unlocked the potential of the extreme depth, propelling the accuracy upwards compared to the previous architectures.

ResNet Architecture

  • A Residual Block. Source: ResNet Paper residual

ResNet’s building blocks designed as identity functions, preserve input information while enabling learning. This approach ensures efficient weight optimization and prevents degradation as the network becomes deeper.

The building block of ResNet can be shown in the picture, source ResNet paper.

resnet_building_block

Shortcut connections perform identity mapping and their output is added to the output of the stacked layers. Identity shortcut connections add neither extra parameters nor computational complexity, these connections bypass layers, creating direct paths for information flow, and they enable neural networks to learn the residual function (F).

We can summarize ResNet Network -> Plain Network + Shortcuts!

For operations (F(x) + x), (F(x)) and (x) should have identical dimensions. ResNet employs two techniques to achieve this:

  • Zero-padding shortcuts that add channels with zero values, maintaining dimensions without introducing extra parameters to be learned.
  • Projection shortcuts that use 1x1 convolutions to adjust dimensions when necessary, involving some additional learnable parameters.

In deeper ResNet architectures like ResNet 50, 101, and 152, a specialized “bottleneck building block” is employed to manage parameter complexity and maintain efficiency while enabling even deeper learning.

ResNet Code

Deep Residual Networks Pre-trained on ImageNet

Below you can see how to load pre-trained ResNet with an image classification head using transformers library.

from transformers import ResNetForImageClassification

model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")

model.eval()

All pre-trained models expect input images normalized similarly, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded into a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].

Here’s a sample execution. This example is available in the Hugging Face documentation.

from transformers import AutoFeatureExtractor, ResNetForImageClassification
import torch
from datasets import load_dataset

dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]

feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-50")
model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")

inputs = feature_extractor(image, return_tensors="pt")

with torch.no_grad():
    logits = model(**inputs).logits

# model predicts one of the 1000 ImageNet classes
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])

References

< > Update on GitHub