Image Classification
PyTorch
ml-aim
File size: 2,309 Bytes
25d990c
 
 
 
af484a4
 
25d990c
af484a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
---
license: other
license_name: apple-sample-code-license
license_link: LICENSE
library_name: ml-aim
pipeline_tag: image-classification
---

# AIM: Autoregressive Image Models

*Alaaeldin El-Nouby, Michal Klein, Shuangfei Zhai, Miguel Angel Bautista, Alexander Toshev, Vaishaal Shankar,
Joshua M Susskind, and Armand Joulin*


This software project accompanies the research paper, [Scalable Pre-training of Large Autoregressive Image Models](https://arxiv.org/abs/2401.08541).

We introduce **AIM** a collection of vision models pre-trained with an autoregressive generative objective.
We show that autoregressive pre-training of image features exhibits similar scaling properties to their
textual counterpart (i.e. Large Language Models). Specifically, we highlight two findings:
1. the model capacity can be trivially scaled to billions of parameters, and
2. AIM effectively leverages large collections of uncurated image data.

## Installation
Please install PyTorch using the official [installation instructions](https://pytorch.org/get-started/locally/).
Afterward, install the package as:
```commandline
pip install git+https://git@github.com/apple/ml-aim.git
```


## Usage
Below we provide an example of loading the model via [HuggingFace Hub](https://huggingface.co/docs/hub/) as:
```python
from PIL import Image

from aim.torch.models import AIMForImageClassification
from aim.torch.data import val_transforms

img = Image.open(...)
model = AIMForImageClassification.from_pretrained("apple/aim-7B")
transform = val_transforms()

inp = transform(img).unsqueeze(0)
logits, features = model(inp)
```

### ImageNet-1k results (frozen trunk)

The table below contains the classification results on ImageNet-1k validation set.

<table style="margin: auto">
  <thead>
    <tr>
      <th rowspan="2">model</th>
      <th colspan="2">top-1 IN-1k</th>
    </tr>
    <tr>
      <th>last layer</th>
      <th>best layer</th>
    </tr>
  </thead>

  <tbody>
    <tr>
      <td>AIM-0.6B</td>
      <td>78.5%</td>
      <td>79.4%</td>
    </tr>
    <tr>
      <td>AIM-1B</td>
      <td>80.6%</td>
      <td>82.3%</td>
    </tr>
    <tr>
      <td>AIM-3B</td>
      <td>82.2%</td>
      <td>83.3%</td>
    </tr>
    <tr>
      <td>AIM-7B</td>
      <td>82.4%</td>
      <td>84.0%</td>
    </tr>
  </tbody>
</table>