Image Classification
mlx-image
Safetensors
MLX
vision
riccardomusmeci commited on
Commit
34ba5ef
·
verified ·
1 Parent(s): afd9479

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +62 -1
README.md CHANGED
@@ -1,3 +1,64 @@
1
  ---
2
- license: mit
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ license: apache-2.0
3
+ tags:
4
+ - mlx
5
+ - mlx-image
6
+ - vision
7
+ - image-classification
8
+ datasets:
9
+ - imagenet-1k
10
+ library_name: mlx-image
11
+
12
  ---
13
+
14
+ # ResNet50
15
+
16
+ ResNet50 is a computer vision model trained on imagenet-1k. It was introduced in the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) and first released in [this repository](https://github.com/KaimingHe/deep-residual-networks).
17
+
18
+ Disclaimer: This is a porting of the torchvision model weights to Apple MLX Framework.
19
+
20
+
21
+ ## How to use
22
+ ```bash
23
+ pip install mlx-image
24
+ ```
25
+
26
+ Here is how to use this model for image classification:
27
+
28
+ ```python
29
+ from mlxim.model import create_model
30
+ from mlxim.io import read_rgb
31
+ from mlxim.transform import ImageNetTransform
32
+
33
+ transform = ImageNetTransform(train=False, img_size=224)
34
+ x = transform(read_rgb("cat.png"))
35
+ x = mx.expand_dims(x, 0)
36
+
37
+ model = create_model("resnet50")
38
+ model.eval()
39
+
40
+ logits = model(x)
41
+ ```
42
+
43
+ You can also use the embeds from last conv layer:
44
+ ```python
45
+ from mlxim.model import create_model
46
+ from mlxim.io import read_rgb
47
+ from mlxim.transform import ImageNetTransform
48
+
49
+ transform = ImageNetTransform(train=False, img_size=224)
50
+ x = transform(read_rgb("cat.png"))
51
+ x = mx.expand_dims(x, 0)
52
+
53
+ # first option
54
+ model = create_model("resnet50", num_classes=0)
55
+ model.eval()
56
+
57
+ embeds = model(x)
58
+
59
+ # second option
60
+ model = create_model("resnet50")
61
+ model.eval()
62
+
63
+ embeds = model.features(x)
64
+ ```