--- license: other license_name: apple-sample-code-license license_link: LICENSE datasets: - imagenet-1k metrics: - accuracy library_name: mlx pipeline_tag: image-classification tags: - large-scale-vision-models - pytorch - mlx - jax - vision - ssl - pre-training - DFN --- # AIM: Autoregressive Image Models *Alaaeldin El-Nouby, Michal Klein, Shuangfei Zhai, Miguel Angel Bautista, Alexander Toshev, Vaishaal Shankar, Joshua M Susskind, and Armand Joulin* This software project accompanies the research paper, Scalable Pre-training of Large Autoregressive Image Models. We introduce **AIM** a collection of vision models pre-trained with an autoregressive generative objective. We show that autoregressive pre-training of image features exhibits similar scaling properties to their textual counterpart (i.e. Large Language Models). Specifically, we highlight two findings: 1. the model capacity can be trivially scaled to billions of parameters, and 2. AIM effectively leverages large collections of uncurated image data. ## Installation Please install PyTorch using the official [installation instructions](https://pytorch.org/get-started/locally/). Afterward, install the package as: ```commandline pip install git+https://git@github.com/apple/ml-aim.git ``` We also offer [MLX](https://github.com/ml-explore/mlx) backend support for research and experimentation on Apple silicon. To enable MLX support, simply run: ```commandline pip install mlx ``` ## Usage Below we provide an example of usage in [PyTorch](https://pytorch.org/): ```python from PIL import Image from aim.utils import load_pretrained from aim.torch.data import val_transforms img = Image.open(...) model = load_pretrained("aim-600M-2B-imgs", backend="torch") transform = val_transforms() inp = transform(img).unsqueeze(0) logits, _ = model(inp) ```
and in both MLX ```python from PIL import Image import mlx.core as mx from aim.utils import load_pretrained from aim.torch.data import val_transforms img = Image.open(...) model = load_pretrained("aim-600M-2B-imgs", backend="mlx") transform = val_transforms() inp = transform(img).unsqueeze(0) inp = mx.array(inp.numpy()) logits, _ = model(inp) ```
and JAX ```python from PIL import Image import jax.numpy as jnp from aim.utils import load_pretrained from aim.torch.data import val_transforms img = Image.open(...) model, params = load_pretrained("aim-600M-2B-imgs", backend="jax") transform = val_transforms() inp = transform(img).unsqueeze(0) inp = jnp.array(inp) (logits, _), _ = model.apply(params, inp, mutable=['batch_stats']) ```
## Pre-trained checkpoints The pre-trained models can be accessed via [PyTorch Hub](https://pytorch.org/hub/) as: ```python import torch aim_600m = torch.hub.load("apple/ml-aim", "aim_600M") aim_1b = torch.hub.load("apple/ml-aim", "aim_1B") aim_3b = torch.hub.load("apple/ml-aim", "aim_3B") aim_7b = torch.hub.load("apple/ml-aim", "aim_7B") ``` ### Pre-trained backbones The following table contains pre-trained backbones used in our paper.
model #params attn (best layer) backbone, SHA256
AIM-0.6B 0.6B 79.4% link, 0d6f6b8f
AIM-1B 1B 82.3% link, d254ecd3
AIM-3B 3B 83.3% link, 8475ce4e
AIM-7B 7B 84.0% link, 184ed94c
### Pre-trained attention heads The table below contains the classification results on ImageNet-1k validation set.
model top-1 IN-1k attention head, SHA256
last layer best layer last layer best layer
AIM-0.6B 78.5% 79.4% link, 5ce5a341 link, ebd45c05
AIM-1B 80.6% 82.3% link, db3be2ad link, f1ed7852
AIM-3B 82.2% 83.3% link, 5c057b30 link, ad380e16
AIM-7B 82.4% 84.0% link, 1e5c99ba link, 73ecd732
## Reproducing the IN-1k classification results The commands below reproduce the [attention probe results](#pre-trained-attention-heads) on ImageNet-1k validation set. We run the evaluation using 1 node with 8 GPUs: ```commandline torchrun --standalone --nnodes=1 --nproc-per-node=8 main_attnprobe.py \ --model=aim-7B \ --batch-size=64 \ --data-path=/path/to/imagenet \ --probe-layers=last \ --backbone-ckpt-path=/path/to/backbone_ckpt.pth \ --head-ckpt-path=/path/to/head_ckpt.pth ``` By default, we probe the last 6 layers. To change this, simply pass `--probe-layers=best`.