Edit model card

Model card for coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k

A timm specific CoAtNet (w/ a MLP Log-CPB (continuous log-coordinate relative position bias motivated by Swin-V2) image classification model. Pretrained in timm on ImageNet-12k (a 11821 class subset of full ImageNet-22k) and fine-tuned on ImageNet-1k by Ross Wightman.

ImageNet-12k training performed on TPUs thanks to support of the TRC program.

Fine-tuning performed on 8x GPU Lambda Labs cloud instances.

Model Variants in maxxvit.py

MaxxViT covers a number of related model architectures that share a common structure including:

  • CoAtNet - Combining MBConv (depthwise-separable) convolutional blocks in early stages with self-attention transformer blocks in later stages.
  • MaxViT - Uniform blocks across all stages, each containing a MBConv (depthwise-separable) convolution block followed by two self-attention blocks with different partitioning schemes (window followed by grid).
  • CoAtNeXt - A timm specific arch that uses ConvNeXt blocks in place of MBConv blocks in CoAtNet. All normalization layers are LayerNorm (no BatchNorm).
  • MaxxViT - A timm specific arch that uses ConvNeXt blocks in place of MBConv blocks in MaxViT. All normalization layers are LayerNorm (no BatchNorm).
  • MaxxViT-V2 - A MaxxViT variation that removes the window block attention leaving only ConvNeXt blocks and grid attention w/ more width to compensate.

Aside from the major variants listed above, there are more subtle changes from model to model. Any model name with the string rw are timm specific configs w/ modelling adjustments made to favour PyTorch eager use. These were created while training initial reproductions of the models so there are variations. All models with the string tf are models exactly matching Tensorflow based models by the original paper authors with weights ported to PyTorch. This covers a number of MaxViT models. The official CoAtNet models were never released.

Model Details

  • Model Type: Image classification / feature backbone
  • Model Stats:
    • Params (M): 41.7
    • GMACs: 8.1
    • Activations (M): 40.1
    • Image size: 224 x 224
  • Papers:
  • Dataset: ImageNet-1k
  • Pretrain Dataset: ImageNet-12k

Model Usage

Image Classification

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = timm.create_model('coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k', pretrained=True)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0))  # unsqueeze single image into batch of 1

top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1) * 100, k=5)

Feature Map Extraction

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = timm.create_model(
    'coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k',
    pretrained=True,
    features_only=True,
)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0))  # unsqueeze single image into batch of 1

for o in output:
    # print shape of each feature map in output
    # e.g.:
    #  torch.Size([1, 64, 112, 112])
    #  torch.Size([1, 96, 56, 56])
    #  torch.Size([1, 192, 28, 28])
    #  torch.Size([1, 384, 14, 14])
    #  torch.Size([1, 768, 7, 7])

    print(o.shape)

Image Embeddings

from urllib.request import urlopen
from PIL import Image
import timm

