Image Classification
mlx-image
Safetensors
MLX
vision
File size: 1,955 Bytes
4b2c97d
705a6f2
 
 
 
 
 
 
 
 
4b2c97d
705a6f2
4b2c97d
705a6f2
4b2c97d
705a6f2
4b2c97d
705a6f2
4b2c97d
705a6f2
 
 
4b2c97d
 
705a6f2
 
 
 
4b2c97d
705a6f2
4b2c97d
705a6f2
 
 
 
4b2c97d
705a6f2
 
 
4b2c97d
705a6f2
 
4b2c97d
705a6f2
 
4b2c97d
705a6f2
 
 
 
 
4b2c97d
705a6f2
 
 
4b2c97d
705a6f2
 
 
4b2c97d
705a6f2
4b2c97d
705a6f2
 
 
4b2c97d
705a6f2
 
4b2c97d
705a6f2
 
4b2c97d
705a6f2
 
 
4b2c97d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
---
license: apache-2.0
tags:
- mlx
- mlx-image
- vision
- image-classification
datasets:
- imagenet-1k
library_name: mlx-image
---
# vit_large_patch14_518.dinov2

A [Vision Transformer](https://arxiv.org/abs/2010.11929v2) image classification model trained on ImageNet-1k dataset with [DINOv2](https://arxiv.org/abs/2304.07193).

The model was trained in self-supervised fashion on ImageNet-1k dataset. No classification head was trained, only the backbone.

Disclaimer: This is a porting of the torch model weights to Apple MLX Framework.

<div align="center">
<img width="100%" alt="DINO illustration" src="dino.gif">
</div>


## How to use
```bash
pip install mlx-image
```

Here is how to use this model for image classification:

```python
from mlxim.model import create_model
from mlxim.io import read_rgb
from mlxim.transform import ImageNetTransform

transform = ImageNetTransform(train=False, img_size=518)
x = transform(read_rgb("cat.png"))
x = mx.expand_dims(x, 0)

model = create_model("vit_large_patch14_518.dinov2")
model.eval()

logits, attn_masks = model(x, attn_masks=True)
```

You can also use the embeds from layer before head:
```python
from mlxim.model import create_model
from mlxim.io import read_rgb
from mlxim.transform import ImageNetTransform

transform = ImageNetTransform(train=False, img_size=512)
x = transform(read_rgb("cat.png"))
x = mx.expand_dims(x, 0)

# first option
model = create_model("vit_large_patch14_518.dinov2", num_classes=0)
model.eval()

embeds = model(x)

# second option
model = create_model("vit_large_patch14_518.dinov2")
model.eval()

embeds, attn_masks = model.get_features(x)
```

## Attention maps
You can visualize the attention maps using the `attn_masks` returned by the model. Go check the mlx-image [notebook](https://github.com/riccardomusmeci/mlx-image/blob/main/notebooks/dino_attention.ipynb).

<div align="center">
<img width="100%" alt="Attention Map" src="attention_maps.png">
</div>