img = Image.open(urlopen(
    'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))

model = timm.create_model(
    'coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k',
    pretrained=True,
    num_classes=0,  # remove classifier nn.Linear
)
model = model.eval()

# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

output = model(transforms(img).unsqueeze(0))  # output is (batch_size, num_features) shaped tensor

# or equivalently (without needing to set num_classes=0)

output = model.forward_features(transforms(img).unsqueeze(0))
# output is unpooled, a (1, 768, 7, 7) shaped tensor

output = model.forward_head(output, pre_logits=True)
# output is a (1, num_features) shaped tensor

Model Comparison

By Top-1

model top1 top5 samples / sec Params (M) GMAC Act (M)
maxvit_xlarge_tf_512.in21k_ft_in1k 88.53 98.64 21.76 475.77 534.14 1413.22
maxvit_xlarge_tf_384.in21k_ft_in1k 88.32 98.54 42.53 475.32 292.78 668.76
maxvit_base_tf_512.in21k_ft_in1k 88.20 98.53 50.87 119.88 138.02 703.99
maxvit_large_tf_512.in21k_ft_in1k 88.04 98.40 36.42 212.33 244.75 942.15
maxvit_large_tf_384.in21k_ft_in1k 87.98 98.56 71.75 212.03 132.55 445.84
maxvit_base_tf_384.in21k_ft_in1k 87.92 98.54 104.71 119.65 73.80 332.90
maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k 87.81 98.37 106.55 116.14 70.97 318.95
maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k 87.47 98.37 149.49 116.09 72.98 213.74
coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k 87.39 98.31 160.80 73.88 47.69 209.43
maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k 86.89 98.02 375.86 116.14 23.15 92.64
maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k 86.64 98.02 501.03 116.09 24.20 62.77
maxvit_base_tf_512.in1k 86.60 97.92 50.75 119.88 138.02 703.99
coatnet_2_rw_224.sw_in12k_ft_in1k 86.57 97.89 631.88 73.87 15.09 49.22
maxvit_large_tf_512.in1k 86.52 97.88 36.04 212.33 244.75 942.15
coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k 86.49 97.90 620.58 73.88 15.18 54.78
maxvit_base_tf_384.in1k 86.29 97.80 101.09 119.65 73.80 332.90
maxvit_large_tf_384.in1k 86.23 97.69 70.56 212.03 132.55 445.84
maxvit_small_tf_512.in1k 86.10 97.76 88.63 69.13 67.26 383.77
maxvit_tiny_tf_512.in1k 85.67 97.58 144.25 31.05 33.49 257.59
maxvit_small_tf_384.in1k 85.54 97.46 188.35 69.02 35.87 183.65
maxvit_tiny_tf_384.in1k 85.11 97.38 293.46 30.98 17.53 123.42
maxvit_large_tf_224.in1k 84.93 96.97 247.71 211.79 43.68 127.35
coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k 84.90 96.96 1025.45 41.72 8.11 40.13
maxvit_base_tf_224.in1k 84.85 96.99 358.25 119.47 24.04 95.01
maxxvit_rmlp_small_rw_256.sw_in1k 84.63 97.06 575.53 66.01 14.67 58.38
coatnet_rmlp_2_rw_224.sw_in1k 84.61 96.74 625.81 73.88 15.18 54.78
maxvit_rmlp_small_rw_224.sw_in1k 84.49 96.76 693.82 64.90 10.75 49.30
maxvit_small_tf_224.in1k 84.43 96.83 647.96 68.93 11.66 53.17
maxvit_rmlp_tiny_rw_256.sw_in1k 84.23 96.78 807.21 29.15 6.77 46.92
coatnet_1_rw_224.sw_in1k 83.62 96.38 989.59 41.72 8.04 34.60
maxvit_tiny_rw_224.sw_in1k 83.50 96.50 1100.53 29.06 5.11 33.11
maxvit_tiny_tf_224.in1k 83.41 96.59 1004.94 30.92 5.60 35.78
coatnet_rmlp_1_rw_224.sw_in1k 83.36 96.45 1093.03 41.69 7.85 35.47
maxxvitv2_nano_rw_256.sw_in1k 83.11 96.33 1276.88 23.70 6.26 23.05
maxxvit_rmlp_nano_rw_256.sw_in1k 83.03 96.34 1341.24 16.78 4.37 26.05
maxvit_rmlp_nano_rw_256.sw_in1k 82.96 96.26 1283.24 15.50 4.47 31.92
maxvit_nano_rw_256.sw_in1k 82.93 96.23 1218.17 15.45 4.46 30.28
coatnet_bn_0_rw_224.sw_in1k 82.39 96.19 1600.14 27.44 4.67 22.04
coatnet_0_rw_224.sw_in1k 82.39 95.84 1831.21 27.44 4.43 18.73
coatnet_rmlp_nano_rw_224.sw_in1k 82.05 95.87 2109.09 15.15 2.62 20.34
coatnext_nano_rw_224.sw_in1k 81.95 95.92 2525.52 14.70 2.47 12.80
coatnet_nano_rw_224.sw_in1k 81.70 95.64 2344.52 15.14 2.41 15.41
maxvit_rmlp_pico_rw_256.sw_in1k 80.53 95.21 1594.71 7.52 1.85 24.86

By Throughput (samples / sec)

model top1 top5 samples / sec Params (M) GMAC Act (M)
coatnext_nano_rw_224.sw_in1k 81.95 95.92 2525.52 14.70 2.47 12.80
coatnet_nano_rw_224.sw_in1k 81.70 95.64 2344.52 15.14 2.41 15.41
coatnet_rmlp_nano_rw_224.sw_in1k 82.05 95.87 2109.09 15.15 2.62 20.34
coatnet_0_rw_224.sw_in1k 82.39 95.84 1831.21 27.44 4.43 18.73
coatnet_bn_0_rw_224.sw_in1k 82.39 96.19 1600.14 27.44 4.67 22.04
maxvit_rmlp_pico_rw_256.sw_in1k 80.53 95.21 1594.71 7.52 1.85 24.86
maxxvit_rmlp_nano_rw_256.sw_in1k 83.03 96.34 1341.24 16.78 4.37 26.05
maxvit_rmlp_nano_rw_256.sw_in1k 82.96 96.26 1283.24 15.50 4.47 31.92
maxxvitv2_nano_rw_256.sw_in1k 83.11 96.33 1276.88 23.70 6.26 23.05
maxvit_nano_rw_256.sw_in1k 82.93 96.23 1218.17 15.45 4.46 30.28
maxvit_tiny_rw_224.sw_in1k 83.50 96.50 1100.53 29.06 5.11 33.11
coatnet_rmlp_1_rw_224.sw_in1k 83.36 96.45 1093.03 41.69 7.85 35.47
coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k 84.90 96.96 1025.45 41.72 8.11 40.13
maxvit_tiny_tf_224.in1k 83.41 96.59 1004.94 30.92 5.60 35.78
coatnet_1_rw_224.sw_in1k 83.62 96.38 989.59 41.72 8.04 34.60
maxvit_rmlp_tiny_rw_256.sw_in1k 84.23 96.78 807.21 29.15 6.77 46.92
maxvit_rmlp_small_rw_224.sw_in1k 84.49 96.76 693.82 64.90 10.75 49.30
maxvit_small_tf_224.in1k 84.43 96.83 647.96 68.93 11.66 53.17
coatnet_2_rw_224.sw_in12k_ft_in1k 86.57 97.89 631.88 73.87 15.09 49.22
coatnet_rmlp_2_rw_224.sw_in1k 84.61 96.74 625.81 73.88 15.18 54.78
coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k 86.49 97.90 620.58 73.88 15.18 54.78
maxxvit_rmlp_small_rw_256.sw_in1k 84.63 97.06 575.53 66.01 14.67 58.38
maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k 86.64 98.02 501.03 116.09 24.20 62.77
maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k 86.89 98.02 375.86 116.14 23.15 92.64
maxvit_base_tf_224.in1k 84.85 96.99 358.25 119.47 24.04 95.01
maxvit_tiny_tf_384.in1k 85.11 97.38 293.46 30.98 17.53 123.42
maxvit_large_tf_224.in1k 84.93 96.97 247.71 211.79 43.68 127.35
maxvit_small_tf_384.in1k 85.54 97.46 188.35 69.02 35.87 183.65
coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k 87.39 98.31 160.80 73.88 47.69 209.43
maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k 87.47 98.37 149.49 116.09 72.98 213.74
maxvit_tiny_tf_512.in1k 85.67 97.58 144.25 31.05 33.49 257.59
maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k 87.81 98.37 106.55 116.14 70.97 318.95
maxvit_base_tf_384.in21k_ft_in1k 87.92 98.54 104.71 119.65 73.80 332.90
maxvit_base_tf_384.in1k 86.29 97.80 101.09 119.65 73.80 332.90
maxvit_small_tf_512.in1k 86.10 97.76 88.63 69.13 67.26 383.77
maxvit_large_tf_384.in21k_ft_in1k 87.98 98.56 71.75 212.03 132.55 445.84
maxvit_large_tf_384.in1k 86.23 97.69 70.56 212.03 132.55 445.84
maxvit_base_tf_512.in21k_ft_in1k 88.20 98.53 50.87 119.88 138.02 703.99
maxvit_base_tf_512.in1k 86.60 97.92 50.75 119.88 138.02 703.99
maxvit_xlarge_tf_384.in21k_ft_in1k 88.32 98.54 42.53 475.32 292.78 668.76
maxvit_large_tf_512.in21k_ft_in1k 88.04 98.40 36.42 212.33 244.75 942.15
maxvit_large_tf_512.in1k 86.52 97.88 36.04 212.33 244.75 942.15
maxvit_xlarge_tf_512.in21k_ft_in1k 88.53 98.64 21.76 475.77 534.14 1413.22

Citation

@misc{rw2019timm,
  author = {Ross Wightman},
  title = {PyTorch Image Models},
  year = {2019},
  publisher = {GitHub},
  journal = {GitHub repository},
  doi = {10.5281/zenodo.4414861},
  howpublished = {\url{https://github.com/huggingface/pytorch-image-models}}
}
@article{tu2022maxvit,
  title={MaxViT: Multi-Axis Vision Transformer},
  author={Tu, Zhengzhong and Talebi, Hossein and Zhang, Han and Yang, Feng and Milanfar, Peyman and Bovik, Alan and Li, Yinxiao},
  journal={ECCV},
  year={2022},
}        
@article{dai2021coatnet,
  title={CoAtNet: Marrying Convolution and Attention for All Data Sizes},
  author={Dai, Zihang and Liu, Hanxiao and Le, Quoc V and Tan, Mingxing},
  journal={arXiv preprint arXiv:2106.04803},
  year={2021}
}
Downloads last month
816
Safetensors
Model size
41.7M params
Tensor type
F32
·

Dataset used to train timm/coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k